From 2638d54e7851f1323dc78a8b513b041835aba27b Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Fri, 21 Mar 2025 12:36:39 +0100 Subject: [PATCH] Gemma 3 tests expect greedy decoding (#36882) tests expect greedy decoding --- tests/models/gemma3/test_modeling_gemma3.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 7904b3f8eb..25e620f61f 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -567,7 +567,7 @@ class Gemma3IntegrationTest(unittest.TestCase): input_size = inputs.input_ids.shape[-1] self.assertTrue(input_size > model.config.sliding_window) - out = model.generate(**inputs, max_new_tokens=20)[:, input_size:] + out = model.generate(**inputs, max_new_tokens=20, do_sample=False)[:, input_size:] output_text = tokenizer.batch_decode(out) EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip @@ -599,6 +599,11 @@ class Gemma3IntegrationTest(unittest.TestCase): generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5) out = model.generate(**inputs, generation_config=generation_config) + out = model.generate(**inputs, generation_config=generation_config, do_sample=False)[:, input_size:] + output_text = tokenizer.batch_decode(out) + EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip + self.assertEqual(output_text, EXPECTED_COMPLETIONS) + # Generation works beyond sliding window self.assertGreater(out.shape[1], model.config.sliding_window) self.assertEqual(out.shape[1], input_size + 5)