Fix a condition in test_generate_with_head_masking (#11911)
* Fix a condition in test_generate_with_head_masking * Fix usage of head_mask in bigbirg_pegasus * Fix head masking for speech2text * Resolve copy mismatch + drop unwanted print statement * Fix the condition
This commit is contained in:
@@ -1174,6 +1174,8 @@ class BigBirdPegasusEncoderAttention(nn.Module):
|
||||
from_blocked_mask=None,
|
||||
to_blocked_mask=None,
|
||||
):
|
||||
# Expand dims to enable multiplication in the self-attention module
|
||||
head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None
|
||||
|
||||
if self.config.attention_type == "original_full":
|
||||
self_outputs = self.self(
|
||||
@@ -1372,6 +1374,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
|
||||
self_attention_outputs = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
band_mask=band_mask,
|
||||
from_mask=from_mask,
|
||||
|
||||
@@ -1352,6 +1352,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
@@ -1366,6 +1368,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
|
||||
@@ -1095,16 +1095,17 @@ class GenerationTesterMixin:
|
||||
|
||||
signature = inspect.signature(model.forward)
|
||||
# We want to test only models where encoder/decoder head masking is implemented
|
||||
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||
if not set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||
continue
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
remove_invalid_values=True,
|
||||
**{name: mask},
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
|
||||
Reference in New Issue
Block a user