Fix a condition in test_generate_with_head_masking (#11911)
* Fix a condition in test_generate_with_head_masking * Fix usage of head_mask in bigbirg_pegasus * Fix head masking for speech2text * Resolve copy mismatch + drop unwanted print statement * Fix the condition
This commit is contained in:
@@ -1095,16 +1095,17 @@ class GenerationTesterMixin:
|
||||
|
||||
signature = inspect.signature(model.forward)
|
||||
# We want to test only models where encoder/decoder head masking is implemented
|
||||
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||
if not set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||
continue
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
remove_invalid_values=True,
|
||||
**{name: mask},
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
|
||||
Reference in New Issue
Block a user