Make test_generate_with_static_cache even less flaky (#34995)

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-12-20 16:03:26 +01:00
committed by GitHub
parent 0fc2970363
commit 504c4d3692
5 changed files with 93 additions and 32 deletions

View File

@@ -37,6 +37,9 @@ from transformers.testing_utils import (
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
set_config_for_less_flaky_test,
set_model_for_less_flaky_test,
set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
@@ -1921,11 +1924,13 @@ class GenerationTesterMixin:
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
has the expected shapes
"""
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(reason="This model does not support the static cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
set_config_for_less_flaky_test(config)
main_input = inputs_dict[model_class.main_input_name]
if config.is_encoder_decoder:
@@ -1938,6 +1943,8 @@ class GenerationTesterMixin:
for dtype in (torch.float32, torch.float16):
model = model_class(config).to(torch_device).to(dtype).eval()
set_model_for_less_flaky_test(model)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"return_dict_in_generate": True, # Required to return `past_key_values`