Fix usage of head masks by TF encoder-decoder models' generate() function (#11775)
* Fix Bart
* Fix Blenderbot{,_small}
* Fix LED
* Fix Marian
* Fix MBart
* Fix Pegasus
* Fix T5
* Add test for generation with head_mask
* Add a common TF test
* Override a test for the LED model as head masking is not yet properly implemented
* Remove all head_masks from input preparation for LED
* Drop masking for T5 as it needs a bit of refactor
This commit is contained in:
@@ -1452,6 +1452,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
past,
|
past,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -1487,6 +1489,8 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1476,6 +1476,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
past,
|
past,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -1511,6 +1513,8 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1451,6 +1451,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
past,
|
past,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -1486,6 +1488,8 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2477,7 +2477,15 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||||||
encoder_global_attentions=enc_g_attns,
|
encoder_global_attentions=enc_g_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs) -> Dict:
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
decoder_input_ids,
|
||||||
|
past,
|
||||||
|
attention_mask,
|
||||||
|
head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict:
|
||||||
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
|
||||||
if len(past) == 1:
|
if len(past) == 1:
|
||||||
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
|
||||||
@@ -2510,6 +2518,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
|
|||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1480,6 +1480,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
past,
|
past,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -1515,6 +1517,8 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1464,6 +1464,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
past,
|
past,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -1499,6 +1501,8 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1489,6 +1489,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
past,
|
past,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -1524,6 +1526,8 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1464,7 +1464,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||||||
encoder_attentions=enc_attns,
|
encoder_attentions=enc_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
inputs,
|
||||||
|
past,
|
||||||
|
attention_mask,
|
||||||
|
use_cache=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
assert past is not None, "past has to be defined for encoder_outputs"
|
assert past is not None, "past has to be defined for encoder_outputs"
|
||||||
|
|
||||||
# first step
|
# first step
|
||||||
|
|||||||
@@ -1195,6 +1195,40 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
self.assertEqual(loss.shape, [loss_size])
|
self.assertEqual(loss.shape, [loss_size])
|
||||||
|
|
||||||
|
def test_generate_with_headmasking(self):
|
||||||
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
# We want to test only encoder-decoder models
|
||||||
|
if not config.is_encoder_decoder:
|
||||||
|
continue
|
||||||
|
|
||||||
|
head_masking = {
|
||||||
|
"head_mask": tf.zeros((config.encoder_layers, config.encoder_attention_heads)),
|
||||||
|
"decoder_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
|
||||||
|
"cross_attn_head_mask": tf.zeros((config.decoder_layers, config.decoder_attention_heads)),
|
||||||
|
}
|
||||||
|
|
||||||
|
signature = inspect.signature(model.call)
|
||||||
|
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||||
|
out = model.generate(
|
||||||
|
inputs_dict["input_ids"],
|
||||||
|
num_beams=1,
|
||||||
|
max_length=inputs_dict["input_ids"] + 5,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
**{name: mask},
|
||||||
|
)
|
||||||
|
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||||
|
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||||
|
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
|
||||||
|
|
||||||
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
# special tokens cannot be bad tokens
|
# special tokens cannot be bad tokens
|
||||||
special_tokens = []
|
special_tokens = []
|
||||||
|
|||||||
@@ -370,6 +370,10 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# This test is too long (>30sec) and makes fail the CI
|
# This test is too long (>30sec) and makes fail the CI
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_generate_with_headmasking(self):
|
||||||
|
# TODO: Head-masking not yet implement
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
@@ -310,6 +310,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
model = TFT5Model.from_pretrained("t5-small")
|
model = TFT5Model.from_pretrained("t5-small")
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_generate_with_headmasking(self):
|
||||||
|
# TODO: Fix head-masking according to PyTorch T5 model
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TFT5EncoderOnlyModelTester:
|
class TFT5EncoderOnlyModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user