fix some flaky tests in tests/generation/test_utils.py (#39254)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1694,6 +1694,7 @@ class GenerationTesterMixin:
|
||||
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
|
||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
@@ -1703,8 +1704,10 @@ class GenerationTesterMixin:
|
||||
continue
|
||||
config.is_decoder = True
|
||||
|
||||
set_config_for_less_flaky_test(config)
|
||||
# Skip models without explicit support
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
set_model_for_less_flaky_test(model)
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
continue
|
||||
|
||||
@@ -2298,13 +2301,14 @@ class GenerationTesterMixin:
|
||||
NOTE: despite the test logic being the same, different implementations actually need different decorators, hence
|
||||
this separate function.
|
||||
"""
|
||||
max_new_tokens = 30
|
||||
max_new_tokens = 3
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn_2",
|
||||
"flash_attention_3": "_supports_flash_attn_3",
|
||||
}
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
|
||||
@@ -2330,6 +2334,7 @@ class GenerationTesterMixin:
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1
|
||||
|
||||
set_config_for_less_flaky_test(config)
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -2350,6 +2355,7 @@ class GenerationTesterMixin:
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
set_model_for_less_flaky_test(model_eager)
|
||||
res_eager = model_eager.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_eager
|
||||
gc.collect()
|
||||
@@ -2359,6 +2365,7 @@ class GenerationTesterMixin:
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(torch_device)
|
||||
set_model_for_less_flaky_test(model_attn)
|
||||
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_attn
|
||||
gc.collect()
|
||||
|
||||
Reference in New Issue
Block a user