Fix TF s2s models (#9478)
* Fix Seq2Seq models for serving * Apply style * Fix lonfgormer * Fix mBart/Pegasus/Blenderbot * Apply style * Add a main intermediate layer * Apply style * Remove import * Apply tf.function to Longformer * Fix utils check_copy * Update S2S template * Fix BART + Blenderbot * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix MBart * Fix Marian * Fix Pegasus + template * Apply style * Fix common attributes test * Forgot to fix the LED test * Apply Patrick's comment on LED Decoder
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers
|
||||
|
||||
from transformers.modeling_tf_outputs import TFCausalLMOutput
|
||||
|
||||
@@ -1915,29 +1916,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
def get_input_embeddings(self):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
return base_model.shared
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
|
||||
try:
|
||||
base_model.shared.weight = value
|
||||
except AttributeError:
|
||||
self(self.dummy_inputs)
|
||||
base_model.shared.weight = value
|
||||
|
||||
base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
|
||||
embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
|
||||
base_model.encoder.set_embed_tokens(embed_tokens)
|
||||
base_model.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
@@ -1948,6 +1926,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
}
|
||||
]
|
||||
)
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartPretrainedModel.serving
|
||||
def serving(self, inputs):
|
||||
output = self.call(inputs)
|
||||
|
||||
@@ -2080,6 +2059,9 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
@@ -2148,7 +2130,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs["inputs_embeds"] + embed_pos
|
||||
@@ -2158,9 +2140,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
# check attention mask and invert
|
||||
if inputs["attention_mask"] is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(inputs["attention_mask"])
|
||||
else:
|
||||
attention_mask = None
|
||||
inputs["attention_mask"] = _expand_mask(inputs["attention_mask"])
|
||||
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
@@ -2175,7 +2155,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
|
||||
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -2219,9 +2199,12 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
def get_embed_tokens(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -2321,20 +2304,13 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
||||
|
||||
hidden_states = inputs["inputs_embeds"]
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
combined_attention_mask = _expand_mask(
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(inputs["attention_mask"], tgt_len=input_shape[-1])
|
||||
inputs["attention_mask"], combined_attention_mask = self.compute_combined_attns_mask(
|
||||
inputs, input_shape, past_key_values_length
|
||||
)
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
@@ -2344,13 +2320,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = ()
|
||||
all_self_attns = ()
|
||||
present_key_values = ()
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_self_attns = () if inputs["output_attentions"] else None
|
||||
present_key_values = () if inputs["use_cache"] else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop):
|
||||
@@ -2374,12 +2352,12 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
else:
|
||||
all_hidden_states = None
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_self_attns = list(all_self_attns)
|
||||
|
||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
||||
|
||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
||||
@@ -2390,18 +2368,43 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
@tf.function
|
||||
def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_length):
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
combined_attention_mask = _expand_mask(
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
|
||||
attention_mask = tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
||||
)
|
||||
attention_mask = tf.concat(
|
||||
[
|
||||
tf.ones((input_shape[0], past_key_values_length), dtype=attention_mask.dtype),
|
||||
attention_mask,
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
|
||||
|
||||
return attention_mask, combined_attention_mask
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
|
||||
)
|
||||
@keras_serializable
|
||||
class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
|
||||
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
@@ -2414,20 +2417,21 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
|
||||
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, embed_tokens, name="decoder")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
|
||||
pass
|
||||
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
|
||||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
@@ -2467,12 +2471,6 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
if inputs["decoder_input_ids"] is None and inputs["decoder_inputs_embeds"] is None:
|
||||
inputs["use_cache"] = False
|
||||
|
||||
inputs["output_hidden_states"] = (
|
||||
inputs["output_hidden_states"]
|
||||
if inputs["output_hidden_states"] is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if inputs["encoder_outputs"] is None:
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
@@ -2520,10 +2518,88 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
|
||||
)
|
||||
class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
output_type=TFSeq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs
|
||||
):
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
@@ -2552,7 +2628,8 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.model = TF{{cookiecutter.camelcase_modelname}}Model(config, name="model")
|
||||
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
|
||||
self.model._set_save_spec(inputs=self.serving.input_signature)
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
@@ -2675,7 +2752,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
Reference in New Issue
Block a user