From 90ca8d36e9df6dbc15fc7743792b576360ded0ed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 15 Jan 2021 12:40:27 +0100 Subject: [PATCH] [TF Led] Fix wrong decoder attention mask behavior (#9601) * fix tf led * remove loop file --- src/transformers/models/led/modeling_tf_led.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index c93569d5f8..b8eb89ce95 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -1862,7 +1862,6 @@ class TFLEDDecoder(tf.keras.layers.Layer): hidden_states = inputs["inputs_embeds"] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) else: @@ -1870,20 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer): tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] ) - if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1: - inputs["attention_mask"] = tf.cast( - tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype - ) - inputs["attention_mask"] = tf.concat( - [ - tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype), - inputs["attention_mask"], - ], - axis=-1, - ) - else: - inputs["attention_mask"] = tf.ones( - (input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32 + if inputs["attention_mask"] is not None and input_shape[-1] > 1: + combined_attention_mask = combined_attention_mask + _expand_mask( + inputs["attention_mask"], tgt_len=input_shape[-1] ) if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: