TF BART models - Add cross_attentions to model output and fix cross-attention head masking (#10699)
* Add cross_attn_head_mask to BART * Fix cross_attentions in TFBart-like models * This commit enables returning of `cross_attentions` for TFBart-like models * It also fixes attention head masking in cross-attenion module * Update TF model templates * Fix missing , in TF model templates * Fix typo: congig -> config
This commit is contained in:
@@ -147,7 +147,6 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
||||
return final_embeddings
|
||||
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
@@ -352,6 +351,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs):
|
||||
@@ -625,7 +625,6 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||
|
||||
|
||||
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
|
||||
@@ -885,6 +884,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
|
||||
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
|
||||
)
|
||||
@@ -1728,16 +1728,18 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
||||
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (:obj:`tf.Tensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
@@ -1798,6 +1800,8 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
@@ -1809,6 +1813,10 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
||||
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
`(decoder_attention_heads,)`
|
||||
cross_attn_layer_head_mask (:obj:`tf.Tensor`): mask for heads of the cross-attention module.
|
||||
`(decoder_attention_heads,)`
|
||||
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
@@ -1821,6 +1829,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -1828,15 +1837,17 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
||||
|
||||
# Cross-Attention Block
|
||||
cross_attn_present_key_value = None
|
||||
cross_attn_weights = None
|
||||
if encoder_hidden_states is not None:
|
||||
residual = hidden_states
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
layer_head_mask=cross_attn_layer_head_mask,
|
||||
past_key_value=cross_attn_past_key_value,
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
@@ -1858,6 +1869,7 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
||||
return (
|
||||
hidden_states,
|
||||
self_attn_weights,
|
||||
cross_attn_layer_head_mask,
|
||||
present_key_value,
|
||||
)
|
||||
|
||||
@@ -1965,6 +1977,24 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
the right for denoising pre-training following the paper.
|
||||
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
|
||||
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
|
||||
@@ -2013,7 +2043,6 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@@ -2034,6 +2063,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -2058,6 +2088,12 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
|
||||
@@ -2082,6 +2118,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
@@ -2115,6 +2152,16 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
encoder_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for encoder_layer in self.layers:
|
||||
|
||||
@@ -2125,7 +2172,11 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
|
||||
continue
|
||||
|
||||
hidden_states, attn = encoder_layer(hidden_states, inputs["attention_mask"])
|
||||
hidden_states, attn = encoder_layer(
|
||||
hidden_states,
|
||||
inputs["attention_mask"],
|
||||
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions += (attn,)
|
||||
@@ -2181,6 +2232,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -2218,6 +2271,18 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||
decoding.
|
||||
@@ -2252,6 +2317,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
@@ -2297,8 +2364,20 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
# decoder layers
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_self_attns = () if inputs["output_attentions"] else None
|
||||
all_cross_attns = () if (inputs["output_attentions"] and inputs["encoder_hidden_states"] is not None) else None
|
||||
present_key_values = () if inputs["use_cache"] else None
|
||||
|
||||
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
for attn_mask in ["head_mask", "cross_attn_head_mask"]:
|
||||
if inputs[attn_mask] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs[attn_mask])[0],
|
||||
len(self.layers),
|
||||
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs[attn_mask])[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
@@ -2311,11 +2390,15 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
|
||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||
|
||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
||||
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=combined_attention_mask,
|
||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||
encoder_attention_mask=inputs["encoder_attention_mask"],
|
||||
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
|
||||
cross_attn_layer_head_mask=inputs["cross_attn_head_mask"][idx]
|
||||
if inputs["cross_attn_head_mask"] is not None
|
||||
else None,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
@@ -2325,23 +2408,30 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
if inputs["output_attentions"]:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None:
|
||||
all_cross_attns += (layer_cross_attn,)
|
||||
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_self_attns = list(all_self_attns)
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None:
|
||||
all_cross_attns = list(all_cross_attns)
|
||||
|
||||
if inputs["use_cache"]:
|
||||
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
|
||||
else:
|
||||
return TFBaseModelOutputWithPast(
|
||||
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=present_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attns,
|
||||
)
|
||||
|
||||
@tf.function
|
||||
@@ -2413,6 +2503,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -2431,6 +2524,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -2450,6 +2546,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
inputs["encoder_outputs"] = self.encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
@@ -2472,6 +2569,8 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
attention_mask=inputs["decoder_attention_mask"],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][0],
|
||||
encoder_attention_mask=inputs["attention_mask"],
|
||||
head_mask=inputs["decoder_head_mask"],
|
||||
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
use_cache=inputs["use_cache"],
|
||||
@@ -2489,6 +2588,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
@@ -2524,6 +2624,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -2542,6 +2645,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -2559,6 +2665,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
attention_mask=inputs["attention_mask"],
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
@@ -2577,6 +2686,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||
|
||||
@@ -2585,6 +2695,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
past_key_values=pkv,
|
||||
decoder_hidden_states=dec_hs,
|
||||
decoder_attentions=dec_attns,
|
||||
cross_attentions=cross_attns,
|
||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_attentions=enc_attns,
|
||||
@@ -2637,6 +2748,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
encoder_outputs: Optional[TFBaseModelOutput] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
@@ -2672,6 +2786,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
head_mask=head_mask,
|
||||
decoder_head_mask=decoder_head_mask,
|
||||
cross_attn_head_mask=cross_attn_head_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -2698,6 +2815,9 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
decoder_input_ids=inputs["decoder_input_ids"],
|
||||
encoder_outputs=inputs["encoder_outputs"],
|
||||
decoder_attention_mask=inputs["decoder_attention_mask"],
|
||||
head_mask=inputs["head_mask"],
|
||||
decoder_head_mask=inputs["decoder_head_mask"],
|
||||
cross_attn_head_mask=inputs["cross_attn_head_mask"],
|
||||
past_key_values=inputs["past_key_values"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
|
||||
@@ -2720,6 +2840,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||
cross_attentions=outputs.cross_attentions, # index 4 of d outputs
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||
@@ -2730,6 +2851,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||
|
||||
@@ -2738,6 +2860,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
||||
past_key_values=pkv,
|
||||
decoder_hidden_states=dec_hs,
|
||||
decoder_attentions=dec_attns,
|
||||
cross_attentions=cross_attns,
|
||||
encoder_last_hidden_state=output.encoder_last_hidden_state,
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_attentions=enc_attns,
|
||||
|
||||
Reference in New Issue
Block a user