From 1cfcbfcab8da6681a73b18e3e198d56417815223 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 24 Apr 2025 11:48:11 +0200 Subject: [PATCH] [VLMs] fix flash-attention tests (#37603) * fix one test * fa2 ln test * remove keys from config recursively * fix * fixup --- src/transformers/configuration_utils.py | 62 +++++++++---------- src/transformers/modeling_utils.py | 11 +++- .../models/internvl/modeling_internvl.py | 6 +- .../models/janus/modeling_janus.py | 3 +- .../models/janus/modular_janus.py | 3 +- .../aya_vision/test_modeling_aya_vision.py | 4 -- .../models/got_ocr2/test_modeling_got_ocr2.py | 6 +- .../models/idefics2/test_modeling_idefics2.py | 4 -- .../models/idefics3/test_modeling_idefics3.py | 4 -- tests/models/llava/test_modeling_llava.py | 4 -- .../llava_next/test_modeling_llava_next.py | 4 -- .../test_modeling_llava_onevision.py | 4 -- .../paligemma/test_modeling_paligemma.py | 4 -- .../paligemma2/test_modeling_paligemma2.py | 4 -- tests/models/smolvlm/test_modeling_smolvlm.py | 4 -- .../video_llava/test_modeling_video_llava.py | 4 -- .../models/vipllava/test_modeling_vipllava.py | 4 -- 17 files changed, 52 insertions(+), 83 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 8fa8dc46c3..c50f9a6506 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -843,29 +843,16 @@ class PretrainedConfig(PushToHubMixin): ): serializable_config_dict[key] = value + self._remove_keys_not_serialized(serializable_config_dict) + if hasattr(self, "quantization_config"): serializable_config_dict["quantization_config"] = ( self.quantization_config.to_dict() if not isinstance(self.quantization_config, dict) else self.quantization_config ) - # Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. - _ = serializable_config_dict.pop("_pre_quantization_dtype", None) - self.dict_torch_dtype_to_str(serializable_config_dict) - if "_attn_implementation_internal" in serializable_config_dict: - del serializable_config_dict["_attn_implementation_internal"] - # Do not serialize `base_model_tp_plan` for now - if "base_model_tp_plan" in serializable_config_dict: - del serializable_config_dict["base_model_tp_plan"] - # Do not serialize `base_model_pp_plan` for now - if "base_model_pp_plan" in serializable_config_dict: - del serializable_config_dict["base_model_pp_plan"] - - if "_name_or_path" in serializable_config_dict: - del serializable_config_dict["_name_or_path"] - return serializable_config_dict def to_dict(self) -> dict[str, Any]: @@ -878,18 +865,6 @@ class PretrainedConfig(PushToHubMixin): output = copy.deepcopy(self.__dict__) if hasattr(self.__class__, "model_type"): output["model_type"] = self.__class__.model_type - if "_auto_class" in output: - del output["_auto_class"] - if "_commit_hash" in output: - del output["_commit_hash"] - if "_attn_implementation_internal" in output: - del output["_attn_implementation_internal"] - # Do not serialize `base_model_tp_plan` for now - if "base_model_tp_plan" in output: - del output["base_model_tp_plan"] - # Do not serialize `base_model_pp_plan` for now - if "base_model_pp_plan" in output: - del output["base_model_pp_plan"] # Transformers version when serializing the model output["transformers_version"] = __version__ @@ -902,16 +877,14 @@ class PretrainedConfig(PushToHubMixin): output[key] = value + self._remove_keys_not_serialized(output) + if hasattr(self, "quantization_config"): output["quantization_config"] = ( self.quantization_config.to_dict() if not isinstance(self.quantization_config, dict) else self.quantization_config ) - - # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. - _ = output.pop("_pre_quantization_dtype", None) - self.dict_torch_dtype_to_str(output) return output @@ -1011,6 +984,33 @@ class PretrainedConfig(PushToHubMixin): if isinstance(value, dict): self.dict_torch_dtype_to_str(value) + def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None: + """ + Checks and removes if there are any keys in the dict that should not be serialized when saving the config. + Runs recursive check on the dict, to remove from all sub configs. + """ + if hasattr(self, "quantization_config"): + # Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = d.pop("_pre_quantization_dtype", None) + + if "_auto_class" in d: + del d["_auto_class"] + if "_commit_hash" in d: + del d["_commit_hash"] + if "_attn_implementation_internal" in d: + del d["_attn_implementation_internal"] + # Do not serialize `base_model_tp_plan` for now + if "base_model_tp_plan" in d: + del d["base_model_tp_plan"] + # Do not serialize `base_model_pp_plan` for now + if "base_model_pp_plan" in d: + del d["base_model_pp_plan"] + if "_name_or_path" in d: + del d["_name_or_path"] + for value in d.values(): + if isinstance(value, dict): + self._remove_keys_not_serialized(value) + @classmethod def register_for_auto_class(cls, auto_class="AutoConfig"): """ diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0dbc97781e..6999d46e49 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4444,7 +4444,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # once the weights have been quantized # Note that once you have loaded a quantized model, you can't change its dtype so this will # remain a single source of truth - config._pre_quantization_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype() + original_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype() + + def _assign_original_dtype(module): + for child in module.children(): + if isinstance(child, PreTrainedModel): + child.config._pre_quantization_dtype = original_dtype + _assign_original_dtype(child) + + config._pre_quantization_dtype = original_dtype + _assign_original_dtype(model) # Prepare the full device map if device_map is not None: diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index b6af62afd5..c181976e0e 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -125,6 +125,9 @@ class InternVLVisionAttention(nn.Module): proj_dropout = config.projection_dropout qk_norm = config.use_qk_norm + # Needed for flash attention + self.is_causal = False + self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) @@ -134,9 +137,6 @@ class InternVLVisionAttention(nn.Module): self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() - # Needed for flash attention - self.is_causal = False - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 4fa538ee7c..6a30954345 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -344,6 +344,7 @@ class JanusVisionAttention(nn.Module): self.attention_dropout = config.attention_dropout proj_dropout = config.projection_dropout qk_norm = config.use_qk_norm + self.is_causal = False # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1. self.num_key_value_groups = 1 @@ -398,7 +399,7 @@ class JanusVisionAttention(nn.Module): attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scale, - is_causal=False, + is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 3a0efff5ae..36499e43ea 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -509,6 +509,7 @@ class JanusVisionAttention(nn.Module): self.attention_dropout = config.attention_dropout proj_dropout = config.projection_dropout qk_norm = config.use_qk_norm + self.is_causal = False # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1. self.num_key_value_groups = 1 @@ -563,7 +564,7 @@ class JanusVisionAttention(nn.Module): attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scale, - is_causal=False, + is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) diff --git a/tests/models/aya_vision/test_modeling_aya_vision.py b/tests/models/aya_vision/test_modeling_aya_vision.py index b5ff4f15a8..5858321ee4 100644 --- a/tests/models/aya_vision/test_modeling_aya_vision.py +++ b/tests/models/aya_vision/test_modeling_aya_vision.py @@ -316,10 +316,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_sdpa_can_compile_dynamic(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - # todo: yoni - fix or improve the test @unittest.skip("Difference is slightly higher than the threshold") def test_batching_equivalence(self): diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index f604dbf036..ed0a25f7b1 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -222,8 +222,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): + @unittest.skip( + reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format" + ) + def test_past_key_values_format(self): pass diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 3c1ac9bb51..325d971434 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -407,10 +407,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest def test_prompt_lookup_decoding_matches_greedy_search(self): pass - @unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @pytest.mark.generate @require_torch_sdpa @slow diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index f1fcfd3943..69a0f85ace 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -367,10 +367,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest def test_prompt_lookup_decoding_matches_greedy_search(self): pass - @unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @pytest.mark.generate @require_torch_sdpa @slow diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index f329d1b211..1072d9043e 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -302,10 +302,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index ea134aee9e..3b9fc36521 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -331,10 +331,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 021739976b..dfb3b01395 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -302,10 +302,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 84b78f7264..dee84a53f3 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -334,10 +334,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index e7d60a8849..6938e50a26 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -331,10 +331,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/smolvlm/test_modeling_smolvlm.py b/tests/models/smolvlm/test_modeling_smolvlm.py index 8e2839166b..cdeb0d95ec 100644 --- a/tests/models/smolvlm/test_modeling_smolvlm.py +++ b/tests/models/smolvlm/test_modeling_smolvlm.py @@ -378,10 +378,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste def test_generate_methods_with_logits_to_keep(self): super().test_generate_methods_with_logits_to_keep() - @unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip def test_training_gradient_checkpointing(self): pass diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 772356e097..bda728b391 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -225,10 +225,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" ) diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index e79593efbf..60460ddfb5 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -305,10 +305,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") - def test_flash_attn_2_fp32_ln(self): - pass - @unittest.skip( "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" )