diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c9e7f9f4f1..c6fd59f68b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1694,6 +1694,7 @@ class GenerationTesterMixin: """Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc""" # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` # if fails, you should probably update the `prepare_inputs_for_generation` function + set_model_tester_for_less_flaky_test(self) for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() @@ -1703,8 +1704,10 @@ class GenerationTesterMixin: continue config.is_decoder = True + set_config_for_less_flaky_test(config) # Skip models without explicit support model = model_class(config).to(torch_device).eval() + set_model_for_less_flaky_test(model) if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): continue @@ -2298,13 +2301,14 @@ class GenerationTesterMixin: NOTE: despite the test logic being the same, different implementations actually need different decorators, hence this separate function. """ - max_new_tokens = 30 + max_new_tokens = 3 support_flag = { "sdpa": "_supports_sdpa", "flash_attention_2": "_supports_flash_attn_2", "flash_attention_3": "_supports_flash_attn_3", } + set_model_tester_for_less_flaky_test(self) for model_class in self.all_generative_model_classes: if not getattr(model_class, support_flag[attn_implementation]): self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`") @@ -2330,6 +2334,7 @@ class GenerationTesterMixin: if hasattr(config, "max_position_embeddings"): config.max_position_embeddings = max_new_tokens + main_input.shape[1] + 1 + set_config_for_less_flaky_test(config) model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: @@ -2350,6 +2355,7 @@ class GenerationTesterMixin: torch_dtype=torch.float16, attn_implementation="eager", ).to(torch_device) + set_model_for_less_flaky_test(model_eager) res_eager = model_eager.generate(**inputs_dict, **generate_kwargs) del model_eager gc.collect() @@ -2359,6 +2365,7 @@ class GenerationTesterMixin: torch_dtype=torch.float16, attn_implementation=attn_implementation, ).to(torch_device) + set_model_for_less_flaky_test(model_attn) res_attn = model_attn.generate(**inputs_dict, **generate_kwargs) del model_attn gc.collect()