Make test_generate_with_static_cache even less flaky (#34995)
* fix * fix * fix * fix * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -37,6 +37,9 @@ from transformers.testing_utils import (
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
set_config_for_less_flaky_test,
|
||||
set_model_for_less_flaky_test,
|
||||
set_model_tester_for_less_flaky_test,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -1921,11 +1924,13 @@ class GenerationTesterMixin:
|
||||
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
|
||||
has the expected shapes
|
||||
"""
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(reason="This model does not support the static cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
set_config_for_less_flaky_test(config)
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
@@ -1938,6 +1943,8 @@ class GenerationTesterMixin:
|
||||
|
||||
for dtype in (torch.float32, torch.float16):
|
||||
model = model_class(config).to(torch_device).to(dtype).eval()
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
|
||||
Reference in New Issue
Block a user