[generate] model defaults being inherited only happens for newer models (#36881)
This commit is contained in:
@@ -575,8 +575,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_generation_beyond_sliding_window_with_generation_config(self):
|
||||
"""
|
||||
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
||||
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||
Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684
|
||||
-- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||
"""
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
attn_implementation = "sdpa"
|
||||
@@ -594,12 +594,16 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
|
||||
# Make sure prefill is larger than sliding window
|
||||
input_size = inputs.input_ids.shape[-1]
|
||||
self.assertTrue(input_size > model.config.sliding_window)
|
||||
self.assertGreater(input_size, model.config.sliding_window)
|
||||
|
||||
generation_config = GenerationConfig(max_new_tokens=20)
|
||||
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
|
||||
out = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
||||
output_text = tokenizer.batch_decode(out)
|
||||
# Generation works beyond sliding window
|
||||
self.assertGreater(out.shape[1], model.config.sliding_window)
|
||||
self.assertEqual(out.shape[1], input_size + 5)
|
||||
|
||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||
# Note: Auto-inheritance only works for models saved starting from 4.50.0
|
||||
model.generation_config.transformers_version = "4.49.0"
|
||||
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
||||
out = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
Reference in New Issue
Block a user