[VLMs] fix flash-attention tests (#37603)

* fix one test

* fa2 ln test

* remove keys from config recursively

* fix

* fixup
This commit is contained in:
Raushan Turganbay
2025-04-24 11:48:11 +02:00
committed by GitHub
parent 02baa61fab
commit 1cfcbfcab8
17 changed files with 52 additions and 83 deletions

View File

@@ -843,29 +843,16 @@ class PretrainedConfig(PushToHubMixin):
): ):
serializable_config_dict[key] = value serializable_config_dict[key] = value
self._remove_keys_not_serialized(serializable_config_dict)
if hasattr(self, "quantization_config"): if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = ( serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict() self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict) if not isinstance(self.quantization_config, dict)
else self.quantization_config else self.quantization_config
) )
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict) self.dict_torch_dtype_to_str(serializable_config_dict)
if "_attn_implementation_internal" in serializable_config_dict:
del serializable_config_dict["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_pp_plan"]
if "_name_or_path" in serializable_config_dict:
del serializable_config_dict["_name_or_path"]
return serializable_config_dict return serializable_config_dict
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
@@ -878,18 +865,6 @@ class PretrainedConfig(PushToHubMixin):
output = copy.deepcopy(self.__dict__) output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"): if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in output:
del output["base_model_pp_plan"]
# Transformers version when serializing the model # Transformers version when serializing the model
output["transformers_version"] = __version__ output["transformers_version"] = __version__
@@ -902,16 +877,14 @@ class PretrainedConfig(PushToHubMixin):
output[key] = value output[key] = value
self._remove_keys_not_serialized(output)
if hasattr(self, "quantization_config"): if hasattr(self, "quantization_config"):
output["quantization_config"] = ( output["quantization_config"] = (
self.quantization_config.to_dict() self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict) if not isinstance(self.quantization_config, dict)
else self.quantization_config else self.quantization_config
) )
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = output.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(output) self.dict_torch_dtype_to_str(output)
return output return output
@@ -1011,6 +984,33 @@ class PretrainedConfig(PushToHubMixin):
if isinstance(value, dict): if isinstance(value, dict):
self.dict_torch_dtype_to_str(value) self.dict_torch_dtype_to_str(value)
def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
"""
Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
Runs recursive check on the dict, to remove from all sub configs.
"""
if hasattr(self, "quantization_config"):
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = d.pop("_pre_quantization_dtype", None)
if "_auto_class" in d:
del d["_auto_class"]
if "_commit_hash" in d:
del d["_commit_hash"]
if "_attn_implementation_internal" in d:
del d["_attn_implementation_internal"]
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in d:
del d["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in d:
del d["base_model_pp_plan"]
if "_name_or_path" in d:
del d["_name_or_path"]
for value in d.values():
if isinstance(value, dict):
self._remove_keys_not_serialized(value)
@classmethod @classmethod
def register_for_auto_class(cls, auto_class="AutoConfig"): def register_for_auto_class(cls, auto_class="AutoConfig"):
""" """

View File

@@ -4444,7 +4444,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# once the weights have been quantized # once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will # Note that once you have loaded a quantized model, you can't change its dtype so this will
# remain a single source of truth # remain a single source of truth
config._pre_quantization_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype() original_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
def _assign_original_dtype(module):
for child in module.children():
if isinstance(child, PreTrainedModel):
child.config._pre_quantization_dtype = original_dtype
_assign_original_dtype(child)
config._pre_quantization_dtype = original_dtype
_assign_original_dtype(model)
# Prepare the full device map # Prepare the full device map
if device_map is not None: if device_map is not None:

View File

@@ -125,6 +125,9 @@ class InternVLVisionAttention(nn.Module):
proj_dropout = config.projection_dropout proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm qk_norm = config.use_qk_norm
# Needed for flash attention
self.is_causal = False
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
@@ -134,9 +137,6 @@ class InternVLVisionAttention(nn.Module):
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity() self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
# Needed for flash attention
self.is_causal = False
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@@ -344,6 +344,7 @@ class JanusVisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
proj_dropout = config.projection_dropout proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm qk_norm = config.use_qk_norm
self.is_causal = False
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1. # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1 self.num_key_value_groups = 1
@@ -398,7 +399,7 @@ class JanusVisionAttention(nn.Module):
attention_mask, attention_mask,
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale, scaling=self.scale,
is_causal=False, is_causal=self.is_causal,
**kwargs, **kwargs,
) )
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)

View File

@@ -509,6 +509,7 @@ class JanusVisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
proj_dropout = config.projection_dropout proj_dropout = config.projection_dropout
qk_norm = config.use_qk_norm qk_norm = config.use_qk_norm
self.is_causal = False
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1. # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
self.num_key_value_groups = 1 self.num_key_value_groups = 1
@@ -563,7 +564,7 @@ class JanusVisionAttention(nn.Module):
attention_mask, attention_mask,
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scale, scaling=self.scale,
is_causal=False, is_causal=self.is_causal,
**kwargs, **kwargs,
) )
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim) attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)

View File

@@ -316,10 +316,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_compile_dynamic(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
# todo: yoni - fix or improve the test # todo: yoni - fix or improve the test
@unittest.skip("Difference is slightly higher than the threshold") @unittest.skip("Difference is slightly higher than the threshold")
def test_batching_equivalence(self): def test_batching_equivalence(self):

View File

@@ -222,8 +222,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type") @unittest.skip(
def test_flash_attn_2_fp32_ln(self): reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
)
def test_past_key_values_format(self):
pass pass

View File

@@ -407,10 +407,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_prompt_lookup_decoding_matches_greedy_search(self): def test_prompt_lookup_decoding_matches_greedy_search(self):
pass pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate @pytest.mark.generate
@require_torch_sdpa @require_torch_sdpa
@slow @slow

View File

@@ -367,10 +367,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_prompt_lookup_decoding_matches_greedy_search(self): def test_prompt_lookup_decoding_matches_greedy_search(self):
pass pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate @pytest.mark.generate
@require_torch_sdpa @require_torch_sdpa
@slow @slow

View File

@@ -302,10 +302,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@@ -331,10 +331,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@@ -302,10 +302,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@@ -334,10 +334,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@@ -331,10 +331,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_generate_from_inputs_embeds_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@@ -378,10 +378,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
def test_generate_methods_with_logits_to_keep(self): def test_generate_methods_with_logits_to_keep(self):
super().test_generate_methods_with_logits_to_keep() super().test_generate_methods_with_logits_to_keep()
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip @unittest.skip
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass

View File

@@ -225,10 +225,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )

View File

@@ -305,10 +305,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip( @unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
) )