Fix TF s2s models (#9478)
* Fix Seq2Seq models for serving * Apply style * Fix lonfgormer * Fix mBart/Pegasus/Blenderbot * Apply style * Add a main intermediate layer * Apply style * Remove import * Apply tf.function to Longformer * Fix utils check_copy * Update S2S template * Fix BART + Blenderbot * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix MBart * Fix Marian * Fix Pegasus + template * Apply style * Fix common attributes test * Forgot to fix the LED test * Apply Patrick's comment on LED Decoder
This commit is contained in:
@@ -322,6 +322,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
"""
|
||||
signature = dict(inspect.signature(func).parameters)
|
||||
signature.pop("kwargs", None)
|
||||
signature.pop("self", None)
|
||||
parameter_names = list(signature.keys())
|
||||
output = {}
|
||||
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
|
||||
@@ -346,6 +347,8 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
|
||||
)
|
||||
|
||||
kwargs.pop("kwargs_call")
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, allowed_types) or v is None:
|
||||
output[k] = v
|
||||
@@ -356,8 +359,8 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
for i, input in enumerate(input_ids):
|
||||
# EagerTensors don't allow to use the .name property so we check for a real Tensor
|
||||
if type(input) == tf.Tensor:
|
||||
# Tensor names have always the pattern name:device_id then we check only the
|
||||
# name and not the device id
|
||||
# Tensor names have always the pattern `name:id` then we check only the
|
||||
# `name` part
|
||||
tensor_name = input.name.split(":")[0]
|
||||
|
||||
if tensor_name in parameter_names:
|
||||
|
||||
@@ -411,29 +411,6 @@ class TFBartPretrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -605,6 +582,9 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
||||
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -744,6 +724,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -871,13 +854,15 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_self_attns = () if inputs["output_attentions"] else None
|
||||
present_key_values = () if inputs["use_cache"] else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop):
|
||||
@@ -901,12 +886,12 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
||||
if inputs["output_attentions"]:
|
||||
all_self_attns = list(all_self_attns)
|
||||
|
||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
||||
@@ -919,16 +904,14 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BART Model outputting raw hidden-states without any specific head on top.",
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TFBartModel(TFBartPretrainedModel):
|
||||
base_model_prefix = "model"
|
||||
class TFBartMainLayer(tf.keras.layers.Layer):
|
||||
config_class = BartConfig
|
||||
|
||||
def __init__(self, config: BartConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
@@ -942,19 +925,20 @@ class TFBartModel(TFBartPretrainedModel):
|
||||
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/bart-large",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -1053,8 +1037,86 @@ class TFBartModel(TFBartPretrainedModel):
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BART Model outputting raw hidden-states without any specific head on top.",
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class TFBartModel(TFBartPretrainedModel):
|
||||
def __init__(self, config: BartConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TFBartMainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/bart-large",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1083,7 +1145,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFBartModel(config, name="model")
|
||||
self.model = TFBartMainLayer(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
@@ -1199,7 +1261,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
|
||||
)
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
@@ -24,6 +24,7 @@ import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@@ -47,7 +48,6 @@ from ...modeling_tf_utils import (
|
||||
shape_list,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ..blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
|
||||
|
||||
@@ -416,31 +416,6 @@ class TFBlenderbotPreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -604,6 +579,9 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
||||
self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -744,6 +722,9 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -921,16 +902,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BLENDERBOT Model outputting raw hidden-states without any specific head on top.",
|
||||
BLENDERBOT_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
class TFBlenderbotMainLayer(tf.keras.layers.Layer):
|
||||
config_class = BlenderbotConfig
|
||||
|
||||
def __init__(self, config: BlenderbotConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: BlenderbotConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
@@ -944,22 +923,20 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
self.encoder = TFBlenderbotEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFBlenderbotDecoder(config, embed_tokens, name="decoder")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
if pretrained_model_name_or_path == "facebook/blenderbot-90M":
|
||||
warnings.warn(
|
||||
"The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical checkpoint `facebook/small_blenderbot-90M` with `TFBlenderbotSmallModel.from_pretrained('facebook/small_blenderbot-90M')` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return TFBlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path)
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
return super(TFBlenderbotModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -977,22 +954,6 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import BlenderbotTokenizer, TFBlenderbotModel
|
||||
|
||||
>>> model = TFBlenderbotModel.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
>>> tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@@ -1066,9 +1027,100 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BLENDERBOT Model outputting raw hidden-states without any specific head on top.",
|
||||
BLENDERBOT_START_DOCSTRING,
|
||||
)
|
||||
class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
|
||||
def __init__(self, config: BlenderbotConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TFBlenderbotMainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
if pretrained_model_name_or_path == "facebook/blenderbot-90M":
|
||||
from ..blenderbot_small import TFBlenderbotSmallModel
|
||||
|
||||
warnings.warn(
|
||||
"The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical checkpoint `facebook/small_blenderbot-90M` with `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return TFBlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path)
|
||||
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/blenderbot-400M-distill",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1097,25 +1149,43 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFBlenderbotModel(config, name="model")
|
||||
self.model = TFBlenderbotMainLayer(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
if pretrained_model_name_or_path == "facebook/blenderbot-90M":
|
||||
from ..blenderbot_small import TFBlenderbotSmallForConditionalGeneration
|
||||
|
||||
warnings.warn(
|
||||
"The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical checkpoint `facebook/small_blenderbot-90M` with `TFBlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return TFBlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
|
||||
|
||||
return super(TFBlenderbotForConditionalGeneration, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@@ -1208,7 +1278,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1283,21 +1353,6 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
|
||||
else:
|
||||
return logits
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def get_bias(self):
|
||||
return {"final_logits_bias": self.final_logits_bias}
|
||||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.compute_loss
|
||||
def compute_loss(self, labels, logits):
|
||||
"""CrossEntropyLoss that ignores pad tokens"""
|
||||
|
||||
@@ -52,7 +52,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
||||
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
||||
@@ -22,6 +22,7 @@ import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@@ -414,31 +415,6 @@ class TFBlenderbotSmallPreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -608,6 +584,9 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
||||
self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -748,6 +727,9 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -922,16 +904,14 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BLENDERBOT_SMALL Model outputting raw hidden-states without any specific head on top.",
|
||||
BLENDERBOT_SMALL_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
|
||||
config_class = BlenderbotSmallConfig
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: BlenderbotSmallConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
@@ -945,14 +925,20 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
self.encoder = TFBlenderbotSmallEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFBlenderbotSmallDecoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLENDERBOT_SMALL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -970,22 +956,6 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import BlenderbotSmallTokenizer, TFBlenderbotSmallModel
|
||||
|
||||
>>> model = TFBlenderbotSmallModel.from_pretrained("facebook/blenderbot_small-90M")
|
||||
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained("facebook/blenderbot_small-90M")
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@@ -1059,9 +1029,87 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BLENDERBOT_SMALL Model outputting raw hidden-states without any specific head on top.",
|
||||
BLENDERBOT_SMALL_START_DOCSTRING,
|
||||
)
|
||||
class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
|
||||
def __init__(self, config: BlenderbotSmallConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TFBlenderbotSmallMainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLENDERBOT_SMALL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/blenderbot_small-90M",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1090,7 +1138,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFBlenderbotSmallModel(config, name="model")
|
||||
self.model = TFBlenderbotSmallMainLayer(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
@@ -1206,7 +1254,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
@@ -320,6 +320,8 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
|
||||
# because of the concat Line 713.
|
||||
attn_probs = tf.where(
|
||||
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
|
||||
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
|
||||
@@ -882,6 +884,7 @@ class TFLEDEncoderAttention(tf.keras.layers.Layer):
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn],
|
||||
training=training,
|
||||
)
|
||||
|
||||
attention_output = self.output_dense(self_outputs[0], training=training)
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
|
||||
@@ -1046,15 +1049,16 @@ class TFLEDEncoderLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = self.activation_dropout(hidden_states, training=training)
|
||||
@@ -1182,29 +1186,6 @@ class TFLEDPreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -1521,6 +1502,9 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -1624,20 +1608,17 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
# check attention mask and invert
|
||||
if inputs["attention_mask"] is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(inputs["attention_mask"])[:, 0, 0, :]
|
||||
attention_mask = attention_mask[:, :, None, None]
|
||||
else:
|
||||
attention_mask = None
|
||||
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])[:, 0, 0, :]
|
||||
inputs["attention_mask"] = inputs["attention_mask"][:, :, None, None]
|
||||
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
all_global_attentions = () if inputs["output_attentions"] and is_global_attn else None
|
||||
all_attentions = all_global_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
hidden_states_to_add = hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||
hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
|
||||
encoder_states = encoder_states + (hidden_states_to_add,)
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
@@ -1646,7 +1627,7 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=inputs["attention_mask"],
|
||||
is_index_masked=is_index_masked,
|
||||
is_index_global_attn=is_index_global_attn,
|
||||
is_global_attn=is_global_attn,
|
||||
@@ -1658,14 +1639,12 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
# bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
|
||||
all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
|
||||
|
||||
if is_global_attn:
|
||||
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
|
||||
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
|
||||
# bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
|
||||
all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
|
||||
|
||||
# undo padding
|
||||
if padding_len > 0:
|
||||
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
|
||||
hidden_states = hidden_states[:, :-padding_len]
|
||||
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
|
||||
hidden_states = self.compute_hidden_states(hidden_states, padding_len)
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -1679,6 +1658,11 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
||||
global_attentions=all_global_attentions,
|
||||
)
|
||||
|
||||
@tf.function
|
||||
def compute_hidden_states(self, hidden_states, padding_len):
|
||||
return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
|
||||
|
||||
@tf.function
|
||||
def _pad_to_window_size(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1777,19 +1761,14 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
Args:
|
||||
input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.LEDTokenizer`. See
|
||||
provide it. Indices can be obtained using :class:`~transformers.LEDTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||
for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
for details. `What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
@@ -1800,13 +1779,10 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last
|
||||
decoding. If :obj:`past_key_values` are used, the user can optionally input only the last
|
||||
:obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
|
||||
shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
|
||||
sequence_length)`.
|
||||
@@ -1930,16 +1906,13 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LED Model outputting raw hidden-states without any specific head on top.",
|
||||
LED_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TFLEDModel(TFLEDPreTrainedModel):
|
||||
base_model_prefix = "led"
|
||||
class TFLEDMainLayer(tf.keras.layers.Layer):
|
||||
config_class = LEDConfig
|
||||
|
||||
def __init__(self, config: LEDConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: LEDConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="led.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("led.shared") as shared_abs_scope_name:
|
||||
@@ -1953,19 +1926,20 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
||||
self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("led.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/led-base-16384",
|
||||
output_type=TFLEDSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -2007,12 +1981,6 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
||||
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
|
||||
inputs["use_cache"] = False
|
||||
|
||||
inputs["output_hidden_states"] = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if inputs["encoder_outputs"] is None:
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
@@ -2063,8 +2031,88 @@ class TFLEDModel(TFLEDPreTrainedModel):
|
||||
encoder_global_attentions=inputs["encoder_outputs"].global_attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LED Model outputting raw hidden-states without any specific head on top.",
|
||||
LED_START_DOCSTRING,
|
||||
)
|
||||
class TFLEDModel(TFLEDPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.led = TFLEDMainLayer(config, name="led")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.led.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.led.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="allenai/led-base-16384",
|
||||
output_type=TFLEDSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None,
|
||||
global_attention_mask=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
global_attention_mask=global_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.led(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
global_attention_mask=inputs["global_attention_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -2095,7 +2143,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.led = TFLEDModel(config, name="led")
|
||||
self.led = TFLEDMainLayer(config, name="led")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
@@ -2157,6 +2205,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
>>> probs = tf.nn.softmax(logits[0])
|
||||
>>> # probs[5] is associated with the mask token
|
||||
"""
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@@ -2221,7 +2270,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
||||
)
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
@@ -974,6 +974,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
# make sure that local attention probabilities are set to 0 for indices of global attn
|
||||
# When is_global_attn is True, the last dimension is always self.one_sided_attn_window_size * 2 + 1 + 1
|
||||
# because of the concat Line 713.
|
||||
attn_probs = tf.where(
|
||||
tf.broadcast_to(is_index_global_attn[:, :, None, None], shape_list(attn_probs)),
|
||||
tf.zeros(shape_list(attn_probs), dtype=tf.dtypes.float32),
|
||||
|
||||
@@ -23,6 +23,7 @@ import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@@ -444,31 +445,6 @@ class TFMarianPreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -625,6 +601,9 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
||||
)
|
||||
self.layers = [TFMarianEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -761,6 +740,9 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -935,16 +917,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare MARIAN Model outputting raw hidden-states without any specific head on top.",
|
||||
MARIAN_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TFMarianModel(TFMarianPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
class TFMarianMainLayer(tf.keras.layers.Layer):
|
||||
config_class = MarianConfig
|
||||
|
||||
def __init__(self, config: MarianConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: MarianConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
@@ -958,14 +938,20 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
||||
self.encoder = TFMarianEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFMarianDecoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -983,24 +969,6 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import MarianTokenizer, TFMarianModel
|
||||
|
||||
>>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
>>> model = TFMarianModel.from_pretrained('Helsinki-NLP/opus-mt-en-de')
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> decoder_input_ids = tokenizer("<pad> Studien haben gezeigt dass es hilfreich ist einen Hund zu besitzen",
|
||||
... return_tensors="tf", add_special_tokens=False).input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@@ -1077,9 +1045,87 @@ class TFMarianModel(TFMarianPreTrainedModel):
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare MARIAN Model outputting raw hidden-states without any specific head on top.",
|
||||
MARIAN_START_DOCSTRING,
|
||||
)
|
||||
class TFMarianModel(TFMarianPreTrainedModel):
|
||||
def __init__(self, config: MarianConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TFMarianMainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="Helsinki-NLP/opus-mt-en-de",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1108,7 +1154,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFMarianModel(config, name="model")
|
||||
self.model = TFMarianMainLayer(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
@@ -1225,7 +1271,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
@@ -417,31 +417,6 @@ class TFMBartPreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -615,6 +590,9 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -757,6 +735,9 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -934,16 +915,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare MBART Model outputting raw hidden-states without any specific head on top.",
|
||||
MBART_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TFMBartModel(TFMBartPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
class TFMBartMainLayer(tf.keras.layers.Layer):
|
||||
config_class = MBartConfig
|
||||
|
||||
def __init__(self, config: MBartConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: MBartConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
@@ -957,19 +936,20 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
self.encoder = TFMBartEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFMBartDecoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/mbart-large-cc25",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -1066,9 +1046,87 @@ class TFMBartModel(TFMBartPreTrainedModel):
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare MBART Model outputting raw hidden-states without any specific head on top.",
|
||||
MBART_START_DOCSTRING,
|
||||
)
|
||||
class TFMBartModel(TFMBartPreTrainedModel):
|
||||
def __init__(self, config: MBartConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TFMBartMainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/mbart-large-cc25",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1097,7 +1155,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFMBartModel(config, name="model")
|
||||
self.model = TFMBartMainLayer(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
@@ -1212,7 +1270,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
@@ -23,6 +23,7 @@ import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@@ -445,31 +446,6 @@ class TFPegasusPreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -631,6 +607,9 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
||||
self.layers = [TFPegasusEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -770,6 +749,9 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -946,16 +928,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
|
||||
PEGASUS_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
class TFPegasusMainLayer(tf.keras.layers.Layer):
|
||||
config_class = PegasusConfig
|
||||
|
||||
def __init__(self, config: PegasusConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: PegasusConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
@@ -969,14 +949,20 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
self.encoder = TFPegasusEncoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TFPegasusDecoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -994,22 +980,6 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import PegasusTokenizer, TFPegasusModel
|
||||
|
||||
>>> tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-large")
|
||||
>>> model = TFPegasusModel.from_pretrained("google/pegasus-large")
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@@ -1086,9 +1056,87 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare PEGASUS Model outputting raw hidden-states without any specific head on top.",
|
||||
PEGASUS_START_DOCSTRING,
|
||||
)
|
||||
class TFPegasusModel(TFPegasusPreTrainedModel):
|
||||
def __init__(self, config: PegasusConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TFPegasusMainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/pegasus-large",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1117,7 +1165,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TFPegasusModel(config, name="model")
|
||||
self.model = TFPegasusMainLayer(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
@@ -1234,7 +1282,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
@@ -1207,7 +1207,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
)
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None,)
|
||||
pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -1437,7 +1437,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
)
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = (tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None,)
|
||||
pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
Reference in New Issue
Block a user