From 15c68c67f4890ef62ce73310d0e1982d5ea91477 Mon Sep 17 00:00:00 2001 From: samuelpullely <51292066+samuelpullely@users.noreply.github.com> Date: Tue, 3 Jan 2023 15:59:08 +0100 Subject: [PATCH] Enable `decoder_attention_mask` in `generate` function (#20726) * Enable `decoder_attention_mask` in `generate` function * Make style corrections * Run `make repo-consistency` * Add integration test --- src/transformers/generation/utils.py | 13 +++++++- src/transformers/models/bart/modeling_bart.py | 2 ++ .../modeling_bigbird_pegasus.py | 2 ++ tests/models/bart/test_modeling_bart.py | 30 +++++++++++++++++++ 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0db1005d95..1f7ac3b2e9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -666,6 +666,9 @@ class GenerationMixin: expand_size, dim=0 ) model_kwargs["encoder_outputs"] = encoder_outputs + decoder_attention_mask = model_kwargs.get("decoder_attention_mask") + if decoder_attention_mask is not None: + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.repeat_interleave(expand_size, dim=0) return input_ids, model_kwargs @@ -701,13 +704,21 @@ class GenerationMixin: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) - # update attention mask if not is_encoder_decoder: + # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], + dim=-1, + ) return model_kwargs diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d62740274c..b85885638e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1420,6 +1420,7 @@ class BartForConditionalGeneration(BartPretrainedModel): decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1437,6 +1438,7 @@ class BartForConditionalGeneration(BartPretrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 0ca41ba9c8..1bff4a6c62 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2619,6 +2619,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): decoder_input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -2636,6 +2637,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 7679c55e2b..d6474c372f 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1226,6 +1226,36 @@ class BartModelIntegrationTests(unittest.TestCase): ], ) + @slow + def test_decoder_attention_mask(self): + model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0).to( + torch_device + ) + tokenizer = self.default_tokenizer + sentence = "UN Chief Says There Is No in Syria" + input_ids = tokenizer(sentence, return_tensors="pt").input_ids.to(torch_device) + padding_size = 3 + decoder_input_ids = torch.tensor( + [ + [model.config.decoder_start_token_id] + + padding_size * [model.config.pad_token_id] + + [model.config.bos_token_id] + ], + dtype=torch.long, + device=torch_device, + ) + decoder_attention_mask = torch.where(decoder_input_ids == model.config.pad_token_id, 0, 1).to(torch_device) + generated_ids = model.generate( + input_ids=input_ids, + use_cache=False, + max_new_tokens=20, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + ) + generated_sentence = tokenizer.batch_decode(generated_ids)[0] + expected_sentence = "UN Chief Says There Is No Plan B for Peace in Syria" + self.assertEqual(generated_sentence, expected_sentence) + class BartStandaloneDecoderModelTester: def __init__(