From 85665a4263681d2d8eaf85de40583ad7f46bc976 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Apr 2025 11:12:18 +0100 Subject: [PATCH] [tests] Stricter generate + compilation test -- no recompilations allowed (#37629) * tmp commit * stricter compilation test * trigger tests * rm todo --- src/transformers/generation/utils.py | 18 +-- src/transformers/models/opt/modeling_opt.py | 8 +- .../models/whisper/modeling_whisper.py | 3 + tests/generation/test_utils.py | 110 ++++++++++-------- tests/models/aria/test_modeling_aria.py | 4 - tests/models/idefics/test_modeling_idefics.py | 4 - tests/models/janus/test_modeling_janus.py | 4 + .../llava_next/test_modeling_llava_next.py | 4 - .../test_modeling_llava_next_video.py | 4 - .../test_modeling_llava_onevision.py | 4 - .../paligemma/test_modeling_paligemma.py | 5 - .../paligemma2/test_modeling_paligemma2.py | 5 - .../test_modeling_qwen2_5_omni.py | 2 + tests/models/whisper/test_modeling_whisper.py | 2 +- 14 files changed, 87 insertions(+), 90 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ee5b79e6d3..2f9118b0da 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -563,17 +563,17 @@ class GenerationMixin: device = model_inputs[input_ids_key].device # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create - # the 4D causal mask exists, it should be present in the base model (XXXModel class). - base_model = getattr(self, self.base_model_prefix, None) - if base_model is None: + # the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder. + base_model = getattr(self, self.base_model_prefix, self) + decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None + causal_mask_creation_function = getattr( + base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + if causal_mask_creation_function is None and decoder is not None: # it may be in the decoder causal_mask_creation_function = getattr( - self, "_prepare_4d_causal_attention_mask_with_cache_position", None + decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None ) - else: - causal_mask_creation_function = getattr( - base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None - ) - if causal_mask_creation_function is None: + if causal_mask_creation_function is None: # can't be found logger.warning_once( f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 173fe89f6e..7e097cca31 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -1012,7 +1012,7 @@ class OPTModel(OPTPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1091,7 +1091,7 @@ class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1279,7 +1279,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1398,7 +1398,7 @@ class OPTForQuestionAnswering(OPTPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 52ad48cc2d..b7a454839c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1837,6 +1837,9 @@ class WhisperDecoderWrapper(WhisperPreTrainedModel): def set_input_embeddings(self, value): self.decoder.embed_tokens = value + def get_decoder(self): + return self.decoder + def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index caae8738e3..2bc1c351b8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -28,8 +28,9 @@ import pytest from packaging import version from parameterized import parameterized -from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, pipeline +from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, logging, pipeline from transformers.testing_utils import ( + CaptureLogger, is_flaky, require_accelerate, require_flash_attn, @@ -38,6 +39,7 @@ from transformers.testing_utils import ( require_torch, require_torch_accelerator, require_torch_gpu, + require_torch_greater_or_equal, require_torch_multi_accelerator, require_torch_multi_gpu, require_torch_sdpa, @@ -81,6 +83,7 @@ if is_torch_available(): BeamSampleEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, + CompileConfig, DisjunctiveConstraint, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, @@ -2109,22 +2112,34 @@ class GenerationTesterMixin: model.generate(**generation_kwargs, **inputs_dict) @pytest.mark.generate + @require_torch_greater_or_equal("2.6") # Uses torch.compiler.set_stance def test_generate_compile_model_forward(self): """ - Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. + Tests that `.generate` is compatible with torch.compile, keeping the same results. Also confirms that + `.forward` called from `.generate` sees no graph breaks or recompilations when compiled. + ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ """ for model_class in self.all_generative_model_classes: + # 1. Test exclusion criteria if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") + # 2. Prepares two sets of inputs config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4) - model = model_class(config).to(torch_device) model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time - main_input = inputs_dict[model.main_input_name].to(torch_device) + # Some composite models have a custom generate and will call an inner model's generate -> that inner model + # is the one that gets compiled. + # (Note for the future: if BLIP starts causing problems, let's stop testing it) + if "blip" in model.__class__.__name__.lower(): + model_to_be_compiled = model.language_model + else: + model_to_be_compiled = model + # creates two sets of *different* inputs with the same shape + main_input = inputs_dict[model.main_input_name].to(torch_device) half_batch_size = main_input.shape[0] // 2 input_1 = {} input_2 = {} @@ -2140,66 +2155,69 @@ class GenerationTesterMixin: model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape ) - # compilation-specific setup + # 3. compilation-specific setup and generation parameterization torch.compiler.reset() # prevent cached compilation from being used in the test has_defined_cache_implementation = model.generation_config.cache_implementation is not None - - # BLIP is the only exception with custom generate which call `self.lm.generate()` - # We should avoid such calls in all subsequent multimodal models and try to make `generate()` - # compatible with multimodality - if "blip" in model.__class__.__name__.lower(): - model.language_model.generation_config.compile_config._compile_all_devices = True - else: - # force compilation (e.g. fast CI, CPU - model.generation_config.compile_config._compile_all_devices = True + compile_config = CompileConfig(dynamic=False) # Error out on dynamic shapes + compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) generation_kwargs = { "do_sample": False, "max_new_tokens": 5, "return_dict_in_generate": True, "output_scores": True, + "compile_config": compile_config, } - # get eager + dynamic cache results for future comparison + # 4. get eager + dynamic cache results for future comparison dynamic_outputs = [] - for model_inputs in model_input_sets: - gen_out = model.generate(**model_inputs, **generation_kwargs) - dynamic_outputs.append(gen_out) - # sanity checks for the default cache implementation - if not has_defined_cache_implementation: + # Ignores all `torch.compile` usage, useful to test models that that have non-default compilable caches + # (who would have used compilation in this section) + with torch.compiler.set_stance("force_eager"): + for model_inputs in model_input_sets: + gen_out = model.generate(**model_inputs, **generation_kwargs) + dynamic_outputs.append(gen_out) + # sanity checks for the default cache implementation + if not has_defined_cache_implementation: + decoder_cache = ( + gen_out.past_key_values.self_attention_cache + if config.is_encoder_decoder + else gen_out.past_key_values + ) + self.assertTrue(isinstance(decoder_cache, DynamicCache)) + self.assertFalse(decoder_cache.is_compileable) + # our auto compile should NOT have been called + self.assertFalse(hasattr(model_to_be_compiled, "_compiled_call")) + + # 5. get compiled results -- relies on the automatic compilation triggered by specific compilable caches + if not has_defined_cache_implementation: + generation_kwargs["cache_implementation"] = "static" + + compiled_outputs = [] + # Uses a context manager to catch recompilation logs. If there is any recompilation, this test fails. + torch._logging.set_logs(recompiles_verbose=True) + logger = logging.get_logger("torch._dynamo.guards") + with CaptureLogger(logger) as cl: + for model_inputs in model_input_sets: + # with torch.compiler.set_stance("fail_on_recompile"): + gen_out = model.generate(**model_inputs, **generation_kwargs) + compiled_outputs.append(gen_out) + # sanity checks decoder_cache = ( gen_out.past_key_values.self_attention_cache if config.is_encoder_decoder else gen_out.past_key_values ) - self.assertTrue(isinstance(decoder_cache, DynamicCache)) - self.assertFalse(decoder_cache.is_compileable) - self.assertFalse(hasattr(model, "_compiled_call")) # our auto compile should NOT have been called + self.assertFalse(isinstance(decoder_cache, DynamicCache)) + self.assertTrue(decoder_cache.is_compileable) + # our auto compile should have been called + self.assertTrue(hasattr(model_to_be_compiled, "_compiled_call")) - # get compiled results -- relies on the automatic compilation triggered by specific "cache_implementation" - if not has_defined_cache_implementation: - generation_kwargs["cache_implementation"] = "static" - - compiled_outputs = [] - for model_inputs in model_input_sets: - gen_out = model.generate(**model_inputs, **generation_kwargs) - compiled_outputs.append(gen_out) - # sanity checks - decoder_cache = ( - gen_out.past_key_values.self_attention_cache - if config.is_encoder_decoder - else gen_out.past_key_values + if "Recompiling" in cl.out or ("guard" in cl.out and "failure" in cl.out): + raise RuntimeError( + f"`torch.compile` recompiled part of the forward pass in {model.__class__.__name__}. " + "See the test logs for more details." ) - self.assertFalse(isinstance(decoder_cache, DynamicCache)) - self.assertTrue(decoder_cache.is_compileable) - - # BLIP is the only exception with custom generate which call `self.lm.generate()` - # We should avoid such calls in all subsequent multimodal models and try to make `generate()` - # compatible with multimodality - if "blip" in model.__class__.__name__.lower(): - self.assertTrue(hasattr(model.language_model, "_compiled_call")) - else: - self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): self._check_similar_generate_outputs(dynamic_result, compiled_result) diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index 60716a51b2..63c812180a 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -280,10 +280,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip(reason="Dynamic control flow due to MoE") - def test_generate_compile_model_forward(self): - pass - @require_torch class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 169ec75d00..5f6a0f1832 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -840,10 +840,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni def test_generate_with_static_cache(self): pass - @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") - def test_generate_compile_model_forward(self): - pass - @unittest.skip(reason="We only test the model that takes in multiple images") def test_model(self): pass diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 03208f388e..48cf7ebc2f 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -335,6 +335,10 @@ class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, uni else: pass + @unittest.skip("There are recompilations in Janus") # TODO (joao, raushan): fix me + def test_generate_compile_model_forward(self): + pass + class JanusVQModelTester: def __init__( diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index dcba5fb8af..ea134aee9e 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -341,10 +341,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip("LLaVA Next has dynamic control flow in unpadding") - def test_generate_compile_model_forward(self): - pass - @require_torch class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 23efdb7de0..47c71d9c75 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -356,10 +356,6 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip("LLaVA Next Video has dynamic control flow in unpadding") - def test_generate_compile_model_forward(self): - pass - @require_torch class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index f311f2b455..021739976b 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -312,10 +312,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip("LLaVA OneVision has dynamic control flow in unpadding") - def test_generate_compile_model_forward(self): - pass - @require_torch class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 924aaaf2af..84b78f7264 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -344,11 +344,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow - @unittest.skip("PaliGemma is not compatible with end-to-end generation compilation") - def test_generate_compile_model_forward(self): - pass - def test_attention_mask_with_token_types(self): """Test that attention masking works correctly both with and without token type IDs.""" config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index 3f60ee33a4..e7d60a8849 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -341,11 +341,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow - @unittest.skip("PaliGemma is not compatible with end-to-end generation compilation") - def test_generate_compile_model_forward(self): - pass - @unittest.skip("Low memory will be removed soon so no need to fix it") def test_beam_search_low_memory(self): pass diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index 116425f349..5070383228 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -365,6 +365,8 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene def test_generate_from_inputs_embeds_with_static_cache(self): pass + # TODO (joao, raushan): there are multiple standardization issues in this model that prevent this test from + # passing, fix me @unittest.skip("Cannot handle 4D attention mask") def test_generate_compile_model_forward(self): pass diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index ff13b5eef5..7d2a5e54bd 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1431,7 +1431,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi with self.assertRaises(ValueError): model(input_features=input_features, labels=labels) - # TODO (joao, eustache): fix me :) + # TODO (joao, eustache): fix me :) The model is not returning a `Cache` by default @unittest.skip(reason="Whisper's custom generate is not consistent regarding the cache return types") def test_generate_compile_model_forward(self): pass