Enable decoder_attention_mask in generate function (#20726)
* Enable `decoder_attention_mask` in `generate` function * Make style corrections * Run `make repo-consistency` * Add integration test
This commit is contained in:
@@ -1226,6 +1226,36 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_decoder_attention_mask(self):
|
||||
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0).to(
|
||||
torch_device
|
||||
)
|
||||
tokenizer = self.default_tokenizer
|
||||
sentence = "UN Chief Says There Is No <mask> in Syria"
|
||||
input_ids = tokenizer(sentence, return_tensors="pt").input_ids.to(torch_device)
|
||||
padding_size = 3
|
||||
decoder_input_ids = torch.tensor(
|
||||
[
|
||||
[model.config.decoder_start_token_id]
|
||||
+ padding_size * [model.config.pad_token_id]
|
||||
+ [model.config.bos_token_id]
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
decoder_attention_mask = torch.where(decoder_input_ids == model.config.pad_token_id, 0, 1).to(torch_device)
|
||||
generated_ids = model.generate(
|
||||
input_ids=input_ids,
|
||||
use_cache=False,
|
||||
max_new_tokens=20,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
generated_sentence = tokenizer.batch_decode(generated_ids)[0]
|
||||
expected_sentence = "</s><pad><pad><pad><s>UN Chief Says There Is No Plan B for Peace in Syria</s>"
|
||||
self.assertEqual(generated_sentence, expected_sentence)
|
||||
|
||||
|
||||
class BartStandaloneDecoderModelTester:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user