Generate: pin number of beams in BART test (#22763)
This commit is contained in:
@@ -1230,7 +1230,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
|
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt"
|
||||||
).input_ids.to(torch_device)
|
).input_ids.to(torch_device)
|
||||||
|
|
||||||
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
|
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64, num_beams=1)
|
||||||
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
|
|||||||
Reference in New Issue
Block a user