Support batched input for decoder start ids (#28887)
* support batched input for decoder start ids * Fix typos Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * minor changes * fix: decoder_start_id as list * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
cc309fd406
commit
d628664688
@@ -3163,6 +3163,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||
|
||||
def test_batched_decoder_start_id(self):
|
||||
# PT-only test: TF doesn't support batched_decoder_start_id
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
|
||||
decoder_start_token_id_batch = [decoder_start_token_id] * input_ids.shape[0]
|
||||
|
||||
outputs = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id)
|
||||
|
||||
outputs_batched_ids = bart_model.generate(input_ids, decoder_start_token_id=decoder_start_token_id_batch)
|
||||
|
||||
self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())
|
||||
|
||||
def test_contrastive_search_batched(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
|
||||
|
||||
Reference in New Issue
Block a user