From 42c8ccfd4c466eedabc5e30ac34f242e2c9f9455 Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Wed, 29 Jan 2025 07:03:45 -0500 Subject: [PATCH] fix `test_generated_length_assisted_generation` (#34935) fix test_generated_length_assisted_generation --- tests/generation/test_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d9b4bbbe8c..26f3b3e324 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3405,7 +3405,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi assistant_model=assistant, min_new_tokens=10, ) - self.assertTrue((input_length + 10) <= out.shape[-1] <= 20) + self.assertTrue((input_length + 10) <= out.shape[-1]) + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=7, + ) + self.assertTrue(out.shape[-1] <= (input_length + 7)) def test_model_kwarg_assisted_decoding_decoder_only(self): # PT-only test: TF doesn't support assisted decoding yet.