fix test_generated_length_assisted_generation (#34935)

fix test_generated_length_assisted_generation
This commit is contained in:
Nadav Timor
2025-01-29 07:03:45 -05:00
committed by GitHub
parent ec7afad609
commit 42c8ccfd4c

View File

@@ -3405,7 +3405,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
assistant_model=assistant, assistant_model=assistant,
min_new_tokens=10, min_new_tokens=10,
) )
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20) self.assertTrue((input_length + 10) <= out.shape[-1])
out = model.generate(
input_ids,
assistant_model=assistant,
max_new_tokens=7,
)
self.assertTrue(out.shape[-1] <= (input_length + 7))
def test_model_kwarg_assisted_decoding_decoder_only(self): def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet. # PT-only test: TF doesn't support assisted decoding yet.