[TFBart] Split TF-Bart (#9497)
* make templates ready * make add_new_model_command_ready * finish tf bart * prepare tf mbart * finish tf bart * add tf mbart * add marian * prep pegasus * add tf pegasus * push blenderbot tf * add blenderbot * add blenderbot small * clean-up * make fix copy * define blend bot tok * fix * up * make style * add to docs * add copy statements * overwrite changes * improve * fix docs * finish * fix last slow test * fix missing git conflict line * fix blenderbot * up * fix blenderbot small * load changes * finish copied from * upload fix
This commit is contained in:
committed by
GitHub
parent
0ecbb69806
commit
7f28613213
@@ -225,7 +225,7 @@ TensorFlow and/or Flax.
|
|||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| BlenderbotSmall | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
|||||||
@@ -98,10 +98,15 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
TFBlenderbotModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFBlenderbotModel
|
||||||
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
TFBlenderbotForConditionalGeneration
|
TFBlenderbotForConditionalGeneration
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate`
|
|
||||||
|
|
||||||
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
|
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
|
||||||
:members:
|
:members: call
|
||||||
|
|||||||
@@ -68,3 +68,17 @@ BlenderbotSmallForConditionalGeneration
|
|||||||
|
|
||||||
.. autoclass:: transformers.BlenderbotSmallForConditionalGeneration
|
.. autoclass:: transformers.BlenderbotSmallForConditionalGeneration
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
TFBlenderbotSmallModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFBlenderbotSmallModel
|
||||||
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
|
TFBlenderbotSmallForConditionalGeneration
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFBlenderbotSmallForConditionalGeneration
|
||||||
|
:members: call
|
||||||
|
|||||||
@@ -193,7 +193,15 @@ MarianMTModel
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
TFMarianModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMarianModel
|
||||||
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
TFMarianMTModel
|
TFMarianMTModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.TFMarianMTModel
|
.. autoclass:: transformers.TFMarianMTModel
|
||||||
|
:members: call
|
||||||
|
|||||||
@@ -124,8 +124,15 @@ MBartForSequenceClassification
|
|||||||
.. autoclass:: transformers.MBartForSequenceClassification
|
.. autoclass:: transformers.MBartForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
|
TFMBartModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFMBartModel
|
||||||
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
TFMBartForConditionalGeneration
|
TFMBartForConditionalGeneration
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.TFMBartForConditionalGeneration
|
.. autoclass:: transformers.TFMBartForConditionalGeneration
|
||||||
:members:
|
:members: call
|
||||||
|
|||||||
@@ -131,7 +131,15 @@ PegasusForConditionalGeneration
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
TFPegasusModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TFPegasusModel
|
||||||
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
TFPegasusForConditionalGeneration
|
TFPegasusForConditionalGeneration
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.TFPegasusForConditionalGeneration
|
.. autoclass:: transformers.TFPegasusForConditionalGeneration
|
||||||
|
:members: call
|
||||||
|
|||||||
@@ -868,7 +868,10 @@ if is_tf_available():
|
|||||||
"TFBertPreTrainedModel",
|
"TFBertPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.blenderbot"].append("TFBlenderbotForConditionalGeneration")
|
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
|
||||||
|
_import_structure["models.blenderbot_small"].extend(
|
||||||
|
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
|
||||||
|
)
|
||||||
_import_structure["models.camembert"].extend(
|
_import_structure["models.camembert"].extend(
|
||||||
[
|
[
|
||||||
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -986,8 +989,8 @@ if is_tf_available():
|
|||||||
"TFLxmertVisualFeatureEncoder",
|
"TFLxmertVisualFeatureEncoder",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.marian"].append("TFMarianMTModel")
|
_import_structure["models.marian"].extend(["TFMarianMTModel", "TFMarianModel"])
|
||||||
_import_structure["models.mbart"].append("TFMBartForConditionalGeneration")
|
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
|
||||||
_import_structure["models.mobilebert"].extend(
|
_import_structure["models.mobilebert"].extend(
|
||||||
[
|
[
|
||||||
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -1028,7 +1031,7 @@ if is_tf_available():
|
|||||||
"TFOpenAIGPTPreTrainedModel",
|
"TFOpenAIGPTPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.pegasus"].append("TFPegasusForConditionalGeneration")
|
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
|
||||||
_import_structure["models.roberta"].extend(
|
_import_structure["models.roberta"].extend(
|
||||||
[
|
[
|
||||||
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -1855,7 +1858,8 @@ if TYPE_CHECKING:
|
|||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFBertPreTrainedModel,
|
TFBertPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.blenderbot import TFBlenderbotForConditionalGeneration
|
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||||
|
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
||||||
from .models.camembert import (
|
from .models.camembert import (
|
||||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFCamembertForMaskedLM,
|
TFCamembertForMaskedLM,
|
||||||
@@ -1953,8 +1957,8 @@ if TYPE_CHECKING:
|
|||||||
TFLxmertPreTrainedModel,
|
TFLxmertPreTrainedModel,
|
||||||
TFLxmertVisualFeatureEncoder,
|
TFLxmertVisualFeatureEncoder,
|
||||||
)
|
)
|
||||||
from .models.marian import TFMarianMTModel
|
from .models.marian import TFMarian, TFMarianMTModel
|
||||||
from .models.mbart import TFMBartForConditionalGeneration
|
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||||
from .models.mobilebert import (
|
from .models.mobilebert import (
|
||||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFMobileBertForMaskedLM,
|
TFMobileBertForMaskedLM,
|
||||||
@@ -1989,7 +1993,7 @@ if TYPE_CHECKING:
|
|||||||
TFOpenAIGPTModel,
|
TFOpenAIGPTModel,
|
||||||
TFOpenAIGPTPreTrainedModel,
|
TFOpenAIGPTPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.pegasus import TFPegasusForConditionalGeneration
|
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
||||||
from .models.roberta import (
|
from .models.roberta import (
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
|
|||||||
@@ -44,7 +44,11 @@ from ..bert.modeling_tf_bert import (
|
|||||||
TFBertLMHeadModel,
|
TFBertLMHeadModel,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
)
|
)
|
||||||
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
|
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||||
|
from ..blenderbot_small.modeling_tf_blenderbot_small import (
|
||||||
|
TFBlenderbotSmallForConditionalGeneration,
|
||||||
|
TFBlenderbotSmallModel,
|
||||||
|
)
|
||||||
from ..camembert.modeling_tf_camembert import (
|
from ..camembert.modeling_tf_camembert import (
|
||||||
TFCamembertForMaskedLM,
|
TFCamembertForMaskedLM,
|
||||||
TFCamembertForMultipleChoice,
|
TFCamembertForMultipleChoice,
|
||||||
@@ -100,8 +104,8 @@ from ..longformer.modeling_tf_longformer import (
|
|||||||
TFLongformerModel,
|
TFLongformerModel,
|
||||||
)
|
)
|
||||||
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
|
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
|
||||||
from ..marian.modeling_tf_marian import TFMarianMTModel
|
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
|
||||||
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration
|
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||||
from ..mobilebert.modeling_tf_mobilebert import (
|
from ..mobilebert.modeling_tf_mobilebert import (
|
||||||
TFMobileBertForMaskedLM,
|
TFMobileBertForMaskedLM,
|
||||||
TFMobileBertForMultipleChoice,
|
TFMobileBertForMultipleChoice,
|
||||||
@@ -122,7 +126,7 @@ from ..mpnet.modeling_tf_mpnet import (
|
|||||||
)
|
)
|
||||||
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
|
||||||
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||||
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
||||||
from ..roberta.modeling_tf_roberta import (
|
from ..roberta.modeling_tf_roberta import (
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
TFRobertaForMultipleChoice,
|
TFRobertaForMultipleChoice,
|
||||||
@@ -167,6 +171,7 @@ from .configuration_auto import (
|
|||||||
BartConfig,
|
BartConfig,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BlenderbotConfig,
|
BlenderbotConfig,
|
||||||
|
BlenderbotSmallConfig,
|
||||||
CamembertConfig,
|
CamembertConfig,
|
||||||
CTRLConfig,
|
CTRLConfig,
|
||||||
DistilBertConfig,
|
DistilBertConfig,
|
||||||
@@ -225,6 +230,12 @@ TF_MODEL_MAPPING = OrderedDict(
|
|||||||
(FunnelConfig, TFFunnelModel),
|
(FunnelConfig, TFFunnelModel),
|
||||||
(DPRConfig, TFDPRQuestionEncoder),
|
(DPRConfig, TFDPRQuestionEncoder),
|
||||||
(MPNetConfig, TFMPNetModel),
|
(MPNetConfig, TFMPNetModel),
|
||||||
|
(BartConfig, TFBartModel),
|
||||||
|
(MBartConfig, TFMBartModel),
|
||||||
|
(MarianConfig, TFMarianModel),
|
||||||
|
(PegasusConfig, TFPegasusModel),
|
||||||
|
(BlenderbotConfig, TFBlenderbotModel),
|
||||||
|
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -328,6 +339,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
|||||||
(MBartConfig, TFMBartForConditionalGeneration),
|
(MBartConfig, TFMBartForConditionalGeneration),
|
||||||
(PegasusConfig, TFPegasusForConditionalGeneration),
|
(PegasusConfig, TFPegasusForConditionalGeneration),
|
||||||
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
|
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
|
||||||
|
(BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
|
||||||
(BartConfig, TFBartForConditionalGeneration),
|
(BartConfig, TFBartForConditionalGeneration),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from ..bart.tokenization_bart import BartTokenizer
|
|||||||
from ..bert.tokenization_bert import BertTokenizer
|
from ..bert.tokenization_bert import BertTokenizer
|
||||||
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
|
from ..bert_japanese.tokenization_bert_japanese import BertJapaneseTokenizer
|
||||||
from ..bertweet.tokenization_bertweet import BertweetTokenizer
|
from ..bertweet.tokenization_bertweet import BertweetTokenizer
|
||||||
|
from ..blenderbot.tokenization_blenderbot import BlenderbotTokenizer
|
||||||
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
|
from ..blenderbot_small.tokenization_blenderbot_small import BlenderbotSmallTokenizer
|
||||||
from ..ctrl.tokenization_ctrl import CTRLTokenizer
|
from ..ctrl.tokenization_ctrl import CTRLTokenizer
|
||||||
from ..deberta.tokenization_deberta import DebertaTokenizer
|
from ..deberta.tokenization_deberta import DebertaTokenizer
|
||||||
@@ -58,6 +59,7 @@ from .configuration_auto import (
|
|||||||
BertConfig,
|
BertConfig,
|
||||||
BertGenerationConfig,
|
BertGenerationConfig,
|
||||||
BlenderbotConfig,
|
BlenderbotConfig,
|
||||||
|
BlenderbotSmallConfig,
|
||||||
CamembertConfig,
|
CamembertConfig,
|
||||||
CTRLConfig,
|
CTRLConfig,
|
||||||
DebertaConfig,
|
DebertaConfig,
|
||||||
@@ -201,7 +203,8 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
|
(MBartConfig, (MBartTokenizer, MBartTokenizerFast)),
|
||||||
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
|
(XLMRobertaConfig, (XLMRobertaTokenizer, XLMRobertaTokenizerFast)),
|
||||||
(MarianConfig, (MarianTokenizer, None)),
|
(MarianConfig, (MarianTokenizer, None)),
|
||||||
(BlenderbotConfig, (BlenderbotSmallTokenizer, None)),
|
(BlenderbotSmallConfig, (BlenderbotSmallTokenizer, None)),
|
||||||
|
(BlenderbotConfig, (BlenderbotTokenizer, None)),
|
||||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||||
(BartConfig, (BartTokenizer, BartTokenizerFast)),
|
(BartConfig, (BartTokenizer, BartTokenizerFast)),
|
||||||
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
(LongformerConfig, (LongformerTokenizer, LongformerTokenizerFast)),
|
||||||
|
|||||||
@@ -170,16 +170,6 @@ class BartConfig(PretrainedConfig):
|
|||||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN
|
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN
|
||||||
|
|
||||||
# IMPORTANT
|
|
||||||
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
|
|
||||||
self.extra_pos_embeddings = 2
|
|
||||||
self.normalize_before = False
|
|
||||||
self.add_final_layer_norm = False
|
|
||||||
self.do_blenderbot_90_layernorm = False
|
|
||||||
self.normalize_embedding = True
|
|
||||||
self.static_position_embeddings = False
|
|
||||||
self.add_bias_logits = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_attention_heads(self) -> int:
|
def num_attention_heads(self) -> int:
|
||||||
return self.encoder_attention_heads
|
return self.encoder_attention_heads
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -12,19 +12,18 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""TF BART model, ported from the fairseq repo."""
|
""" TF 2.0 Bart model. """
|
||||||
|
|
||||||
|
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
import warnings
|
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from ...activations_tf import ACT2FN
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
add_end_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
@@ -55,13 +54,14 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "BartConfig"
|
_CONFIG_FOR_DOC = "BartConfig"
|
||||||
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
||||||
|
|
||||||
|
|
||||||
LARGE_NEGATIVE = -1e8
|
LARGE_NEGATIVE = -1e8
|
||||||
|
|
||||||
|
|
||||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, eos_token_id: int):
|
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), eos_token_id)
|
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||||
# replace possible -100 values in labels by `pad_token_id`
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
shifted_input_ids = tf.where(
|
shifted_input_ids = tf.where(
|
||||||
@@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||||||
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
|
return tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
|
||||||
|
|
||||||
|
|
||||||
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
|
||||||
"""
|
"""
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
"""
|
"""
|
||||||
@@ -108,18 +108,15 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
|
|||||||
|
|
||||||
class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||||
"""
|
"""
|
||||||
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
|
This module learns positional embeddings up to a fixed maximum size.
|
||||||
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
|
|
||||||
the forward function.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset, **kwargs):
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
|
||||||
|
assert padding_idx is not None, "padding_idx cannot be None"
|
||||||
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||||
# and adjust num_embeddings appropriately. Other models dont have this hack
|
# and adjust num_embeddings appropriately. Other models dont have this hack
|
||||||
self.offset = offset
|
self.offset = 2
|
||||||
assert padding_idx is not None, "padding_idx cannot be None"
|
super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
|
||||||
num_embeddings += offset
|
|
||||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
|
||||||
|
|
||||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
||||||
"""Input is expected to be of size [bsz x seqlen]."""
|
"""Input is expected to be of size [bsz x seqlen]."""
|
||||||
@@ -128,56 +125,7 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||||||
positions = tf.range(
|
positions = tf.range(
|
||||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||||
)
|
)
|
||||||
return super().call(positions + self.offset) # super object is not callable for some reason
|
return super().call(positions + self.offset)
|
||||||
|
|
||||||
|
|
||||||
class TFBartSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
|
||||||
"""This module produces sinusoidal positional embeddings of any length."""
|
|
||||||
|
|
||||||
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
|
|
||||||
|
|
||||||
if embedding_dim % 2 != 0:
|
|
||||||
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
|
||||||
super().__init__(
|
|
||||||
num_positions,
|
|
||||||
embedding_dim,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def build(self, input_shape: tf.TensorShape):
|
|
||||||
"""
|
|
||||||
Build shared token embedding layer Shared weights logic adapted from
|
|
||||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
|
||||||
"""
|
|
||||||
super().build(input_shape) # Instantiates self.weight so it can be loaded
|
|
||||||
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
|
|
||||||
self.set_weights([weight]) # overwrite self.weight to correct value
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _init_weight(n_pos: int, dim: int):
|
|
||||||
"""
|
|
||||||
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
|
||||||
the 2nd half of the vector. [dim // 2:]
|
|
||||||
"""
|
|
||||||
position_enc = np.array(
|
|
||||||
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
|
||||||
)
|
|
||||||
# index 0 is all zero
|
|
||||||
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
|
|
||||||
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
|
|
||||||
# convert to tensor
|
|
||||||
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
|
|
||||||
tf.stop_gradient(table)
|
|
||||||
return table
|
|
||||||
|
|
||||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
|
||||||
"""Input is expected to be of size [bsz x seqlen]."""
|
|
||||||
bsz, seq_len = input_shape[:2]
|
|
||||||
|
|
||||||
positions = tf.range(
|
|
||||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
|
||||||
)
|
|
||||||
return super().call(positions)
|
|
||||||
|
|
||||||
|
|
||||||
class TFBartAttention(tf.keras.layers.Layer):
|
class TFBartAttention(tf.keras.layers.Layer):
|
||||||
@@ -310,10 +258,9 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
self.self_attn = TFBartAttention(
|
self.self_attn = TFBartAttention(
|
||||||
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
|
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
|
||||||
)
|
)
|
||||||
self.normalize_before = config.normalize_before
|
|
||||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = get_tf_activation(config.activation_function)
|
||||||
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
||||||
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
|
self.fc1 = tf.keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
|
||||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||||
@@ -327,8 +274,6 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if self.normalize_before:
|
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||||
hidden_states=hidden_states, attention_mask=attention_mask
|
hidden_states=hidden_states, attention_mask=attention_mask
|
||||||
)
|
)
|
||||||
@@ -339,19 +284,15 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
if not self.normalize_before:
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if self.normalize_before:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||||
hidden_states = self.activation_dropout(hidden_states, training=training)
|
hidden_states = self.activation_dropout(hidden_states, training=training)
|
||||||
hidden_states = self.fc2(hidden_states)
|
hidden_states = self.fc2(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
if not self.normalize_before:
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states, self_attn_weights
|
return hidden_states, self_attn_weights
|
||||||
|
|
||||||
@@ -368,9 +309,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.activation_fn = ACT2FN[config.activation_function]
|
self.activation_fn = get_tf_activation(config.activation_function)
|
||||||
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
||||||
self.normalize_before = config.normalize_before
|
|
||||||
|
|
||||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||||
self.encoder_attn = TFBartAttention(
|
self.encoder_attn = TFBartAttention(
|
||||||
@@ -405,8 +345,6 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if self.normalize_before:
|
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
@@ -419,15 +357,12 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
if not self.normalize_before:
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if self.normalize_before:
|
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
@@ -439,24 +374,19 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
if not self.normalize_before:
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||||
present_key_value = present_key_value + cross_attn_present_key_value
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if self.normalize_before:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||||
hidden_states = self.activation_dropout(hidden_states, training=training)
|
hidden_states = self.activation_dropout(hidden_states, training=training)
|
||||||
hidden_states = self.fc2(hidden_states)
|
hidden_states = self.fc2(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
if not self.normalize_before:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -472,8 +402,8 @@ class TFBartPretrainedModel(TFPreTrainedModel):
|
|||||||
@property
|
@property
|
||||||
def dummy_inputs(self):
|
def dummy_inputs(self):
|
||||||
pad_token = 1
|
pad_token = 1
|
||||||
input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
|
input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
|
||||||
decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
|
decoder_input_ids = tf.cast(tf.convert_to_tensor(DUMMY_INPUTS), tf.int32)
|
||||||
dummy_inputs = {
|
dummy_inputs = {
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": tf.math.not_equal(input_ids, pad_token),
|
"attention_mask": tf.math.not_equal(input_ids, pad_token),
|
||||||
@@ -520,14 +450,6 @@ class TFBartPretrainedModel(TFPreTrainedModel):
|
|||||||
return self.serving_output(output)
|
return self.serving_output(output)
|
||||||
|
|
||||||
|
|
||||||
class TFPretrainedBartModel(TFBartPretrainedModel):
|
|
||||||
def __init_subclass__(self):
|
|
||||||
warnings.warn(
|
|
||||||
"The class `TFPretrainedBartModel` has been deprecated, please use `TFBartPretrainedModel` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BART_START_DOCSTRING = r"""
|
BART_START_DOCSTRING = r"""
|
||||||
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
|
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
|
||||||
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
||||||
@@ -563,6 +485,36 @@ BART_START_DOCSTRING = r"""
|
|||||||
model weights.
|
model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
BART_GENERATION_EXAMPLE = r"""
|
||||||
|
Summarization example::
|
||||||
|
|
||||||
|
>>> from transformers import BartTokenizer, TFBartForConditionalGeneration, BartConfig
|
||||||
|
|
||||||
|
>>> model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large')
|
||||||
|
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
|
||||||
|
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||||
|
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='tf')
|
||||||
|
|
||||||
|
>>> # Generate Summary
|
||||||
|
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
|
||||||
|
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
||||||
|
|
||||||
|
Mask filling example::
|
||||||
|
|
||||||
|
>>> from transformers import BartTokenizer, TFBartForConditionalGeneration
|
||||||
|
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
||||||
|
>>> TXT = "My friends are <mask> but they eat too many carbs."
|
||||||
|
|
||||||
|
>>> model = TFBartForConditionalGeneration.from_pretrained('facebook/bart-large')
|
||||||
|
>>> input_ids = tokenizer([TXT], return_tensors='tf')['input_ids']
|
||||||
|
>>> logits = model(input_ids).logits
|
||||||
|
>>> probs = tf.nn.softmax(logits[0])
|
||||||
|
>>> # probs[5] is associated with the mask token
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
BART_INPUTS_DOCSTRING = r"""
|
BART_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`):
|
input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`):
|
||||||
@@ -581,8 +533,21 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
Provide for translation and summarization training. By default, the model will create this tensor by
|
Indices of decoder input sequence tokens in the vocabulary.
|
||||||
shifting the input_ids right, following the paper.
|
|
||||||
|
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
|
Bart uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
|
||||||
|
:obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
|
||||||
|
:obj:`past_key_values`).
|
||||||
|
|
||||||
|
For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no
|
||||||
|
:obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
|
||||||
|
the right for denoising pre-training following the paper.
|
||||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
@@ -603,7 +568,7 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.TFModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||||
behaviors between training and evaluation).
|
behaviors between training and evaluation).
|
||||||
@@ -626,36 +591,19 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.layerdrop = config.encoder_layerdrop
|
self.layerdrop = config.encoder_layerdrop
|
||||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.max_source_positions = config.max_position_embeddings
|
self.max_source_positions = config.max_position_embeddings
|
||||||
|
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
if config.static_position_embeddings:
|
self.embed_positions = TFBartLearnedPositionalEmbedding(
|
||||||
self.embed_positions = TFBartSinusoidalPositionalEmbedding(
|
config.max_position_embeddings,
|
||||||
config.max_position_embeddings,
|
config.d_model,
|
||||||
config.d_model,
|
self.padding_idx,
|
||||||
name="embed_positions",
|
name="embed_positions",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.embed_positions = TFBartLearnedPositionalEmbedding(
|
|
||||||
config.max_position_embeddings,
|
|
||||||
config.d_model,
|
|
||||||
self.padding_idx,
|
|
||||||
config.extra_pos_embeddings,
|
|
||||||
name="embed_positions",
|
|
||||||
)
|
|
||||||
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||||
self.layernorm_embedding = (
|
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
|
||||||
if config.normalize_embedding
|
|
||||||
else tf.keras.layers.Layer()
|
|
||||||
)
|
|
||||||
self.layer_norm = (
|
|
||||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
|
||||||
if config.add_final_layer_norm
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_embed_tokens(self, embed_tokens):
|
def set_embed_tokens(self, embed_tokens):
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
@@ -725,11 +673,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if inputs["inputs_embeds"] is None:
|
if inputs["inputs_embeds"] is None:
|
||||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||||
else:
|
|
||||||
inputs["inputs_embeds"] = inputs["inputs_embeds"]
|
|
||||||
|
|
||||||
inputs["inputs_embeds"] = inputs["inputs_embeds"] * self.embed_scale
|
|
||||||
|
|
||||||
embed_pos = self.embed_positions(input_shape)
|
embed_pos = self.embed_positions(input_shape)
|
||||||
hidden_states = inputs["inputs_embeds"] + embed_pos
|
hidden_states = inputs["inputs_embeds"] + embed_pos
|
||||||
@@ -739,7 +683,9 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
# check attention mask and invert
|
# check attention mask and invert
|
||||||
if inputs["attention_mask"] is not None:
|
if inputs["attention_mask"] is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])
|
attention_mask = _expand_mask(inputs["attention_mask"])
|
||||||
|
else:
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
encoder_states = () if inputs["output_hidden_states"] else None
|
encoder_states = () if inputs["output_hidden_states"] else None
|
||||||
all_attentions = () if inputs["output_attentions"] else None
|
all_attentions = () if inputs["output_attentions"] else None
|
||||||
@@ -754,12 +700,11 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
continue
|
continue
|
||||||
|
|
||||||
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
|
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||||
|
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_attentions += (attn,)
|
all_attentions += (attn,)
|
||||||
if self.layer_norm:
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
@@ -786,36 +731,18 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
|
||||||
self.layerdrop = config.decoder_layerdrop
|
self.layerdrop = config.decoder_layerdrop
|
||||||
if config.static_position_embeddings:
|
self.embed_positions = TFBartLearnedPositionalEmbedding(
|
||||||
self.embed_positions = TFBartSinusoidalPositionalEmbedding(
|
config.max_position_embeddings,
|
||||||
config.max_position_embeddings,
|
config.d_model,
|
||||||
config.d_model,
|
self.padding_idx,
|
||||||
name="embed_positions",
|
name="embed_positions",
|
||||||
)
|
)
|
||||||
else:
|
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||||
self.embed_positions = TFBartLearnedPositionalEmbedding(
|
|
||||||
config.max_position_embeddings,
|
|
||||||
config.d_model,
|
|
||||||
self.padding_idx,
|
|
||||||
config.extra_pos_embeddings,
|
|
||||||
name="embed_positions",
|
|
||||||
)
|
|
||||||
self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
||||||
self.layernorm_embedding = (
|
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
|
||||||
if config.normalize_embedding
|
|
||||||
else tf.keras.layers.Layer()
|
|
||||||
)
|
|
||||||
self.layer_norm = (
|
|
||||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
|
||||||
if config.add_final_layer_norm
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
|
|
||||||
|
|
||||||
def set_embed_tokens(self, embed_tokens):
|
def set_embed_tokens(self, embed_tokens):
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
@@ -912,16 +839,16 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
past_key_values_length = (
|
past_key_values_length = (
|
||||||
inputs["past_key_values"][0][0].shape[2] if inputs["past_key_values"] is not None else 0
|
shape_list(inputs["past_key_values"][0][0])[2] if inputs["past_key_values"] is not None else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||||
|
|
||||||
if inputs["inputs_embeds"] is None:
|
if inputs["inputs_embeds"] is None:
|
||||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||||
|
|
||||||
hidden_states = inputs["inputs_embeds"] * self.embed_scale
|
hidden_states = inputs["inputs_embeds"]
|
||||||
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
@@ -931,35 +858,16 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||||
inputs["attention_mask"] = tf.cast(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
inputs["attention_mask"] = tf.concat(
|
|
||||||
[
|
|
||||||
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
|
|
||||||
inputs["attention_mask"],
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inputs["attention_mask"] = tf.ones(
|
|
||||||
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
||||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
|
||||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
|
inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])
|
||||||
|
|
||||||
if self.do_blenderbot_90_layernorm:
|
hidden_states = self.layernorm_embedding(hidden_states + positions)
|
||||||
hidden_states = self.layernorm_embedding(hidden_states) + positions
|
|
||||||
else:
|
|
||||||
hidden_states = self.layernorm_embedding(hidden_states + positions)
|
|
||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
@@ -991,10 +899,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
if self.layer_norm is not None: # same as if config.add_final_layer_norm
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
else:
|
else:
|
||||||
@@ -1002,7 +906,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
||||||
|
|
||||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values) if inputs["use_cache"] else None
|
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
||||||
@@ -1098,7 +1002,7 @@ class TFBartModel(TFBartPretrainedModel):
|
|||||||
|
|
||||||
if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None:
|
if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||||
inputs["input_ids"], self.config.pad_token_id, self.config.eos_token_id
|
inputs["input_ids"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["encoder_outputs"] is None:
|
if inputs["encoder_outputs"] is None:
|
||||||
@@ -1206,6 +1110,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
@add_end_docstrings(BART_GENERATION_EXAMPLE)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -1224,22 +1129,14 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
|||||||
training=False,
|
training=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
r"""
|
||||||
|
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
|
||||||
|
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
# Mask filling only works for bart-large
|
|
||||||
from transformers import BartTokenizer, TFBartForConditionalGeneration
|
|
||||||
import tensorflow as tf
|
|
||||||
mname = 'facebook/bart-large'
|
|
||||||
tokenizer = BartTokenizer.from_pretrained(mname)
|
|
||||||
TXT = "My friends are <mask> but they eat too many carbs."
|
|
||||||
model = TFBartForConditionalGeneration.from_pretrained(mname)
|
|
||||||
batch = tokenizer([TXT], return_tensors='tf')
|
|
||||||
logits = model(inputs=batch.input_ids).logits
|
|
||||||
probs = tf.nn.softmax(logits[0])
|
|
||||||
# probs[5] is associated with the mask token
|
|
||||||
"""
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
@@ -1265,7 +1162,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
|||||||
inputs["use_cache"] = False
|
inputs["use_cache"] = False
|
||||||
if inputs["decoder_input_ids"] is None:
|
if inputs["decoder_input_ids"] is None:
|
||||||
inputs["decoder_input_ids"] = shift_tokens_right(
|
inputs["decoder_input_ids"] = shift_tokens_right(
|
||||||
inputs["labels"], self.config.pad_token_id, self.config.eos_token_id
|
inputs["labels"], self.config.pad_token_id, self.config.decoder_start_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
@@ -1363,7 +1260,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
|||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
for layer_past_key_values in past_key_values:
|
for layer_past_key_values in past_key_values:
|
||||||
reordered_past += (
|
reordered_past += (
|
||||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values),
|
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2])
|
||||||
|
+ layer_past_key_values[2:],
|
||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration"]
|
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -52,7 +52,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
|
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -161,17 +161,6 @@ class BlenderbotConfig(PretrainedConfig):
|
|||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
|
||||||
# IMPORTANT
|
|
||||||
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
|
|
||||||
self.extra_pos_embeddings = 0
|
|
||||||
self.normalize_before = True
|
|
||||||
self.add_final_layer_norm = True
|
|
||||||
self.do_blenderbot_90_layernorm = True
|
|
||||||
self.normalize_embedding = False
|
|
||||||
self.static_position_embeddings = False
|
|
||||||
self.add_bias_logits = False
|
|
||||||
self.force_bos_token_to_be_generated = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_attention_heads(self) -> int:
|
def num_attention_heads(self) -> int:
|
||||||
return self.encoder_attention_heads
|
return self.encoder_attention_heads
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...file_utils import _BaseLazyModule, is_torch_available
|
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@@ -33,6 +33,11 @@ if is_torch_available():
|
|||||||
"BlenderbotSmallPreTrainedModel",
|
"BlenderbotSmallPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
_import_structure["modeling_tf_blenderbot_small"] = [
|
||||||
|
"TFBlenderbotSmallForConditionalGeneration",
|
||||||
|
"TFBlenderbotSmallModel",
|
||||||
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig
|
from .configuration_blenderbot_small import BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotSmallConfig
|
||||||
@@ -46,6 +51,9 @@ if TYPE_CHECKING:
|
|||||||
BlenderbotSmallPreTrainedModel,
|
BlenderbotSmallPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -866,6 +866,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions else None
|
all_cross_attentions = () if output_attentions else None
|
||||||
next_decoder_cache = () if use_cache else None
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -30,7 +30,7 @@ logger = logging.get_logger(__name__)
|
|||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
"vocab_file": "vocab.json",
|
"vocab_file": "vocab.json",
|
||||||
"merges_file": "merges.txt",
|
"merges_file": "merges.txt",
|
||||||
# "tokenizer_config_file": "tokenizer_config.json",
|
"tokenizer_config_file": "tokenizer_config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -75,13 +75,20 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
|
|||||||
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vocab_files_names = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
|
vocab_files_names = {
|
||||||
|
"vocab_file": "vocab.json",
|
||||||
|
"merges_file": "merges.txt",
|
||||||
|
"tokenizer_config": "tokenizer_config.json",
|
||||||
|
}
|
||||||
pretrained_vocab_files_map = {
|
pretrained_vocab_files_map = {
|
||||||
"vocab_file": {
|
"vocab_file": {
|
||||||
"facebook/blenderbot_small-90M": "https://cdn.huggingface.co/facebook/blenderbot_small-90M/vocab.json"
|
"facebook/blenderbot_small-90M": "https://huggingface.co/facebook/blenderbot_small-90M/blob/main/vocab.json"
|
||||||
},
|
},
|
||||||
"merges_file": {
|
"merges_file": {
|
||||||
"facebook/blenderbot_small-90M": "https://cdn.huggingface.co/facebook/blenderbot_small-90M/merges.txt"
|
"facebook/blenderbot_small-90M": "https://huggingface.co/facebook/blenderbot_small-90M/blob/main/merges.txt"
|
||||||
|
},
|
||||||
|
"tokenizer_config_file": {
|
||||||
|
"facebook/blenderbot_small-90M": "https://huggingface.co/facebook/blenderbot_small-90M/blob/main/tokenizer.json"
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
max_model_input_sizes = {"facebook/blenderbot_small-90M": 512}
|
max_model_input_sizes = {"facebook/blenderbot_small-90M": 512}
|
||||||
|
|||||||
@@ -1475,7 +1475,7 @@ LED_INPUTS_DOCSTRING = r"""
|
|||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.TFModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||||
behaviors between training and evaluation).
|
behaviors between training and evaluation).
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_marian"] = ["TFMarianMTModel"]
|
_import_structure["modeling_tf_marian"] = ["TFMarianMTModel", "TFMarianModel"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -60,7 +60,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_marian import TFMarianMTModel
|
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -159,17 +159,6 @@ class MarianConfig(PretrainedConfig):
|
|||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
|
||||||
# IMPORTANT
|
|
||||||
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
|
|
||||||
self.extra_pos_embeddings = 0
|
|
||||||
self.normalize_before = False
|
|
||||||
self.add_final_layer_norm = False
|
|
||||||
self.do_blenderbot_90_layernorm = False
|
|
||||||
self.normalize_embedding = False
|
|
||||||
self.static_position_embeddings = True
|
|
||||||
self.add_bias_logits = False
|
|
||||||
self.force_bos_token_to_be_generated = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_attention_heads(self) -> int:
|
def num_attention_heads(self) -> int:
|
||||||
return self.encoder_attention_heads
|
return self.encoder_attention_heads
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -47,7 +47,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration"]
|
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -70,7 +70,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -159,17 +159,6 @@ class MBartConfig(PretrainedConfig):
|
|||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
|
||||||
# IMPORTANT
|
|
||||||
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
|
|
||||||
self.extra_pos_embeddings = 2
|
|
||||||
self.normalize_before = True
|
|
||||||
self.add_final_layer_norm = True
|
|
||||||
self.do_blenderbot_90_layernorm = False
|
|
||||||
self.normalize_embedding = True
|
|
||||||
self.static_position_embeddings = False
|
|
||||||
self.add_bias_logits = False
|
|
||||||
self.force_bos_token_to_be_generated = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_attention_heads(self) -> int:
|
def num_attention_heads(self) -> int:
|
||||||
return self.encoder_attention_heads
|
return self.encoder_attention_heads
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -45,7 +45,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration"]
|
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -66,7 +66,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -159,17 +159,6 @@ class PegasusConfig(PretrainedConfig):
|
|||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
|
||||||
# IMPORTANT
|
|
||||||
# DELETE ALL OF THE FOLLOWING LINES AS SOON AS TF IS READY
|
|
||||||
self.extra_pos_embeddings = 0
|
|
||||||
self.normalize_before = True
|
|
||||||
self.add_final_layer_norm = True
|
|
||||||
self.do_blenderbot_90_layernorm = False
|
|
||||||
self.normalize_embedding = False
|
|
||||||
self.static_position_embeddings = True
|
|
||||||
self.add_bias_logits = False
|
|
||||||
self.force_bos_token_to_be_generated = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_attention_heads(self) -> int:
|
def num_attention_heads(self) -> int:
|
||||||
return self.encoder_attention_heads
|
return self.encoder_attention_heads
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -369,6 +369,33 @@ class TFBlenderbotForConditionalGeneration:
|
|||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFBlenderbotModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFBlenderbotSmallForConditionalGeneration:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFBlenderbotSmallModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -952,6 +979,11 @@ class TFLxmertVisualFeatureEncoder:
|
|||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFMarian:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
class TFMarianMTModel:
|
class TFMarianMTModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
@@ -970,6 +1002,15 @@ class TFMBartForConditionalGeneration:
|
|||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFMBartModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1211,6 +1252,15 @@ class TFPegasusForConditionalGeneration:
|
|||||||
requires_tf(self)
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TFPegasusModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_tf(self)
|
||||||
|
|
||||||
|
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -1464,7 +1464,6 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
|||||||
)
|
)
|
||||||
|
|
||||||
{% else %}
|
{% else %}
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -1936,7 +1935,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||||||
input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`):
|
input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`):
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
Indices can be obtained using :class:`~transformers.{{cookiecutter.camelcase_modelname}}Tokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
details.
|
details.
|
||||||
|
|
||||||
@@ -1949,8 +1948,21 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
Provide for translation and summarization training. By default, the model will create this tensor by
|
Indices of decoder input sequence tokens in the vocabulary.
|
||||||
shifting the input_ids right, following the paper.
|
|
||||||
|
Indices can be obtained using :class:`~transformers.{{cookiecutter.camelcase_modelname}}Tokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
|
{{cookiecutter.camelcase_modelname}} uses the :obj:`eos_token_id` as the starting token for
|
||||||
|
:obj:`decoder_input_ids` generation. If :obj:`past_key_values` is used, optionally only the last
|
||||||
|
:obj:`decoder_input_ids` have to be input (see :obj:`past_key_values`).
|
||||||
|
|
||||||
|
For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no
|
||||||
|
:obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
|
||||||
|
the right for denoising pre-training following the paper.
|
||||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||||
@@ -1971,7 +1983,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.TFModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||||
behaviors between training and evaluation).
|
behaviors between training and evaluation).
|
||||||
@@ -1996,7 +2008,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
self.layerdrop = config.encoder_layerdrop
|
self.layerdrop = config.encoder_layerdrop
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.max_source_positions = config.max_position_embeddings
|
self.max_source_positions = config.max_position_embeddings
|
||||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
|
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
@@ -2077,14 +2089,10 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if inputs["inputs_embeds"] is None:
|
if inputs["inputs_embeds"] is None:
|
||||||
inputs_embeds = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||||
else:
|
|
||||||
inputs_embeds = inputs["inputs_embeds"]
|
|
||||||
|
|
||||||
inputs_embeds = inputs_embeds
|
|
||||||
|
|
||||||
embed_pos = self.embed_positions(input_shape)
|
embed_pos = self.embed_positions(input_shape)
|
||||||
hidden_states = inputs_embeds + embed_pos
|
hidden_states = inputs["inputs_embeds"] + embed_pos
|
||||||
hidden_states = self.layernorm_embedding(hidden_states)
|
hidden_states = self.layernorm_embedding(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||||
|
|
||||||
@@ -2146,7 +2154,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
self.padding_idx,
|
self.padding_idx,
|
||||||
name="embed_positions",
|
name="embed_positions",
|
||||||
)
|
)
|
||||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||||
self.layers = [TF{{cookiecutter.camelcase_modelname}}DecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
self.layers = [TF{{cookiecutter.camelcase_modelname}}DecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
||||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||||
|
|
||||||
@@ -2259,7 +2267,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
hidden_states = inputs["inputs_embeds"]
|
hidden_states = inputs["inputs_embeds"]
|
||||||
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
combined_attention_mask = None
|
|
||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||||
else:
|
else:
|
||||||
@@ -2267,21 +2274,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||||
inputs["attention_mask"] = tf.cast(
|
combined_attention_mask = combined_attention_mask + _expand_mask(inputs["attention_mask"], tgt_len=input_shape[-1])
|
||||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
|
||||||
)
|
|
||||||
inputs["attention_mask"] = tf.concat(
|
|
||||||
[
|
|
||||||
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
|
|
||||||
inputs["attention_mask"],
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inputs["attention_mask"] = tf.ones(
|
|
||||||
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
@@ -2683,7 +2677,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
for layer_past_key_values in past_key_values:
|
for layer_past_key_values in past_key_values:
|
||||||
reordered_past += (
|
reordered_past += (
|
||||||
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values),
|
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) + layer_past_key_values[2:],
|
||||||
)
|
)
|
||||||
return (past[0], reordered_past)
|
return (past[0], reordered_past)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 {{cookiecutter.authors}} The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import {{cookiecutter.camelcase_modelname}}Config, is_tf_available
|
from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Config
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
@@ -28,12 +28,12 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
|
||||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -323,8 +323,12 @@ class TF{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCa
|
|||||||
{% else %}
|
{% else %}
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer, is_tf_available
|
from transformers import (
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_tf, slow
|
is_tf_available,
|
||||||
|
{{cookiecutter.camelcase_modelname}}Config,
|
||||||
|
{{cookiecutter.camelcase_modelname}}Tokenizer,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
@@ -333,7 +337,10 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, TF{{cookiecutter.camelcase_modelname}}Model
|
from transformers import (
|
||||||
|
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||||
|
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@@ -453,7 +460,7 @@ def prepare_{{cookiecutter.lowercase_modelname}}_inputs_dict(
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8)
|
decoder_attention_mask = tf.concat([tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8), tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8)], axis=-1)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -21,8 +21,8 @@ import unittest
|
|||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
from .test_configuration_common import ConfigTester
|
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -31,15 +31,17 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
{{cookiecutter.camelcase_modelname}}Config,
|
{{cookiecutter.camelcase_modelname}}Config,
|
||||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
|
||||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
|
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
)
|
)
|
||||||
from transformers.models.{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||||
|
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class {{cookiecutter.camelcase_modelname}}ModelTester:
|
class {{cookiecutter.camelcase_modelname}}ModelTester:
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2020 HuggingFace Inc. team.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -13,35 +13,158 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tests.test_configuration_common import ConfigTester
|
from transformers import BlenderbotConfig, BlenderbotTokenizer, is_tf_available
|
||||||
from tests.test_modeling_tf_bart import TFBartModelTester
|
|
||||||
from tests.test_modeling_tf_common import TFModelTesterMixin
|
|
||||||
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
|
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
|
from transformers.testing_utils import require_tf, require_tokenizers, slow
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
|
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||||
|
|
||||||
|
|
||||||
class TFBlenderbotModelTester(TFBartModelTester):
|
@require_tf
|
||||||
config_updates = dict(
|
class TFBlenderbotModelTester:
|
||||||
normalize_before=True,
|
|
||||||
static_position_embeddings=True,
|
|
||||||
do_blenderbot_90_layernorm=True,
|
|
||||||
normalize_embeddings=True,
|
|
||||||
)
|
|
||||||
config_cls = BlenderbotConfig
|
config_cls = BlenderbotConfig
|
||||||
|
config_updates = {}
|
||||||
|
hidden_act = "gelu"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=20,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||||
|
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||||
|
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||||
|
|
||||||
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = self.config_cls(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
eos_token_ids=[2],
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.pad_token_id,
|
||||||
|
**self.config_updates,
|
||||||
|
)
|
||||||
|
inputs_dict = prepare_blenderbot_inputs_dict(config, input_ids, decoder_input_ids)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
|
model = TFBlenderbotModel(config=config).get_decoder()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
input_ids = input_ids[:1, :]
|
||||||
|
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||||
|
self.batch_size = 1
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||||
|
|
||||||
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
past_key_values = past_key_values[1]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||||
|
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
|
||||||
|
|
||||||
|
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_blenderbot_inputs_dict(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
|
if decoder_attention_mask is None:
|
||||||
|
decoder_attention_mask = tf.concat(
|
||||||
|
[
|
||||||
|
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
|
||||||
|
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
all_model_classes = (TFBlenderbotForConditionalGeneration, TFBlenderbotModel) if is_tf_available() else ()
|
||||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -53,9 +176,9 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_decoder_model_past_large_inputs(self):
|
||||||
# inputs_embeds not supported
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
pass
|
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -77,8 +200,22 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
name = model.get_bias()
|
name = model.get_bias()
|
||||||
assert name is None
|
assert name is None
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_hidden_states_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_attentions_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
def test_saved_model_creation(self):
|
def test_saved_model_creation(self):
|
||||||
# This test is too long (>30sec) and makes fail the CI
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_saved_model_creation_extended(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
@@ -145,17 +282,33 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
if a is None and b is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
if tf.debugging.assert_near(a, b, atol=atol):
|
||||||
|
return True
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
msg = "{} != {}".format(a, b)
|
||||||
|
if prefix:
|
||||||
|
msg = prefix + ": " + msg
|
||||||
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return tf.constant(tok_lst, dtype=tf.int32)
|
||||||
|
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TFBlenderbot90MIntegrationTests(unittest.TestCase):
|
class TFBlenderbot400MIntegrationTests(unittest.TestCase):
|
||||||
src_text = [
|
src_text = ["My friends are cool but they eat too many carbs."]
|
||||||
"Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
model_name = "facebook/blenderbot-400M-distill"
|
||||||
]
|
|
||||||
model_name = "facebook/blenderbot-90M"
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def tokenizer(self):
|
def tokenizer(self):
|
||||||
return BlenderbotSmallTokenizer.from_pretrained(self.model_name)
|
return BlenderbotTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
@@ -163,17 +316,13 @@ class TFBlenderbot90MIntegrationTests(unittest.TestCase):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_90_generation_from_long_input(self):
|
def test_generation_from_long_input(self):
|
||||||
model_inputs = self.tokenizer(self.src_text, return_tensors="tf")
|
model_inputs = self.tokenizer(self.src_text, return_tensors="tf")
|
||||||
generated_ids = self.model.generate(
|
generated_ids = self.model.generate(
|
||||||
model_inputs.input_ids,
|
model_inputs.input_ids,
|
||||||
attention_mask=model_inputs.attention_mask,
|
|
||||||
num_beams=2,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
)
|
||||||
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)[0]
|
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)[0]
|
||||||
assert generated_words in (
|
assert (
|
||||||
"i don't know. i just feel like i'm going to throw up. it's not fun.",
|
generated_words
|
||||||
"i'm not sure. i just feel like i've been feeling like i have to be in a certain place",
|
== " That's unfortunate. Are they trying to lose weight or are they just trying to be healthier?"
|
||||||
"i'm not sure. i just feel like i've been in a bad situation.",
|
|
||||||
)
|
)
|
||||||
|
|||||||
328
tests/test_modeling_tf_blenderbot_small.py
Normal file
328
tests/test_modeling_tf_blenderbot_small.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import BlenderbotSmallConfig, BlenderbotSmallTokenizer, is_tf_available
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.testing_utils import require_tf, require_tokenizers, slow
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFBlenderbotSmallModelTester:
|
||||||
|
config_cls = BlenderbotSmallConfig
|
||||||
|
config_updates = {}
|
||||||
|
hidden_act = "gelu"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=20,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||||
|
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||||
|
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||||
|
|
||||||
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = self.config_cls(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
eos_token_ids=[2],
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.pad_token_id,
|
||||||
|
**self.config_updates,
|
||||||
|
)
|
||||||
|
inputs_dict = prepare_blenderbot_small_inputs_dict(config, input_ids, decoder_input_ids)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
|
model = TFBlenderbotSmallModel(config=config).get_decoder()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
input_ids = input_ids[:1, :]
|
||||||
|
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||||
|
self.batch_size = 1
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||||
|
|
||||||
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
past_key_values = past_key_values[1]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||||
|
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
|
||||||
|
|
||||||
|
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_blenderbot_small_inputs_dict(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
|
if decoder_attention_mask is None:
|
||||||
|
decoder_attention_mask = tf.concat(
|
||||||
|
[
|
||||||
|
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
|
||||||
|
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel) if is_tf_available() else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
||||||
|
is_encoder_decoder = True
|
||||||
|
test_pruning = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=BlenderbotSmallConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||||
|
x = model.get_output_layer_with_bias()
|
||||||
|
assert x is None
|
||||||
|
name = model.get_prefix_bias_name()
|
||||||
|
assert name is None
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_hidden_states_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_attentions_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_saved_model_creation(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_saved_model_creation_extended(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_resize_token_embeddings(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
def _get_word_embedding_weight(model, embedding_layer):
|
||||||
|
if hasattr(embedding_layer, "weight"):
|
||||||
|
return embedding_layer.weight
|
||||||
|
else:
|
||||||
|
# Here we build the word embeddings weights if not exists.
|
||||||
|
# And then we retry to get the attribute once built.
|
||||||
|
model(model.dummy_inputs)
|
||||||
|
if hasattr(embedding_layer, "weight"):
|
||||||
|
return embedding_layer.weight
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||||
|
# build the embeddings
|
||||||
|
model = model_class(config=config)
|
||||||
|
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||||
|
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||||
|
old_final_logits_bias = model.get_bias()
|
||||||
|
|
||||||
|
# reshape the embeddings
|
||||||
|
model.resize_token_embeddings(size)
|
||||||
|
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||||
|
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||||
|
new_final_logits_bias = model.get_bias()
|
||||||
|
|
||||||
|
# check that the resized embeddings size matches the desired size.
|
||||||
|
assert_size = size if size is not None else config.vocab_size
|
||||||
|
|
||||||
|
self.assertEqual(new_input_embeddings.shape[0], assert_size)
|
||||||
|
|
||||||
|
# check that weights remain the same after resizing
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
|
||||||
|
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||||
|
models_equal = False
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
if old_output_embeddings is not None and new_output_embeddings is not None:
|
||||||
|
self.assertEqual(new_output_embeddings.shape[0], assert_size)
|
||||||
|
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
|
||||||
|
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||||
|
models_equal = False
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
if old_final_logits_bias is not None and new_final_logits_bias is not None:
|
||||||
|
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
|
||||||
|
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
|
||||||
|
self.assertEqual(new_final_logits_bias.shape[0], 1)
|
||||||
|
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
|
||||||
|
|
||||||
|
models_equal = True
|
||||||
|
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
|
||||||
|
for p1, p2 in zip(old, new):
|
||||||
|
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
|
||||||
|
models_equal = False
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
if a is None and b is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
if tf.debugging.assert_near(a, b, atol=atol):
|
||||||
|
return True
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
msg = "{} != {}".format(a, b)
|
||||||
|
if prefix:
|
||||||
|
msg = prefix + ": " + msg
|
||||||
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return tf.constant(tok_lst, dtype=tf.int32)
|
||||||
|
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
class TFBlenderbot90MIntegrationTests(unittest.TestCase):
|
||||||
|
src_text = [
|
||||||
|
"Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
||||||
|
]
|
||||||
|
model_name = "facebook/blenderbot_small-90M"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tokenizer(self):
|
||||||
|
# use "old" tokenizer here because of bug when downloading new tokenizer
|
||||||
|
return BlenderbotSmallTokenizer.from_pretrained("facebook/blenderbot-90M")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_90_generation_from_long_input(self):
|
||||||
|
model_inputs = self.tokenizer(self.src_text, return_tensors="tf")
|
||||||
|
generated_ids = self.model.generate(
|
||||||
|
model_inputs.input_ids,
|
||||||
|
attention_mask=model_inputs.attention_mask,
|
||||||
|
num_beams=2,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)[0]
|
||||||
|
assert generated_words in (
|
||||||
|
"i don't know. i just feel like i'm going to throw up. it's not fun.",
|
||||||
|
"i'm not sure. i just feel like i've been feeling like i have to be in a certain place",
|
||||||
|
"i'm not sure. i just feel like i've been in a bad situation.",
|
||||||
|
)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2020 HuggingFace Inc. team.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -13,48 +13,174 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available
|
from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
|
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_bart import TFBartModelTester
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFAutoModelForSeq2SeqLM, TFMarianMTModel
|
from transformers import TFAutoModelForSeq2SeqLM, TFMarianModel, TFMarianMTModel
|
||||||
|
|
||||||
|
|
||||||
class ModelTester(TFBartModelTester):
|
|
||||||
config_updates = dict(static_position_embeddings=True, add_bias_logits=True)
|
|
||||||
config_cls = MarianConfig
|
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFMarianMTModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFMarianModelTester:
|
||||||
all_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
config_cls = MarianConfig
|
||||||
|
config_updates = {}
|
||||||
|
hidden_act = "gelu"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=20,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||||
|
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||||
|
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||||
|
|
||||||
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = self.config_cls(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
eos_token_ids=[2],
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.pad_token_id,
|
||||||
|
**self.config_updates,
|
||||||
|
)
|
||||||
|
inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
|
model = TFMarianModel(config=config).get_decoder()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
input_ids = input_ids[:1, :]
|
||||||
|
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||||
|
self.batch_size = 1
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||||
|
|
||||||
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
past_key_values = past_key_values[1]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||||
|
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
|
||||||
|
|
||||||
|
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_marian_inputs_dict(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
|
if decoder_attention_mask is None:
|
||||||
|
decoder_attention_mask = tf.concat(
|
||||||
|
[
|
||||||
|
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
|
||||||
|
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (TFMarianMTModel, TFMarianModel) if is_tf_available() else ()
|
||||||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||||
model_tester_cls = ModelTester
|
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = self.model_tester_cls(self)
|
self.model_tester = TFMarianModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=MarianConfig)
|
self.config_tester = ConfigTester(self, config_class=MarianConfig)
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_decoder_model_past_large_inputs(self):
|
||||||
# inputs_embeds not supported
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
pass
|
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_compile_tf_model(self):
|
def test_compile_tf_model(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -107,8 +233,22 @@ class TFMarianMTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
name = model.get_bias()
|
name = model.get_bias()
|
||||||
assert name is None
|
assert name is None
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_hidden_states_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_attentions_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
def test_saved_model_creation(self):
|
def test_saved_model_creation(self):
|
||||||
# This test is too long (>30sec) and makes fail the CI
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_saved_model_creation_extended(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
@@ -175,6 +315,25 @@ class TFMarianMTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
if a is None and b is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
if tf.debugging.assert_near(a, b, atol=atol):
|
||||||
|
return True
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
msg = "{} != {}".format(a, b)
|
||||||
|
if prefix:
|
||||||
|
msg = prefix + ": " + msg
|
||||||
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return tf.constant(tok_lst, dtype=tf.int32)
|
||||||
|
|
||||||
|
|
||||||
class AbstractMarianIntegrationTest(unittest.TestCase):
|
class AbstractMarianIntegrationTest(unittest.TestCase):
|
||||||
maxDiff = 1000 # show more chars for failing integration tests
|
maxDiff = 1000 # show more chars for failing integration tests
|
||||||
|
|
||||||
@@ -219,7 +378,6 @@ class AbstractMarianIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
@is_pt_tf_cross_test
|
|
||||||
class TestMarian_MT_EN(AbstractMarianIntegrationTest):
|
class TestMarian_MT_EN(AbstractMarianIntegrationTest):
|
||||||
"""Cover low resource/high perplexity setting. This breaks if pad_token_id logits not set to LARGE_NEGATIVE."""
|
"""Cover low resource/high perplexity setting. This breaks if pad_token_id logits not set to LARGE_NEGATIVE."""
|
||||||
|
|
||||||
@@ -233,7 +391,6 @@ class TestMarian_MT_EN(AbstractMarianIntegrationTest):
|
|||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TestMarian_en_zh(AbstractMarianIntegrationTest):
|
class TestMarian_en_zh(AbstractMarianIntegrationTest):
|
||||||
@@ -247,7 +404,6 @@ class TestMarian_en_zh(AbstractMarianIntegrationTest):
|
|||||||
self._assert_generated_batch_equal_expected()
|
self._assert_generated_batch_equal_expected()
|
||||||
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TestMarian_en_ROMANCE(AbstractMarianIntegrationTest):
|
class TestMarian_en_ROMANCE(AbstractMarianIntegrationTest):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2020 HuggingFace Inc. team.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -12,47 +12,107 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tests.test_configuration_common import ConfigTester
|
|
||||||
from tests.test_modeling_tf_bart import TFBartModelTester
|
|
||||||
from tests.test_modeling_tf_common import TFModelTesterMixin
|
|
||||||
from transformers import AutoTokenizer, MBartConfig, is_tf_available
|
from transformers import AutoTokenizer, MBartConfig, is_tf_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
|
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFAutoModelForSeq2SeqLM, TFMBartForConditionalGeneration
|
from transformers import TFAutoModelForSeq2SeqLM, TFMBartForConditionalGeneration, TFMBartModel
|
||||||
|
|
||||||
|
|
||||||
class ModelTester(TFBartModelTester):
|
|
||||||
config_updates = dict(normalize_before=True, add_final_layer_norm=True)
|
|
||||||
config_cls = MBartConfig
|
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFMBartModelTester:
|
||||||
all_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
config_cls = MBartConfig
|
||||||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
config_updates = {}
|
||||||
model_tester_cls = ModelTester
|
hidden_act = "gelu"
|
||||||
is_encoder_decoder = True
|
|
||||||
test_pruning = False
|
|
||||||
|
|
||||||
def setUp(self):
|
def __init__(
|
||||||
self.model_tester = self.model_tester_cls(self)
|
self,
|
||||||
self.config_tester = ConfigTester(self, config_class=MBartConfig)
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=20,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
def test_config(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
self.config_tester.run_common_tests()
|
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||||
|
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||||
|
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
# inputs_embeds not supported
|
|
||||||
pass
|
config = self.config_cls(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
eos_token_ids=[2],
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.pad_token_id,
|
||||||
|
**self.config_updates,
|
||||||
|
)
|
||||||
|
inputs_dict = prepare_mbart_inputs_dict(config, input_ids, decoder_input_ids)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
|
model = TFMBartModel(config=config).get_decoder()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
input_ids = input_ids[:1, :]
|
||||||
|
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||||
|
self.batch_size = 1
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||||
|
|
||||||
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
past_key_values = past_key_values[1]
|
||||||
|
|
||||||
def test_compile_tf_model(self):
|
def test_compile_tf_model(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -60,13 +120,11 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||||
|
|
||||||
model_class = self.all_generative_model_classes[0]
|
model_class = self.all_generative_model_classes[0]
|
||||||
input_ids = {
|
input_ids = {
|
||||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prepare our model
|
# Prepare our model
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||||
@@ -74,17 +132,58 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model = model_class.from_pretrained(tmpdirname)
|
model = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
outputs_dict = model(input_ids)
|
outputs_dict = model(input_ids)
|
||||||
hidden_states = outputs_dict[0]
|
hidden_states = outputs_dict[0]
|
||||||
|
|
||||||
# Add a dense layer on top to test integration with other keras modules
|
# Add a dense layer on top to test integration with other keras modules
|
||||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||||
|
|
||||||
# Compile extended model
|
# Compile extended model
|
||||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_mbart_inputs_dict(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
|
if decoder_attention_mask is None:
|
||||||
|
decoder_attention_mask = tf.concat(
|
||||||
|
[
|
||||||
|
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
|
||||||
|
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (TFMBartForConditionalGeneration, TFMBartModel) if is_tf_available() else ()
|
||||||
|
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||||
|
is_encoder_decoder = True
|
||||||
|
test_pruning = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = TFMBartModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=MBartConfig)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_decoder_model_past_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -105,8 +204,22 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
name = model.get_bias()
|
name = model.get_bias()
|
||||||
assert name is None
|
assert name is None
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_hidden_states_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_attentions_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
def test_saved_model_creation(self):
|
def test_saved_model_creation(self):
|
||||||
# This test is too long (>30sec) and makes fail the CI
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_saved_model_creation_extended(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
@@ -173,10 +286,31 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
if a is None and b is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
if tf.debugging.assert_near(a, b, atol=atol):
|
||||||
|
return True
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
msg = "{} != {}".format(a, b)
|
||||||
|
if prefix:
|
||||||
|
msg = prefix + ": " + msg
|
||||||
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return tf.constant(tok_lst, dtype=tf.int32)
|
||||||
|
|
||||||
|
|
||||||
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TestMBartEnRO(unittest.TestCase):
|
class TFMBartModelIntegrationTest(unittest.TestCase):
|
||||||
src_text = [
|
src_text = [
|
||||||
" UN Chief Says There Is No Military Solution in Syria",
|
" UN Chief Says There Is No Military Solution in Syria",
|
||||||
]
|
]
|
||||||
@@ -191,7 +325,7 @@ class TestMBartEnRO(unittest.TestCase):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2020 HuggingFace Inc. team.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -18,46 +18,167 @@ import unittest
|
|||||||
|
|
||||||
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
|
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
|
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_bart import TFBartModelTester
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import TFAutoModelForSeq2SeqLM, TFPegasusForConditionalGeneration
|
from transformers import TFAutoModelForSeq2SeqLM, TFPegasusForConditionalGeneration, TFPegasusModel
|
||||||
|
|
||||||
|
|
||||||
class ModelTester(TFBartModelTester):
|
@require_tf
|
||||||
config_updates = dict(
|
class TFPegasusModelTester:
|
||||||
normalize_before=True,
|
|
||||||
static_position_embeddings=True,
|
|
||||||
)
|
|
||||||
hidden_act = "relu"
|
|
||||||
config_cls = PegasusConfig
|
config_cls = PegasusConfig
|
||||||
|
config_updates = {}
|
||||||
|
hidden_act = "gelu"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_labels=False,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=20,
|
||||||
|
eos_token_id=2,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||||
|
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
|
||||||
|
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||||
|
|
||||||
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = self.config_cls(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
d_model=self.hidden_size,
|
||||||
|
encoder_layers=self.num_hidden_layers,
|
||||||
|
decoder_layers=self.num_hidden_layers,
|
||||||
|
encoder_attention_heads=self.num_attention_heads,
|
||||||
|
decoder_attention_heads=self.num_attention_heads,
|
||||||
|
encoder_ffn_dim=self.intermediate_size,
|
||||||
|
decoder_ffn_dim=self.intermediate_size,
|
||||||
|
dropout=self.hidden_dropout_prob,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
eos_token_ids=[2],
|
||||||
|
bos_token_id=self.bos_token_id,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
decoder_start_token_id=self.pad_token_id,
|
||||||
|
**self.config_updates,
|
||||||
|
)
|
||||||
|
inputs_dict = prepare_pegasus_inputs_dict(config, input_ids, decoder_input_ids)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||||
|
model = TFPegasusModel(config=config).get_decoder()
|
||||||
|
input_ids = inputs_dict["input_ids"]
|
||||||
|
|
||||||
|
input_ids = input_ids[:1, :]
|
||||||
|
attention_mask = inputs_dict["attention_mask"][:1, :]
|
||||||
|
self.batch_size = 1
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||||
|
|
||||||
|
output, past_key_values = outputs.to_tuple()
|
||||||
|
past_key_values = past_key_values[1]
|
||||||
|
|
||||||
|
# create hypothetical next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||||
|
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
|
||||||
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
|
||||||
|
|
||||||
|
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_pegasus_inputs_dict(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
decoder_input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
|
):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
|
if decoder_attention_mask is None:
|
||||||
|
decoder_attention_mask = tf.concat(
|
||||||
|
[
|
||||||
|
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
|
||||||
|
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
all_model_classes = (TFPegasusForConditionalGeneration, TFPegasusModel) if is_tf_available() else ()
|
||||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||||
model_tester_cls = ModelTester
|
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = self.model_tester_cls(self)
|
self.model_tester = TFPegasusModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=PegasusConfig)
|
self.config_tester = ConfigTester(self, config_class=PegasusConfig)
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_decoder_model_past_large_inputs(self):
|
||||||
# inputs_embeds not supported
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
pass
|
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_compile_tf_model(self):
|
def test_compile_tf_model(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -110,8 +231,22 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
name = model.get_bias()
|
name = model.get_bias()
|
||||||
assert name is None
|
assert name is None
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_hidden_states_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_attentions_output(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
def test_saved_model_creation(self):
|
def test_saved_model_creation(self):
|
||||||
# This test is too long (>30sec) and makes fail the CI
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_saved_model_creation_extended(self):
|
||||||
|
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
@@ -178,7 +313,25 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
if a is None and b is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
if tf.debugging.assert_near(a, b, atol=atol):
|
||||||
|
return True
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
msg = "{} != {}".format(a, b)
|
||||||
|
if prefix:
|
||||||
|
msg = prefix + ": " + msg
|
||||||
|
raise AssertionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _long_tensor(tok_lst):
|
||||||
|
return tf.constant(tok_lst, dtype=tf.int32)
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TFPegasusIntegrationTests(unittest.TestCase):
|
class TFPegasusIntegrationTests(unittest.TestCase):
|
||||||
@@ -198,7 +351,7 @@ class TFPegasusIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
|
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user