[generate] model defaults being inherited only happens for newer models (#36881)

This commit is contained in:
Joao Gante
2025-03-21 11:01:09 +00:00
committed by GitHub
parent f19d018bff
commit 94f487626a
2 changed files with 54 additions and 26 deletions

View File

@@ -575,8 +575,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684
-- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "google/gemma-3-1b-it"
attn_implementation = "sdpa"
@@ -594,12 +594,16 @@ class Gemma3IntegrationTest(unittest.TestCase):
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
self.assertGreater(input_size, model.config.sliding_window)
generation_config = GenerationConfig(max_new_tokens=20)
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
out = model.generate(**inputs, generation_config=generation_config)
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
output_text = tokenizer.batch_decode(out)
# Generation works beyond sliding window
self.assertGreater(out.shape[1], model.config.sliding_window)
self.assertEqual(out.shape[1], input_size + 5)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
# Note: Auto-inheritance only works for models saved starting from 4.50.0
model.generation_config.transformers_version = "4.49.0"
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
out = model.generate(**inputs, generation_config=generation_config)