From 91cf29153bd441320437c5a84acfa8d3481dcefe Mon Sep 17 00:00:00 2001 From: lexhuismans <43178421+lexhuismans@users.noreply.github.com> Date: Thu, 13 May 2021 13:02:27 +0200 Subject: [PATCH] [T5] Add 3D attention mask to T5 model (2) (#9643) (#11197) * Add 3D attention mask to T5 model (#9643) Added code for 3D attention mask in T5 model. Similar to BERT model. * Add test for 3D attention mask Added test for 3D attention mask: test_decoder_model_past_with_3d_attn_mask() 3D attention mask of the shape [Batch_size, Seq_length, Seq_length] both for attention mask and decoder attention mask. Test is passing. --- src/transformers/models/t5/modeling_t5.py | 8 ++++++- tests/test_modeling_t5.py | 28 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index adf9430d9e..97838a5bdf 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -914,7 +914,13 @@ class T5Stack(T5PreTrainedModel): # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is not None: + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index e72c05e90f..31b712b075 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -530,6 +530,34 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + def test_decoder_model_past_with_3d_attn_mask(self): + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = self.model_tester.prepare_config_and_inputs() + + attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], + vocab_size=2, + ) + decoder_attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length], + vocab_size=2, + ) + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)