Fix usage of head masks by PT encoder-decoder models' generate() function (#11621)
* Add missing head masking for generate() function * Add head_mask, decoder_head_mask and cross_attn_head_mask into prepare_inputs_for_generation for generate() function for multiple encoder-decoder models. * Add test_genereate_with_head_masking * [WIP] Update the new test and handle special cases * make style * Omit ProphetNet test so far * make fix-copies
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
@@ -1072,6 +1073,40 @@ class GenerationTesterMixin:
|
||||
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
|
||||
)
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
model = model_class(config)
|
||||
# We want to test only encoder-decoder models
|
||||
if not config.is_encoder_decoder:
|
||||
continue
|
||||
|
||||
head_masking = {
|
||||
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads),
|
||||
"decoder_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
|
||||
"cross_attn_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
|
||||
}
|
||||
|
||||
signature = inspect.signature(model.forward)
|
||||
# We want to test only models where encoder/decoder head masking is implemented
|
||||
if set(head_masking.keys()) < set([*signature.parameters.keys()]):
|
||||
continue
|
||||
|
||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
**{name: mask},
|
||||
)
|
||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
Reference in New Issue
Block a user