Enable decoder_attention_mask in generate function (#20726)

* Enable `decoder_attention_mask` in `generate` function

* Make style corrections

* Run `make repo-consistency`

* Add integration test
This commit is contained in:
samuelpullely
2023-01-03 15:59:08 +01:00
committed by GitHub
parent a9653400d3
commit 15c68c67f4
4 changed files with 46 additions and 1 deletions

View File

@@ -1226,6 +1226,36 @@ class BartModelIntegrationTests(unittest.TestCase):
],
)
@slow
def test_decoder_attention_mask(self):
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0).to(
torch_device
)
tokenizer = self.default_tokenizer
sentence = "UN Chief Says There Is No <mask> in Syria"
input_ids = tokenizer(sentence, return_tensors="pt").input_ids.to(torch_device)
padding_size = 3
decoder_input_ids = torch.tensor(
[
[model.config.decoder_start_token_id]
+ padding_size * [model.config.pad_token_id]
+ [model.config.bos_token_id]
],
dtype=torch.long,
device=torch_device,
)
decoder_attention_mask = torch.where(decoder_input_ids == model.config.pad_token_id, 0, 1).to(torch_device)
generated_ids = model.generate(
input_ids=input_ids,
use_cache=False,
max_new_tokens=20,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
generated_sentence = tokenizer.batch_decode(generated_ids)[0]
expected_sentence = "</s><pad><pad><pad><s>UN Chief Says There Is No Plan B for Peace in Syria</s>"
self.assertEqual(generated_sentence, expected_sentence)
class BartStandaloneDecoderModelTester:
def __init__(