From 0eaeae2e3675bba04157301fa1d93b7ce3c78dc0 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Thu, 10 Jun 2021 16:28:07 +0200 Subject: [PATCH] Fix a condition in test_generate_with_head_masking (#11911) * Fix a condition in test_generate_with_head_masking * Fix usage of head_mask in bigbirg_pegasus * Fix head masking for speech2text * Resolve copy mismatch + drop unwanted print statement * Fix the condition --- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 3 +++ .../models/speech_to_text/modeling_speech_to_text.py | 4 ++++ tests/test_generation_utils.py | 5 +++-- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index dddfd434b6..3f548ecfc2 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1174,6 +1174,8 @@ class BigBirdPegasusEncoderAttention(nn.Module): from_blocked_mask=None, to_blocked_mask=None, ): + # Expand dims to enable multiplication in the self-attention module + head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None if self.config.attention_type == "original_full": self_outputs = self.self( @@ -1372,6 +1374,7 @@ class BigBirdPegasusEncoderLayer(nn.Module): self_attention_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, + head_mask=layer_head_mask, output_attentions=output_attentions, band_mask=band_mask, from_mask=from_mask, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index dfbea1cf4c..dde154ab46 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1352,6 +1352,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): past=None, attention_mask=None, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs @@ -1366,6 +1368,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 289fa4882c..ed28c77c07 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1095,16 +1095,17 @@ class GenerationTesterMixin: signature = inspect.signature(model.forward) # We want to test only models where encoder/decoder head masking is implemented - if set(head_masking.keys()) < set([*signature.parameters.keys()]): + if not set(head_masking.keys()) < set([*signature.parameters.keys()]): continue for attn_name, (name, mask) in zip(attention_names, head_masking.items()): out = model.generate( input_ids, + attention_mask=attention_mask, num_beams=1, - max_length=max_length, output_attentions=True, return_dict_in_generate=True, + remove_invalid_values=True, **{name: mask}, ) # We check the state of decoder_attentions and cross_attentions just from the last step