Enable decoder_attention_mask in generate function (#20726)
* Enable `decoder_attention_mask` in `generate` function * Make style corrections * Run `make repo-consistency` * Add integration test
This commit is contained in:
@@ -666,6 +666,9 @@ class GenerationMixin:
|
||||
expand_size, dim=0
|
||||
)
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
decoder_attention_mask = model_kwargs.get("decoder_attention_mask")
|
||||
if decoder_attention_mask is not None:
|
||||
model_kwargs["decoder_attention_mask"] = decoder_attention_mask.repeat_interleave(expand_size, dim=0)
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
@@ -701,13 +704,21 @@ class GenerationMixin:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if not is_encoder_decoder:
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
else:
|
||||
# update decoder attention mask
|
||||
if "decoder_attention_mask" in model_kwargs:
|
||||
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||
model_kwargs["decoder_attention_mask"] = torch.cat(
|
||||
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
@@ -1420,6 +1420,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
@@ -1437,6 +1438,7 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
|
||||
@@ -2619,6 +2619,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
@@ -2636,6 +2637,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
||||
"past_key_values": past,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
|
||||
@@ -1226,6 +1226,36 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_decoder_attention_mask(self):
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0).to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = self.default_tokenizer
|
||||
sentence = "UN Chief Says There Is No <mask> in Syria"
|
||||
input_ids = tokenizer(sentence, return_tensors="pt").input_ids.to(torch_device)
|
||||
padding_size = 3
|
||||
decoder_input_ids = torch.tensor(
|
||||
[
|
||||
[model.config.decoder_start_token_id]
|
||||
+ padding_size * [model.config.pad_token_id]
|
||||
+ [model.config.bos_token_id]
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
decoder_attention_mask = torch.where(decoder_input_ids == model.config.pad_token_id, 0, 1).to(torch_device)
|
||||
generated_ids = model.generate(
|
||||
input_ids=input_ids,
|
||||
use_cache=False,
|
||||
max_new_tokens=20,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
generated_sentence = tokenizer.batch_decode(generated_ids)[0]
|
||||
expected_sentence = "</s><pad><pad><pad><s>UN Chief Says There Is No Plan B for Peace in Syria</s>"
|
||||
self.assertEqual(generated_sentence, expected_sentence)
|
||||
|
||||
|
||||
class BartStandaloneDecoderModelTester:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user