Fix a condition in test_generate_with_head_masking (#11911)
* Fix a condition in test_generate_with_head_masking * Fix usage of head_mask in bigbirg_pegasus * Fix head masking for speech2text * Resolve copy mismatch + drop unwanted print statement * Fix the condition
This commit is contained in:
@@ -1174,6 +1174,8 @@ class BigBirdPegasusEncoderAttention(nn.Module):
|
|||||||
from_blocked_mask=None,
|
from_blocked_mask=None,
|
||||||
to_blocked_mask=None,
|
to_blocked_mask=None,
|
||||||
):
|
):
|
||||||
|
# Expand dims to enable multiplication in the self-attention module
|
||||||
|
head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None
|
||||||
|
|
||||||
if self.config.attention_type == "original_full":
|
if self.config.attention_type == "original_full":
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
@@ -1372,6 +1374,7 @@ class BigBirdPegasusEncoderLayer(nn.Module):
|
|||||||
self_attention_outputs = self.self_attn(
|
self_attention_outputs = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
band_mask=band_mask,
|
band_mask=band_mask,
|
||||||
from_mask=from_mask,
|
from_mask=from_mask,
|
||||||
|
|||||||
@@ -1352,6 +1352,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
|||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@@ -1366,6 +1368,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1095,16 +1095,17 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
signature = inspect.signature(model.forward)
|
signature = inspect.signature(model.forward)
|
||||||
# We want to test only models where encoder/decoder head masking is implemented
|
# We want to test only models where encoder/decoder head masking is implemented
|
||||||
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
if not set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||||
out = model.generate(
|
out = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
remove_invalid_values=True,
|
||||||
**{name: mask},
|
**{name: mask},
|
||||||
)
|
)
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||||
|
|||||||
Reference in New Issue
Block a user