Fix a condition in test_generate_with_head_masking (#11911)

* Fix a condition in test_generate_with_head_masking

* Fix usage of head_mask in bigbirg_pegasus

* Fix head masking for speech2text

* Resolve copy mismatch + drop unwanted print statement

* Fix the condition
This commit is contained in:
Daniel Stancl
2021-06-10 16:28:07 +02:00
committed by GitHub
parent bebbdd0fc9
commit 0eaeae2e36
3 changed files with 10 additions and 2 deletions

View File

@@ -1095,16 +1095,17 @@ class GenerationTesterMixin:
signature = inspect.signature(model.forward)
# We want to test only models where encoder/decoder head masking is implemented
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
if not set(head_masking.keys()) < set([*signature.parameters.keys()]):
continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
out = model.generate(
input_ids,
attention_mask=attention_mask,
num_beams=1,
max_length=max_length,
output_attentions=True,
return_dict_in_generate=True,
remove_invalid_values=True,
**{name: mask},
)
# We check the state of decoder_attentions and cross_attentions just from the last step