fix some flaky tests in tests/generation/test_utils.py (#39254)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-07-07 19:49:41 +02:00
committed by GitHub
parent 93747d89ea
commit 41e865bb8d

View File

@@ -1694,6 +1694,7 @@ class GenerationTesterMixin:
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
@@ -1703,8 +1704,10 @@ class GenerationTesterMixin:
continue
config.is_decoder = True
set_config_for_less_flaky_test(config)
# Skip models without explicit support
model = model_class(config).to(torch_device).eval()
set_model_for_less_flaky_test(model)
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
@@ -2298,13 +2301,14 @@ class GenerationTesterMixin:
NOTE: despite the test logic being the same, different implementations actually need different decorators, hence
this separate function.
"""
max_new_tokens = 30
max_new_tokens = 3
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
"flash_attention_3": "_supports_flash_attn_3",
}
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
if not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
@@ -2330,6 +2334,7 @@ class GenerationTesterMixin:
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
set_config_for_less_flaky_test(config)
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -2350,6 +2355,7 @@ class GenerationTesterMixin:
torch_dtype=torch.float16,
attn_implementation="eager",
).to(torch_device)
set_model_for_less_flaky_test(model_eager)
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
del model_eager
gc.collect()
@@ -2359,6 +2365,7 @@ class GenerationTesterMixin:
torch_dtype=torch.float16,
attn_implementation=attn_implementation,
).to(torch_device)
set_model_for_less_flaky_test(model_attn)
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
del model_attn
gc.collect()