[T5] Add 3D attention mask to T5 model (2) (#9643) (#11197)

* Add 3D attention mask to T5 model (#9643)

Added code for 3D attention mask in T5 model. Similar to BERT model.

* Add test for 3D attention mask

Added test for 3D attention mask: test_decoder_model_past_with_3d_attn_mask()
3D attention mask of the shape [Batch_size, Seq_length, Seq_length] both for
attention mask and decoder attention mask. Test is passing.
This commit is contained in:
lexhuismans
2021-05-13 13:02:27 +02:00
committed by GitHub
parent 6ee1a4fd3e
commit 91cf29153b
2 changed files with 35 additions and 1 deletions

View File

@@ -530,6 +530,34 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_decoder_model_past_with_3d_attn_mask(self):
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = self.model_tester.prepare_config_and_inputs()
attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
vocab_size=2,
)
decoder_attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
vocab_size=2,
)
self.model_tester.create_and_check_decoder_model_attention_mask_past(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)