Add TFEncoderDecoderModel + Add cross-attention to some TF models (#13222)
* Add cross attentions to TFGPT2Model * Add TFEncoderDecoderModel * Add TFBaseModelOutputWithPoolingAndCrossAttentions * Add cross attentions to TFBertModel * Fix past or past_key_values argument issue * Fix generation * Fix save and load * Add some checks and comments * Clean the code that deals with past keys/values * Add kwargs to processing_inputs * Add serving_output to TFEncoderDecoderModel * Some cleaning + fix use_cache value issue * Fix tests + add bert2bert/bert2gpt2 tests * Fix more tests * Ignore crossattention.bias when loading GPT2 weights into TFGPT2 * Fix return_dict_in_generate in tf generation * Fix is_token_logit_eos_token bug in tf generation * Finalize the tests after fixing some bugs * Fix another is_token_logit_eos_token bug in tf generation * Add/Update docs * Add TFBertEncoderDecoderModelTest * Clean test script * Add TFEncoderDecoderModel to the library * Add cross attentions to TFRobertaModel * Add TFRobertaEncoderDecoderModelTest * make style * Change the way of position_ids computation * bug fix * Fix copies in tf_albert * Remove some copied from and apply some fix-copies * Remove some copied * Add cross attentions to some other TF models * Remove encoder_hidden_states from TFLayoutLMModel.call for now * Make style * Fix TFRemBertForCausalLM * Revert the change to longformer + Remove copies * Revert the change to albert and convbert + Remove copies * make quality * make style * Add TFRembertEncoderDecoderModelTest * make quality and fix-copies * test TFRobertaForCausalLM * Fixes for failed tests * Fixes for failed tests * fix more tests * Fixes for failed tests * Fix Auto mapping order * Fix TFRemBertEncoder return value * fix tf_rembert * Check copies are OK * Fix missing TFBaseModelOutputWithPastAndCrossAttentions is not defined * Add TFEncoderDecoderModelSaveLoadTests * fix tf weight loading * check the change of use_cache * Revert the change * Add missing test_for_causal_lm for TFRobertaModelTest * Try cleaning past * fix _reorder_cache * Revert some files to original versions * Keep as many copies as possible * Apply suggested changes - Use raise ValueError instead of assert * Move import to top * Fix wrong require_torch * Replace more assert by raise ValueError * Add test_pt_tf_model_equivalence (the test won't pass for now) * add test for loading/saving * finish * finish * Remove test_pt_tf_model_equivalence * Update tf modeling template * Remove pooling, added in the prev. commit, from MainLayer * Update tf modeling test template * Move inputs["use_cache"] = False to modeling_tf_utils.py * Fix torch.Tensor in the comment * fix use_cache * Fix missing use_cache in ElectraConfig * Add a note to from_pretrained * Fix style * Change test_encoder_decoder_save_load_from_encoder_decoder_from_pt * Fix TFMLP (in TFGPT2) activation issue * Fix None past_key_values value in serving_output * Don't call get_encoderdecoder_model in TFEncoderDecoderModelTest.test_configuration_tie until we have a TF checkpoint on Hub * Apply review suggestions - style for cross_attns in serving_output * Apply review suggestions - change assert + docstrings * break the error message to respect the char limit * deprecate the argument past * fix docstring style * Update the encoder-decoder rst file * fix Unknown interpreted text role "method" * fix typo Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -24,15 +24,16 @@ import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPooling,
|
||||
TFCausalLMOutput,
|
||||
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||
TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||
TFCausalLMOutputWithCrossAttentions,
|
||||
TFMaskedLMOutput,
|
||||
TFMultipleChoiceModelOutput,
|
||||
TFQuestionAnsweringModelOutput,
|
||||
@@ -116,6 +117,7 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
position_ids: tf.Tensor = None,
|
||||
token_type_ids: tf.Tensor = None,
|
||||
inputs_embeds: tf.Tensor = None,
|
||||
past_key_values_length=0,
|
||||
training: bool = False,
|
||||
) -> tf.Tensor:
|
||||
"""
|
||||
@@ -135,7 +137,9 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
|
||||
position_ids = tf.expand_dims(
|
||||
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
|
||||
)
|
||||
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
|
||||
@@ -174,6 +178,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
@@ -186,16 +192,49 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
encoder_hidden_states: tf.Tensor,
|
||||
encoder_attention_mask: tf.Tensor,
|
||||
past_key_value: Tuple[tf.Tensor],
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(inputs=hidden_states)
|
||||
mixed_key_layer = self.key(inputs=hidden_states)
|
||||
mixed_value_layer = self.value(inputs=hidden_states)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
|
||||
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
|
||||
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
@@ -225,6 +264,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -263,6 +304,9 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
input_tensor: tf.Tensor,
|
||||
attention_mask: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
encoder_hidden_states: tf.Tensor,
|
||||
encoder_attention_mask: tf.Tensor,
|
||||
past_key_value: Tuple[tf.Tensor],
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
@@ -270,13 +314,17 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
hidden_states=input_tensor,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
attention_output = self.dense_output(
|
||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||
)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
# add attentions (possibly with past_key_value) if we output them
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -327,6 +375,12 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.attention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="attention")
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = TF{{cookiecutter.camelcase_modelname}}Attention(config, name="crossattention")
|
||||
self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate")
|
||||
self.bert_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output")
|
||||
|
||||
@@ -335,20 +389,69 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
encoder_hidden_states: Optional[tf.Tensor],
|
||||
encoder_attention_mask: Optional[tf.Tensor],
|
||||
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
attention_outputs = self.attention(
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
input_tensor=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
|
||||
"by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
input_tensor=attention_output,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
intermediate_output = self.intermediate(hidden_states=attention_output)
|
||||
layer_output = self.bert_output(hidden_states=intermediate_output, input_tensor=attention_output, training=training)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
layer_output = self.bert_output(
|
||||
hidden_states=intermediate_output, input_tensor=attention_output, training=training
|
||||
)
|
||||
outputs = (layer_output,) + outputs # add attentions if we output them
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -357,7 +460,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
||||
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.layer = [TF{{cookiecutter.camelcase_modelname}}Layer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
||||
|
||||
def call(
|
||||
@@ -365,39 +468,61 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
encoder_hidden_states: Optional[tf.Tensor],
|
||||
encoder_attention_mask: Optional[tf.Tensor],
|
||||
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
|
||||
use_cache: Optional[bool],
|
||||
output_attentions: bool,
|
||||
output_hidden_states: bool,
|
||||
return_dict: bool,
|
||||
training: bool = False,
|
||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
# Add last layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return tuple(
|
||||
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
|
||||
)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -492,6 +617,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
self.embeddings = TF{{cookiecutter.camelcase_modelname}}Embeddings(config, name="embeddings")
|
||||
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, name="encoder")
|
||||
@@ -521,12 +647,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@@ -536,6 +666,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -543,6 +677,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if not self.config.is_decoder:
|
||||
inputs["use_cache"] = False
|
||||
|
||||
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif inputs["input_ids"] is not None:
|
||||
@@ -552,8 +689,16 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if inputs["past_key_values"] is None:
|
||||
past_key_values_length = 0
|
||||
inputs["past_key_values"] = [None] * len(self.encoder.layer)
|
||||
else:
|
||||
past_key_values_length = shape_list(inputs["past_key_values"][0][0])[-2]
|
||||
|
||||
if inputs["attention_mask"] is None:
|
||||
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
|
||||
inputs["attention_mask"] = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
|
||||
|
||||
if inputs["token_type_ids"] is None:
|
||||
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
|
||||
@@ -563,6 +708,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
position_ids=inputs["position_ids"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
past_key_values_length=past_key_values_length,
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
@@ -571,7 +717,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
|
||||
attention_mask_shape = shape_list(inputs["attention_mask"])
|
||||
|
||||
mask_seq_length = seq_length + past_key_values_length
|
||||
# Copied from `modeling_tf_t5.py`
|
||||
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||
if self.is_decoder:
|
||||
seq_ids = tf.range(mask_seq_length)
|
||||
causal_mask = tf.less_equal(
|
||||
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
|
||||
seq_ids[None, :, None],
|
||||
)
|
||||
causal_mask = tf.cast(causal_mask, dtype=inputs["attention_mask"].dtype)
|
||||
extended_attention_mask = causal_mask * inputs["attention_mask"][:, None, :]
|
||||
attention_mask_shape = shape_list(extended_attention_mask)
|
||||
extended_attention_mask = tf.reshape(
|
||||
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
|
||||
)
|
||||
else:
|
||||
extended_attention_mask = tf.reshape(
|
||||
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -583,6 +751,29 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
|
||||
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
|
||||
|
||||
# Copied from `modeling_tf_t5.py` with -1e9 -> -10000
|
||||
if self.is_decoder and inputs["encoder_attention_mask"] is not None:
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
inputs["encoder_attention_mask"] = tf.cast(
|
||||
inputs["encoder_attention_mask"], dtype=extended_attention_mask.dtype
|
||||
)
|
||||
num_dims_encoder_attention_mask = len(shape_list(inputs["encoder_attention_mask"]))
|
||||
if num_dims_encoder_attention_mask == 3:
|
||||
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, :, :]
|
||||
if num_dims_encoder_attention_mask == 2:
|
||||
encoder_extended_attention_mask = inputs["encoder_attention_mask"][:, None, None, :]
|
||||
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
|
||||
# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
|
||||
# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
|
||||
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@@ -597,6 +788,10 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
hidden_states=embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=inputs["head_mask"],
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
@@ -610,10 +805,12 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
sequence_output,
|
||||
) + encoder_outputs[1:]
|
||||
|
||||
return TFBaseModelOutput(
|
||||
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -625,6 +822,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
"""
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
dummy = {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||
# Add `encoder_hidden_states` to make the cross-attention layers' weights initialized
|
||||
if self.config.add_cross_attention:
|
||||
batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape
|
||||
shape = (batch_size, seq_len) + (self.config.hidden_size,)
|
||||
h = tf.random.uniform(shape=shape)
|
||||
dummy["encoder_hidden_states"] = h
|
||||
|
||||
return dummy
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
|
||||
@@ -732,7 +947,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
@@ -743,12 +958,36 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
||||
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
@@ -758,6 +997,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -771,6 +1014,10 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
@@ -779,12 +1026,26 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
|
||||
def serving_output(self, output: TFBaseModelOutput) -> TFBaseModelOutput:
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
||||
def serving_output(
|
||||
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
|
||||
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
|
||||
output_cache = self.config.use_cache and self.config.is_decoder
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||
cross_attns = None
|
||||
|
||||
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
|
||||
return TFBaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=output.last_hidden_state,
|
||||
pooler_output=output.pooler_output,
|
||||
past_key_values=pkv,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
cross_attentions=cross_attns,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings("""{{cookiecutter.modelname}} Model with a `language modeling` head on top. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING)
|
||||
@@ -903,10 +1164,22 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||
return self.mlm.predictions
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs):
|
||||
# cut decoder_input_ids if past is used
|
||||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": model_kwargs["use_cache"],
|
||||
}
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFCausalLMOutput,
|
||||
output_type=TFCausalLMOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
@@ -917,14 +1190,36 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
|
||||
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers`)
|
||||
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
|
||||
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
@@ -938,6 +1233,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -952,6 +1251,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
@@ -971,19 +1274,28 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFCausalLMOutput(
|
||||
return TFCausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
||||
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
|
||||
def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions:
|
||||
output_cache = self.config.use_cache and self.config.is_decoder
|
||||
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
|
||||
if not (self.config.output_attentions and self.config.add_cross_attention):
|
||||
cross_attns = None
|
||||
|
||||
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
return TFCausalLMOutputWithCrossAttentions(
|
||||
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
|
||||
)
|
||||
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
|
||||
|
||||
@@ -21,7 +21,7 @@ from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Co
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@@ -123,6 +123,33 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user