Add head_mask/decoder_head_mask for TF BART models (#9639)

* Add head_mask/decoder_head_mask for TF BART models

* Add head_mask and decoder_head_mask input arguments for TF BART-based
models as a TF counterpart to the PR #9569

* Add test_headmasking functionality to tests/test_modeling_tf_common.py

* TODO: Add a test to verify that we can get a gradient back for
importance score computation

* Remove redundant #TODO note

Remove redundant #TODO note from tests/test_modeling_tf_common.py

* Fix assertions

* Make style

* Fix ...Model input args and adjust one new test

* Add back head_mask and decoder_head_mask to BART-based ...Model
after the last commit

* Remove head_mask ande decoder_head_mask from input_dict
in TF test_train_pipeline_custom_model as these two have different
shape than other input args (Necessary for passing this test)

* Revert adding global_rng in test_modeling_tf_common.py
This commit is contained in:
Daniel Stancl
2021-01-26 09:50:00 +01:00
committed by GitHub
parent cb73ab5a38
commit 1867d9a8d7
32 changed files with 849 additions and 36 deletions

View File

@@ -164,6 +164,7 @@ class TFBartAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@@ -230,6 +231,17 @@ class TFBartAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
@@ -266,16 +278,18 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
"""
residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
@@ -331,6 +345,8 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -342,6 +358,10 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
@@ -354,6 +374,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@@ -370,6 +391,7 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
@@ -527,6 +549,18 @@ BART_INPUTS_DOCSTRING = r"""
the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -593,6 +627,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@@ -617,6 +652,12 @@ class TFBartEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -635,6 +676,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -670,8 +712,15 @@ class TFBartEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
@@ -680,7 +729,11 @@ class TFBartEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]:
all_attentions += (attn,)
@@ -737,6 +790,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
@@ -774,6 +829,19 @@ class TFBartDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@@ -802,6 +870,8 @@ class TFBartDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
@@ -858,6 +928,13 @@ class TFBartDecoder(tf.keras.layers.Layer):
all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
@@ -875,6 +952,10 @@ class TFBartDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
)
@@ -945,6 +1026,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -963,6 +1046,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -993,6 +1078,7 @@ class TFBartMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
@@ -1015,6 +1101,8 @@ class TFBartMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
@@ -1067,6 +1155,8 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1085,6 +1175,8 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1102,6 +1194,8 @@ class TFBartModel(TFBartPretrainedModel):
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
@@ -1179,6 +1273,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1207,6 +1303,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1233,6 +1331,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1277,7 +1377,15 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
encoder_attentions=enc_attns,
)
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
@@ -1309,6 +1417,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

View File

@@ -167,6 +167,7 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@@ -233,6 +234,17 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
@@ -270,17 +282,19 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
@@ -336,6 +350,8 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -347,6 +363,10 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
@@ -360,6 +380,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@@ -376,6 +397,7 @@ class TFBlenderbotDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
@@ -524,6 +546,18 @@ BLENDERBOT_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`).
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -590,6 +624,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@@ -614,6 +649,12 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -632,6 +673,7 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -666,8 +708,15 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
@@ -676,7 +725,11 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]:
all_attentions += (attn,)
@@ -735,6 +788,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
@@ -772,6 +827,19 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@@ -800,6 +868,8 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
@@ -855,6 +925,14 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
all_hidden_states = ()
all_self_attns = ()
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
@@ -871,6 +949,10 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
)
@@ -943,6 +1025,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -961,6 +1045,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -983,6 +1069,7 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
@@ -1005,6 +1092,8 @@ class TFBlenderbotMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
@@ -1070,6 +1159,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1088,6 +1179,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1105,6 +1198,8 @@ class TFBlenderbotModel(TFBlenderbotPreTrainedModel):
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
@@ -1196,6 +1291,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1224,6 +1321,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1249,6 +1348,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1295,7 +1396,15 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
@@ -1327,6 +1436,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel):
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

View File

@@ -166,6 +166,7 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@@ -232,6 +233,17 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
@@ -269,16 +281,18 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
"""
residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
@@ -335,6 +349,8 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -346,6 +362,10 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
@@ -358,6 +378,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@@ -374,6 +395,7 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
@@ -529,6 +551,18 @@ BLENDERBOT_SMALL_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`).
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -595,6 +629,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@@ -619,6 +654,12 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -637,6 +678,7 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -672,8 +714,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
@@ -682,7 +731,11 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]:
all_attentions += (attn,)
@@ -740,6 +793,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
@@ -777,6 +832,19 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@@ -805,6 +873,8 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
@@ -859,6 +929,13 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
all_self_attns = ()
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
@@ -875,6 +952,10 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
)
@@ -945,6 +1026,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -963,6 +1046,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -985,6 +1070,7 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
@@ -1007,6 +1093,8 @@ class TFBlenderbotSmallMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
@@ -1059,6 +1147,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1077,6 +1167,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1094,6 +1186,8 @@ class TFBlenderbotSmallModel(TFBlenderbotSmallPreTrainedModel):
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
@@ -1172,6 +1266,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1200,6 +1296,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1225,6 +1323,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1271,7 +1371,15 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
@@ -1303,6 +1411,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

View File

@@ -196,6 +196,7 @@ class TFMarianAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@@ -262,6 +263,17 @@ class TFMarianAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
@@ -299,16 +311,18 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
"""
residual = hidden_states
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
@@ -365,6 +379,8 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -376,6 +392,10 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
@@ -388,6 +408,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@@ -404,6 +425,7 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
@@ -548,6 +570,18 @@ MARIAN_INPUTS_DOCSTRING = r"""
:obj:`past_key_values`).
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -612,6 +646,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@@ -636,6 +671,12 @@ class TFMarianEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -654,6 +695,7 @@ class TFMarianEncoder(tf.keras.layers.Layer):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -688,8 +730,15 @@ class TFMarianEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
@@ -698,7 +747,11 @@ class TFMarianEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]:
all_attentions += (attn,)
@@ -753,6 +806,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
@@ -790,6 +845,19 @@ class TFMarianDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@@ -818,6 +886,8 @@ class TFMarianDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
@@ -872,6 +942,14 @@ class TFMarianDecoder(tf.keras.layers.Layer):
all_hidden_states = ()
all_self_attns = ()
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
@@ -888,6 +966,10 @@ class TFMarianDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
)
@@ -958,6 +1040,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -976,6 +1060,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1001,6 +1087,7 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
@@ -1023,6 +1110,8 @@ class TFMarianMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
@@ -1075,6 +1164,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1092,6 +1183,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
@@ -1110,6 +1203,8 @@ class TFMarianModel(TFMarianPreTrainedModel):
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
@@ -1188,6 +1283,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1216,6 +1313,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1242,6 +1341,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1288,7 +1389,15 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
@@ -1320,6 +1429,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel):
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

View File

@@ -170,6 +170,7 @@ class TFMBartAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@@ -236,6 +237,17 @@ class TFMBartAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
@@ -272,17 +284,19 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
@@ -337,6 +351,8 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -348,6 +364,10 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
@@ -361,6 +381,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@@ -377,6 +398,7 @@ class TFMBartDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
@@ -505,6 +527,18 @@ MBART_INPUTS_DOCSTRING = r"""
the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -601,6 +635,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@@ -625,6 +660,12 @@ class TFMBartEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -643,6 +684,7 @@ class TFMBartEncoder(tf.keras.layers.Layer):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -678,8 +720,15 @@ class TFMBartEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
@@ -688,7 +737,11 @@ class TFMBartEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]:
all_attentions += (attn,)
@@ -748,6 +801,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
@@ -785,6 +840,19 @@ class TFMBartDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@@ -813,6 +881,8 @@ class TFMBartDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
@@ -868,6 +938,14 @@ class TFMBartDecoder(tf.keras.layers.Layer):
all_hidden_states = ()
all_self_attns = ()
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
@@ -884,6 +962,10 @@ class TFMBartDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
)
@@ -956,6 +1038,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -974,6 +1058,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1002,6 +1088,7 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
@@ -1024,6 +1111,8 @@ class TFMBartMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
@@ -1076,6 +1165,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1094,6 +1185,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1111,6 +1204,8 @@ class TFMBartModel(TFMBartPreTrainedModel):
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
@@ -1189,6 +1284,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1217,6 +1314,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1241,6 +1340,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1287,7 +1388,15 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
@@ -1319,6 +1428,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel):
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}

View File

@@ -197,6 +197,7 @@ class TFPegasusAttention(tf.keras.layers.Layer):
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
@@ -263,6 +264,17 @@ class TFPegasusAttention(tf.keras.layers.Layer):
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
@@ -300,17 +312,19 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
self.fc2 = tf.keras.layers.Dense(self.embed_dim, name="fc2")
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, training=False):
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
"""
Args:
hidden_states (:obj:`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (:obj:`tf.Tensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
@@ -366,6 +380,8 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
encoder_hidden_states: Optional[tf.Tensor] = None,
encoder_attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
@@ -377,6 +393,10 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
encoder_hidden_states (:obj:`tf.Tensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (:obj:`tf.Tensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`tf.Tensor`): mask for attention heads in a given layer of size
`(decoder_attention_heads,)`
encoder_layer_head_mask (:obj:`tf.Tensor`): mask for encoder attention heads in a given layer of size
`(encoder_attention_heads,)`
past_key_value (:obj:`Tuple(tf.Tensor)`): cached past key and value projection states
"""
residual = hidden_states
@@ -390,6 +410,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@@ -406,6 +427,7 @@ class TFPegasusDecoderLayer(tf.keras.layers.Layer):
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
past_key_value=cross_attn_past_key_value,
)
hidden_states = self.dropout(hidden_states, training=training)
@@ -553,6 +575,18 @@ PEGASUS_INPUTS_DOCSTRING = r"""
the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
decoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
@@ -618,6 +652,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@@ -642,6 +677,12 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
@@ -660,6 +701,7 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
@@ -694,8 +736,15 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
encoder_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
@@ -704,7 +753,11 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
if inputs["training"] and (dropout_probability < self.layerdrop): # skip the layer
continue
hidden_states, attn = encoder_layer(hidden_states, attention_mask)
hidden_states, attn = encoder_layer(
hidden_states,
attention_mask,
inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
)
if inputs["output_attentions"]:
all_attentions += (attn,)
@@ -762,6 +815,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
@@ -799,6 +854,19 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`tf.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`tf.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`Tuple[Tuple[tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
@@ -827,6 +895,8 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
@@ -881,6 +951,14 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
all_hidden_states = ()
all_self_attns = ()
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
@@ -897,6 +975,10 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
encoder_attention_mask=inputs["encoder_attention_mask"],
layer_head_mask=inputs["head_mask"][idx] if inputs["head_mask"] is not None else None,
encoder_layer_head_mask=inputs["encoder_head_mask"][idx]
if inputs["encoder_head_mask"] is not None
else None,
past_key_value=past_key_value,
)
@@ -969,6 +1051,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -987,6 +1071,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1012,6 +1098,7 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
inputs["encoder_outputs"] = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
@@ -1034,6 +1121,8 @@ class TFPegasusMainLayer(tf.keras.layers.Layer):
attention_mask=inputs["decoder_attention_mask"],
encoder_hidden_states=inputs["encoder_outputs"][0],
encoder_attention_mask=inputs["attention_mask"],
head_mask=inputs["decoder_head_mask"],
encoder_head_mask=inputs["head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
@@ -1086,6 +1175,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1104,6 +1195,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1121,6 +1214,8 @@ class TFPegasusModel(TFPegasusPreTrainedModel):
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
@@ -1199,6 +1294,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
inputs_embeds=None,
@@ -1227,6 +1324,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@@ -1253,6 +1352,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
decoder_input_ids=inputs["decoder_input_ids"],
encoder_outputs=inputs["encoder_outputs"],
decoder_attention_mask=inputs["decoder_attention_mask"],
head_mask=inputs["head_mask"],
decoder_head_mask=inputs["decoder_head_mask"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
@@ -1299,7 +1400,15 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
head_mask=None,
use_cache=None,
**kwargs,
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
@@ -1331,6 +1440,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel):
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}