Fix usage of head masks by TF encoder-decoder models' generate() function (#11775)
* Fix Bart
* Fix Blenderbot{,_small}
* Fix LED
* Fix Marian
* Fix MBart
* Fix Pegasus
* Fix T5
* Add test for generation with head_mask
* Add a common TF test
* Override a test for the LED model as head masking is not yet properly implemented
* Remove all head_masks from input preparation for LED
* Drop masking for T5 as it needs a bit of refactor
This commit is contained in:
@@ -1195,6 +1195,40 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
def test_generate_with_headmasking(self):
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# We want to test only encoder-decoder models
|
||||
if not config.is_encoder_decoder:
|
||||
continue
|
||||
|
||||
head_masking = {
|
||||
"head_mask": tf.zeros((config.encoder_layers, config.encoder_attention_heads)),
|
||||
"decoder_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
|
||||
"cross_attn_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
|
||||
}
|
||||
|
||||
signature = inspect.signature(model.call)
|
||||
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||
continue
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
out = model.generate(
|
||||
inputs_dict["input_ids"],
|
||||
num_beams=1,
|
||||
max_length=inputs_dict["input_ids"] + 5,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
**{name: mask},
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = []
|
||||
|
||||
Reference in New Issue
Block a user