[tests] fix flaky pattern in test_generate_continue_from_past_key_values (#37724)

This commit is contained in:
Joao Gante
2025-04-29 12:20:42 +01:00
committed by GitHub
parent 4abeb50f6e
commit 3a1acc36ed

View File

@@ -1097,25 +1097,18 @@ class GenerationTesterMixin:
# test output equality of low versus high memory # test output equality of low versus high memory
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
generate_kwargs = {
"top_k": 4,
"penalty_alpha": 0.6,
"max_new_tokens": self.max_new_tokens,
"use_cache": True,
"return_dict_in_generate": True,
"output_scores": True,
}
low_output = model.generate( low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True)
top_k=4, high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False)
penalty_alpha=0.6, self._check_similar_generate_outputs(low_output, high_output)
low_memory=True,
max_new_tokens=self.max_new_tokens,
**inputs_dict,
use_cache=True,
)
high_output = model.generate(
top_k=4,
penalty_alpha=0.6,
low_memory=False,
max_new_tokens=self.max_new_tokens,
**inputs_dict,
use_cache=True,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@parameterized.expand([("random",), ("same",)]) @parameterized.expand([("random",), ("same",)])
@pytest.mark.generate @pytest.mark.generate
@@ -1863,22 +1856,29 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device) model = model_class(config).to(torch_device)
model.eval() model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.generation_config.encoder_no_repeat_ngram_size = 0
model.generation_config.use_cache = True
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs) outputs = model(**inputs)
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`") self.skipTest(reason="This model doesn't return `past_key_values`")
generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True) outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask). # inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
# Continue from the tokens generated above, preparing the inputs accordingly # Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values inputs["past_key_values"] = outputs_cached.past_key_values
@@ -1901,10 +1901,13 @@ class GenerationTesterMixin:
mode="constant", mode="constant",
value=1, value=1,
) )
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True) first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other # The two sets of generated text and past kv should be equal to each other
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist()) self._check_similar_generate_outputs(outputs, outputs_cached)
for layer_idx in range(len(outputs_cached.past_key_values)): for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue( self.assertTrue(
@@ -1930,6 +1933,8 @@ class GenerationTesterMixin:
if config.is_encoder_decoder: if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder") self.skipTest(reason="This model is encoder-decoder")
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -1990,32 +1995,6 @@ class GenerationTesterMixin:
) )
) )
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_accelerator
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
@pytest.mark.generate @pytest.mark.generate
def test_generate_with_static_cache(self): def test_generate_with_static_cache(self):
""" """