[tests] fix test_nemotron_8b_generation_sdpa (#37665)
add max_new_tokens
This commit is contained in:
@@ -195,7 +195,7 @@ class NemotronIntegrationTest(unittest.TestCase):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
|
inputs = tokenizer(text, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
output = model.generate(**inputs, do_sample=False)
|
output = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
self.assertEqual(EXPECTED_TEXT, output_text)
|
self.assertEqual(EXPECTED_TEXT, output_text)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user