Fix TF T5/LED missing cross attn in retrun values (#15511)
* add cross attn to outputs * add cross attn to outputs for TFLED * add undo padding * remove unused import * fix style Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from ...file_utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import TFBaseModelOutputWithPast
|
from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions
|
||||||
|
|
||||||
# Public API
|
# Public API
|
||||||
from ...modeling_tf_utils import (
|
from ...modeling_tf_utils import (
|
||||||
@@ -1220,7 +1220,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
|
|||||||
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
encoder_layer_head_mask: Optional[tf.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||||
training=False,
|
training=False,
|
||||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||||
@@ -1254,12 +1254,13 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
cross_attn_present_key_value = None
|
||||||
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
@@ -1285,6 +1286,7 @@ class TFLEDDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
|
cross_attn_weights,
|
||||||
present_key_value,
|
present_key_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1808,6 +1810,14 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||||||
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
|
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
|
||||||
hidden_states = self.compute_hidden_states(hidden_states, padding_len)
|
hidden_states = self.compute_hidden_states(hidden_states, padding_len)
|
||||||
|
|
||||||
|
# undo padding
|
||||||
|
if inputs["output_attentions"]:
|
||||||
|
all_attentions = (
|
||||||
|
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
|
||||||
|
if padding_len > 0
|
||||||
|
else all_attentions
|
||||||
|
)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
@@ -2038,6 +2048,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||||||
# decoder layers
|
# decoder layers
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_self_attns = ()
|
all_self_attns = ()
|
||||||
|
all_cross_attentions = ()
|
||||||
present_key_values = ()
|
present_key_values = ()
|
||||||
|
|
||||||
# check if head_mask has a correct number of layers specified if desired
|
# check if head_mask has a correct number of layers specified if desired
|
||||||
@@ -2059,7 +2070,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None
|
||||||
|
|
||||||
hidden_states, layer_self_attn, present_key_value = decoder_layer(
|
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=combined_attention_mask,
|
attention_mask=combined_attention_mask,
|
||||||
encoder_hidden_states=inputs["encoder_hidden_states"],
|
encoder_hidden_states=inputs["encoder_hidden_states"],
|
||||||
@@ -2076,24 +2087,31 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
all_cross_attentions += (layer_cross_attn,)
|
||||||
|
|
||||||
if inputs["output_hidden_states"]:
|
if inputs["output_hidden_states"]:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
else:
|
else:
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
|
|
||||||
all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
|
all_self_attns = all_self_attns if inputs["output_attentions"] else None
|
||||||
|
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
|
||||||
|
|
||||||
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
|
||||||
|
|
||||||
if not inputs["return_dict"]:
|
if not inputs["return_dict"]:
|
||||||
return hidden_states, present_key_values, all_hidden_states, all_self_attns
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_values,
|
past_key_values=present_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -2223,6 +2241,7 @@ class TFLEDMainLayer(tf.keras.layers.Layer):
|
|||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -2475,6 +2494,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||||||
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
past_key_values=outputs.past_key_values, # index 1 of d outputs
|
||||||
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
|
||||||
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
|
||||||
|
cross_attentions=outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
|
||||||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from ...file_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_tf_outputs import (
|
from ...modeling_tf_outputs import (
|
||||||
TFBaseModelOutput,
|
TFBaseModelOutput,
|
||||||
TFBaseModelOutputWithPast,
|
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||||
TFSeq2SeqLMOutput,
|
TFSeq2SeqLMOutput,
|
||||||
TFSeq2SeqModelOutput,
|
TFSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
@@ -771,6 +771,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
||||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||||
all_attentions = () if inputs["output_attentions"] else None
|
all_attentions = () if inputs["output_attentions"] else None
|
||||||
|
all_cross_attentions = () if (inputs["output_attentions"] and self.is_decoder) else None
|
||||||
position_bias = None
|
position_bias = None
|
||||||
encoder_decoder_position_bias = None
|
encoder_decoder_position_bias = None
|
||||||
|
|
||||||
@@ -814,6 +815,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
all_attentions = all_attentions + (layer_outputs[3],)
|
all_attentions = all_attentions + (layer_outputs[3],)
|
||||||
|
if self.is_decoder:
|
||||||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
||||||
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||||
@@ -831,14 +834,17 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if inputs["output_attentions"]:
|
if inputs["output_attentions"]:
|
||||||
outputs = outputs + (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
if self.is_decoder:
|
||||||
|
outputs + (all_cross_attentions,)
|
||||||
|
return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions)
|
||||||
|
|
||||||
if self.is_decoder:
|
if self.is_decoder:
|
||||||
return TFBaseModelOutputWithPast(
|
return TFBaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=present_key_value_states,
|
past_key_values=present_key_value_states,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_attentions,
|
attentions=all_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return TFBaseModelOutput(
|
return TFBaseModelOutput(
|
||||||
@@ -1264,6 +1270,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||||||
past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
@@ -1508,6 +1515,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
past_key_values=past,
|
past_key_values=past,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(attentions[0].shape[-3:]),
|
list(attentions[0].shape[-3:]),
|
||||||
[self.model_tester.num_attention_heads, encoder_seq_length, seq_length],
|
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
list(global_attentions[0].shape[-3:]),
|
list(global_attentions[0].shape[-3:]),
|
||||||
|
|||||||
Reference in New Issue
Block a user