[tests] fix flaky pattern in test_generate_continue_from_past_key_values (#37724)
This commit is contained in:
@@ -1097,25 +1097,18 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# test output equality of low versus high memory
|
# test output equality of low versus high memory
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
generate_kwargs = {
|
||||||
|
"top_k": 4,
|
||||||
|
"penalty_alpha": 0.6,
|
||||||
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
"use_cache": True,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"output_scores": True,
|
||||||
|
}
|
||||||
|
|
||||||
low_output = model.generate(
|
low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True)
|
||||||
top_k=4,
|
high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False)
|
||||||
penalty_alpha=0.6,
|
self._check_similar_generate_outputs(low_output, high_output)
|
||||||
low_memory=True,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
**inputs_dict,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
high_output = model.generate(
|
|
||||||
top_k=4,
|
|
||||||
penalty_alpha=0.6,
|
|
||||||
low_memory=False,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
**inputs_dict,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@@ -1863,22 +1856,29 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
|
||||||
model.generation_config.forced_eos_token_id = None
|
|
||||||
model.generation_config.encoder_no_repeat_ngram_size = 0
|
|
||||||
model.generation_config.use_cache = True
|
|
||||||
|
|
||||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
if "past_key_values" not in outputs:
|
if "past_key_values" not in outputs:
|
||||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||||
|
|
||||||
|
generate_kwargs = {
|
||||||
|
"pad_token_id": -1,
|
||||||
|
"eos_token_id": -1,
|
||||||
|
"forced_eos_token_id": None,
|
||||||
|
"encoder_no_repeat_ngram_size": 0,
|
||||||
|
"use_cache": True,
|
||||||
|
"do_sample": False,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"output_scores": True,
|
||||||
|
}
|
||||||
|
|
||||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
|
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
|
||||||
|
|
||||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
|
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
|
||||||
|
|
||||||
# Continue from the tokens generated above, preparing the inputs accordingly
|
# Continue from the tokens generated above, preparing the inputs accordingly
|
||||||
inputs["past_key_values"] = outputs_cached.past_key_values
|
inputs["past_key_values"] = outputs_cached.past_key_values
|
||||||
@@ -1901,10 +1901,13 @@ class GenerationTesterMixin:
|
|||||||
mode="constant",
|
mode="constant",
|
||||||
value=1,
|
value=1,
|
||||||
)
|
)
|
||||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
|
first_caches_scores = outputs_cached.scores
|
||||||
|
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
|
||||||
|
full_cached_scores = first_caches_scores + outputs_cached.scores
|
||||||
|
outputs_cached.scores = full_cached_scores
|
||||||
|
|
||||||
# The two sets of generated text and past kv should be equal to each other
|
# The two sets of generated text and past kv should be equal to each other
|
||||||
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
|
self._check_similar_generate_outputs(outputs, outputs_cached)
|
||||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@@ -1930,6 +1933,8 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder")
|
self.skipTest(reason="This model is encoder-decoder")
|
||||||
|
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
|
||||||
|
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
@@ -1990,32 +1995,6 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
|
||||||
@require_torch_accelerator
|
|
||||||
@pytest.mark.generate
|
|
||||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
|
||||||
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_cache_class:
|
|
||||||
self.skipTest(reason="This model does not support the new cache format")
|
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
generation_kwargs = {
|
|
||||||
"max_new_tokens": 5,
|
|
||||||
"use_cache": True,
|
|
||||||
"cache_implementation": cache_implementation,
|
|
||||||
}
|
|
||||||
|
|
||||||
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
|
||||||
|
|
||||||
# Most cache classes have their own tests except for some that are tested here
|
|
||||||
# The ones here do not need special treatment when passing `cache_implementation`
|
|
||||||
# and are not bound to specific models only
|
|
||||||
new_results = model.generate(**generation_kwargs, **inputs_dict)
|
|
||||||
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
|
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_with_static_cache(self):
|
def test_generate_with_static_cache(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user