From 0b933584473c7e7a1d1e231ffa3e95b71c3e2139 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 26 May 2021 15:02:44 +0200 Subject: [PATCH] Fix usage of head masks by TF encoder-decoder models' `generate()` function (#11775) * Fix Bart * Fix Blenderbot{,_small} * Fix LED * Fix Marian * Fix MBart * Fix Pegasus * Fix T5 * Add test for generation with head_mask * Add a common TF test * Override a test for the LED model as head masking is not yet properly implemented * Remove all head_masks from input preparation for LED * Drop masking for T5 as it needs a bit of refactor --- .../models/bart/modeling_tf_bart.py | 4 +++ .../blenderbot/modeling_tf_blenderbot.py | 4 +++ .../modeling_tf_blenderbot_small.py | 4 +++ .../models/led/modeling_tf_led.py | 11 +++++- .../models/marian/modeling_tf_marian.py | 4 +++ .../models/mbart/modeling_tf_mbart.py | 4 +++ .../models/pegasus/modeling_tf_pegasus.py | 4 +++ src/transformers/models/t5/modeling_tf_t5.py | 9 ++++- tests/test_modeling_tf_common.py | 34 +++++++++++++++++++ tests/test_modeling_tf_led.py | 4 +++ tests/test_modeling_tf_t5.py | 4 +++ 11 files changed, 84 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 41f5f95918..0d925c652a 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1452,6 +1452,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode past, attention_mask, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, **kwargs, ) -> Dict: @@ -1487,6 +1489,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 687cd2c7b8..3e25194806 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1476,6 +1476,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal past, attention_mask, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, **kwargs, ) -> Dict: @@ -1511,6 +1513,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 49bc59757b..ef0bb6e4f3 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1451,6 +1451,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel past, attention_mask, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, **kwargs, ) -> Dict: @@ -1486,6 +1488,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 7752044c22..3719893991 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -2477,7 +2477,15 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): encoder_global_attentions=enc_g_attns, ) - def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict: + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past, + attention_mask, + head_mask=None, + use_cache=None, + **kwargs, + ) -> Dict: assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" if len(past) == 1: assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" @@ -2510,6 +2518,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 81ad6b8185..b9e951e5c3 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1480,6 +1480,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): past, attention_mask, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, **kwargs, ) -> Dict: @@ -1515,6 +1517,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index a17d9ad1a0..7f42002d2f 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1464,6 +1464,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo past, attention_mask, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, **kwargs, ) -> Dict: @@ -1499,6 +1501,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 3fadffad18..2829954ea5 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1489,6 +1489,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua past, attention_mask, head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, **kwargs, ) -> Dict: @@ -1524,6 +1526,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 4d70cb2c3e..284fdb1573 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1464,7 +1464,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling encoder_attentions=enc_attns, ) - def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs): + def prepare_inputs_for_generation( + self, + inputs, + past, + attention_mask, + use_cache=None, + **kwargs, + ): assert past is not None, "past has to be defined for encoder_outputs" # first step diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 36ce1fbf17..b46ac03129 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1195,6 +1195,40 @@ class TFModelTesterMixin: self.assertEqual(loss.shape, [loss_size]) + def test_generate_with_headmasking(self): + attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + # We want to test only encoder-decoder models + if not config.is_encoder_decoder: + continue + + head_masking = { + "head_mask": tf.zeros((config.encoder_layers, config.encoder_attention_heads)), + "decoder_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)), + "cross_attn_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)), + } + + signature = inspect.signature(model.call) + if set(head_masking.keys()) < set([*signature.parameters.keys()]): + continue + + for attn_name, (name, mask) in zip(attention_names, head_masking.items()): + out = model.generate( + inputs_dict["input_ids"], + num_beams=1, + max_length=inputs_dict["input_ids"] + 5, + output_attentions=True, + return_dict_in_generate=True, + **{name: mask}, + ) + # We check the state of decoder_attentions and cross_attentions just from the last step + attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1] + self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0) + def _generate_random_bad_tokens(self, num_bad_tokens, model): # special tokens cannot be bad tokens special_tokens = [] diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index a10ceb6f2d..41d132c80b 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -370,6 +370,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass + def test_generate_with_headmasking(self): + # TODO: Head-masking not yet implement + pass + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 28b501a7ab..a902363fbd 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): model = TFT5Model.from_pretrained("t5-small") self.assertIsNotNone(model) + def test_generate_with_headmasking(self): + # TODO: Fix head-masking according to PyTorch T5 model + pass + class TFT5EncoderOnlyModelTester: def __init__(