Fix usage of head masks by PT encoder-decoder models' generate() function (#11621)

* Add missing head masking for generate() function

* Add head_mask, decoder_head_mask and cross_attn_head_mask
into prepare_inputs_for_generation for generate() function
for multiple encoder-decoder models.

* Add test_genereate_with_head_masking

* [WIP] Update the new test and handle special cases

* make style

* Omit ProphetNet test so far

* make fix-copies
This commit is contained in:
Daniel Stancl
2021-05-19 01:44:53 +02:00
committed by GitHub
parent ca33278fdb
commit 680d181ce8
16 changed files with 148 additions and 4 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import inspect
import unittest
from transformers import is_torch_available
@@ -1072,6 +1073,40 @@ class GenerationTesterMixin:
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
)
def test_generate_with_head_masking(self):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config)
# We want to test only encoder-decoder models
if not config.is_encoder_decoder:
continue
head_masking = {
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads),
"decoder_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
"cross_attn_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
}
signature = inspect.signature(model.forward)
# We want to test only models where encoder/decoder head masking is implemented
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
out = model.generate(
input_ids,
num_beams=1,
max_length=max_length,
output_attentions=True,
return_dict_in_generate=True,
**{name: mask},
)
# We check the state of decoder_attentions and cross_attentions just from the last step
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences

View File

@@ -1088,6 +1088,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(encoder_attentions.grad)
def test_generate_with_head_masking(self):
"""Generating with head_masking has not been implemented for ProphetNet models yet."""
pass
@require_torch
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):

View File

@@ -600,6 +600,37 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
input_names=["input_ids", "decoder_input_ids"],
)
def test_generate_with_head_masking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
max_length = config_and_inputs[1].shape[-1] + 3
model = T5ForConditionalGeneration(config)
head_masking = {
"head_mask": torch.zeros(config.num_layers, config.num_heads),
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads),
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads),
}
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
head_masks = {name: mask}
# Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified
if name == "head_mask":
head_masks["decoder_head_mask"] = torch.ones(config.num_decoder_layers, config.num_heads)
out = model.generate(
config_and_inputs[1],
num_beams=1,
max_length=max_length,
output_attentions=True,
return_dict_in_generate=True,
**head_masks,
)
# We check the state of decoder_attentions and cross_attentions just from the last step
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
class T5EncoderOnlyModelTester:
def __init__(