[TF Led] Fix wrong decoder attention mask behavior (#9601)
* fix tf led * remove loop file
This commit is contained in:
committed by
GitHub
parent
85788bae5c
commit
90ca8d36e9
@@ -1862,7 +1862,6 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||||||
hidden_states = inputs["inputs_embeds"]
|
hidden_states = inputs["inputs_embeds"]
|
||||||
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
combined_attention_mask = None
|
|
||||||
if input_shape[-1] > 1:
|
if input_shape[-1] > 1:
|
||||||
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||||
else:
|
else:
|
||||||
@@ -1870,20 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
|
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||||
inputs["attention_mask"] = tf.cast(
|
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||||
)
|
|
||||||
inputs["attention_mask"] = tf.concat(
|
|
||||||
[
|
|
||||||
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
|
|
||||||
inputs["attention_mask"],
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inputs["attention_mask"] = tf.ones(
|
|
||||||
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user