From 9af845afc2607172b6830610ab465bdd31f258cd Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 14 Apr 2023 09:57:25 +0100 Subject: [PATCH] Generate: pin number of beams in BART test (#22763) --- tests/models/bart/test_modeling_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 36837c9556..a4b77e8431 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1230,7 +1230,7 @@ class BartModelIntegrationTests(unittest.TestCase): article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="pt" ).input_ids.to(torch_device) - outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64, num_beams=1) generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) self.assertListEqual(