Fix usage of head masks by TF encoder-decoder models' generate() function (#11775)

* Fix Bart

* Fix Blenderbot{,_small}

* Fix LED

* Fix Marian

* Fix MBart

* Fix Pegasus

* Fix T5

* Add test for generation with head_mask

* Add a common TF test

* Override a test for the LED model as head masking is not yet properly implemented

* Remove all head_masks from input preparation for LED

* Drop masking for T5 as it needs a bit of refactor
This commit is contained in:
Daniel Stancl
2021-05-26 15:02:44 +02:00
committed by GitHub
parent 0b0a598452
commit 0b93358447
11 changed files with 84 additions and 2 deletions

View File

@@ -1195,6 +1195,40 @@ class TFModelTesterMixin:
self.assertEqual(loss.shape, [loss_size])
def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config)
# We want to test only encoder-decoder models
if not config.is_encoder_decoder:
continue
head_masking = {
"head_mask": tf.zeros((config.encoder_layers, config.encoder_attention_heads)),
"decoder_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
"cross_attn_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
}
signature = inspect.signature(model.call)
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
out = model.generate(
inputs_dict["input_ids"],
num_beams=1,
max_length=inputs_dict["input_ids"] + 5,
output_attentions=True,
return_dict_in_generate=True,
**{name: mask},
)
# We check the state of decoder_attentions and cross_attentions just from the last step
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []

View File

@@ -370,6 +370,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_generate_with_headmasking(self):
# TODO: Head-masking not yet implement
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFT5Model.from_pretrained("t5-small")
self.assertIsNotNone(model)
def test_generate_with_headmasking(self):
# TODO: Fix head-masking according to PyTorch T5 model
pass
class TFT5EncoderOnlyModelTester:
def __init__(