[VLMs] fix flash-attention tests (#37603)
* fix one test * fa2 ln test * remove keys from config recursively * fix * fixup
This commit is contained in:
committed by
GitHub
parent
02baa61fab
commit
1cfcbfcab8
@@ -843,29 +843,16 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
):
|
):
|
||||||
serializable_config_dict[key] = value
|
serializable_config_dict[key] = value
|
||||||
|
|
||||||
|
self._remove_keys_not_serialized(serializable_config_dict)
|
||||||
|
|
||||||
if hasattr(self, "quantization_config"):
|
if hasattr(self, "quantization_config"):
|
||||||
serializable_config_dict["quantization_config"] = (
|
serializable_config_dict["quantization_config"] = (
|
||||||
self.quantization_config.to_dict()
|
self.quantization_config.to_dict()
|
||||||
if not isinstance(self.quantization_config, dict)
|
if not isinstance(self.quantization_config, dict)
|
||||||
else self.quantization_config
|
else self.quantization_config
|
||||||
)
|
)
|
||||||
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
|
||||||
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
|
|
||||||
|
|
||||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||||
|
|
||||||
if "_attn_implementation_internal" in serializable_config_dict:
|
|
||||||
del serializable_config_dict["_attn_implementation_internal"]
|
|
||||||
# Do not serialize `base_model_tp_plan` for now
|
|
||||||
if "base_model_tp_plan" in serializable_config_dict:
|
|
||||||
del serializable_config_dict["base_model_tp_plan"]
|
|
||||||
# Do not serialize `base_model_pp_plan` for now
|
|
||||||
if "base_model_pp_plan" in serializable_config_dict:
|
|
||||||
del serializable_config_dict["base_model_pp_plan"]
|
|
||||||
|
|
||||||
if "_name_or_path" in serializable_config_dict:
|
|
||||||
del serializable_config_dict["_name_or_path"]
|
|
||||||
|
|
||||||
return serializable_config_dict
|
return serializable_config_dict
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
@@ -878,18 +865,6 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
output = copy.deepcopy(self.__dict__)
|
output = copy.deepcopy(self.__dict__)
|
||||||
if hasattr(self.__class__, "model_type"):
|
if hasattr(self.__class__, "model_type"):
|
||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
if "_auto_class" in output:
|
|
||||||
del output["_auto_class"]
|
|
||||||
if "_commit_hash" in output:
|
|
||||||
del output["_commit_hash"]
|
|
||||||
if "_attn_implementation_internal" in output:
|
|
||||||
del output["_attn_implementation_internal"]
|
|
||||||
# Do not serialize `base_model_tp_plan` for now
|
|
||||||
if "base_model_tp_plan" in output:
|
|
||||||
del output["base_model_tp_plan"]
|
|
||||||
# Do not serialize `base_model_pp_plan` for now
|
|
||||||
if "base_model_pp_plan" in output:
|
|
||||||
del output["base_model_pp_plan"]
|
|
||||||
|
|
||||||
# Transformers version when serializing the model
|
# Transformers version when serializing the model
|
||||||
output["transformers_version"] = __version__
|
output["transformers_version"] = __version__
|
||||||
@@ -902,16 +877,14 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
|
|
||||||
output[key] = value
|
output[key] = value
|
||||||
|
|
||||||
|
self._remove_keys_not_serialized(output)
|
||||||
|
|
||||||
if hasattr(self, "quantization_config"):
|
if hasattr(self, "quantization_config"):
|
||||||
output["quantization_config"] = (
|
output["quantization_config"] = (
|
||||||
self.quantization_config.to_dict()
|
self.quantization_config.to_dict()
|
||||||
if not isinstance(self.quantization_config, dict)
|
if not isinstance(self.quantization_config, dict)
|
||||||
else self.quantization_config
|
else self.quantization_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
|
||||||
_ = output.pop("_pre_quantization_dtype", None)
|
|
||||||
|
|
||||||
self.dict_torch_dtype_to_str(output)
|
self.dict_torch_dtype_to_str(output)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -1011,6 +984,33 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
self.dict_torch_dtype_to_str(value)
|
self.dict_torch_dtype_to_str(value)
|
||||||
|
|
||||||
|
def _remove_keys_not_serialized(self, d: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Checks and removes if there are any keys in the dict that should not be serialized when saving the config.
|
||||||
|
Runs recursive check on the dict, to remove from all sub configs.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "quantization_config"):
|
||||||
|
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||||
|
_ = d.pop("_pre_quantization_dtype", None)
|
||||||
|
|
||||||
|
if "_auto_class" in d:
|
||||||
|
del d["_auto_class"]
|
||||||
|
if "_commit_hash" in d:
|
||||||
|
del d["_commit_hash"]
|
||||||
|
if "_attn_implementation_internal" in d:
|
||||||
|
del d["_attn_implementation_internal"]
|
||||||
|
# Do not serialize `base_model_tp_plan` for now
|
||||||
|
if "base_model_tp_plan" in d:
|
||||||
|
del d["base_model_tp_plan"]
|
||||||
|
# Do not serialize `base_model_pp_plan` for now
|
||||||
|
if "base_model_pp_plan" in d:
|
||||||
|
del d["base_model_pp_plan"]
|
||||||
|
if "_name_or_path" in d:
|
||||||
|
del d["_name_or_path"]
|
||||||
|
for value in d.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
self._remove_keys_not_serialized(value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_for_auto_class(cls, auto_class="AutoConfig"):
|
def register_for_auto_class(cls, auto_class="AutoConfig"):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4444,7 +4444,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# once the weights have been quantized
|
# once the weights have been quantized
|
||||||
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
# Note that once you have loaded a quantized model, you can't change its dtype so this will
|
||||||
# remain a single source of truth
|
# remain a single source of truth
|
||||||
config._pre_quantization_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
|
original_dtype = torch_dtype if torch_dtype is not None else torch.get_default_dtype()
|
||||||
|
|
||||||
|
def _assign_original_dtype(module):
|
||||||
|
for child in module.children():
|
||||||
|
if isinstance(child, PreTrainedModel):
|
||||||
|
child.config._pre_quantization_dtype = original_dtype
|
||||||
|
_assign_original_dtype(child)
|
||||||
|
|
||||||
|
config._pre_quantization_dtype = original_dtype
|
||||||
|
_assign_original_dtype(model)
|
||||||
|
|
||||||
# Prepare the full device map
|
# Prepare the full device map
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
|
|||||||
@@ -125,6 +125,9 @@ class InternVLVisionAttention(nn.Module):
|
|||||||
proj_dropout = config.projection_dropout
|
proj_dropout = config.projection_dropout
|
||||||
qk_norm = config.use_qk_norm
|
qk_norm = config.use_qk_norm
|
||||||
|
|
||||||
|
# Needed for flash attention
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||||
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||||
@@ -134,9 +137,6 @@ class InternVLVisionAttention(nn.Module):
|
|||||||
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
self.q_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||||
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
self.k_norm = InternVLVisionRMSNorm(self.embed_dim) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
# Needed for flash attention
|
|
||||||
self.is_causal = False
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@@ -344,6 +344,7 @@ class JanusVisionAttention(nn.Module):
|
|||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
proj_dropout = config.projection_dropout
|
proj_dropout = config.projection_dropout
|
||||||
qk_norm = config.use_qk_norm
|
qk_norm = config.use_qk_norm
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
|
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
|
||||||
self.num_key_value_groups = 1
|
self.num_key_value_groups = 1
|
||||||
@@ -398,7 +399,7 @@ class JanusVisionAttention(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
dropout=0.0 if not self.training else self.attention_dropout,
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
scaling=self.scale,
|
scaling=self.scale,
|
||||||
is_causal=False,
|
is_causal=self.is_causal,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
||||||
|
|||||||
@@ -509,6 +509,7 @@ class JanusVisionAttention(nn.Module):
|
|||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
proj_dropout = config.projection_dropout
|
proj_dropout = config.projection_dropout
|
||||||
qk_norm = config.use_qk_norm
|
qk_norm = config.use_qk_norm
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
|
# Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
|
||||||
self.num_key_value_groups = 1
|
self.num_key_value_groups = 1
|
||||||
@@ -563,7 +564,7 @@ class JanusVisionAttention(nn.Module):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
dropout=0.0 if not self.training else self.attention_dropout,
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
scaling=self.scale,
|
scaling=self.scale,
|
||||||
is_causal=False,
|
is_causal=self.is_causal,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
||||||
|
|||||||
@@ -316,10 +316,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
def test_sdpa_can_compile_dynamic(self):
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# todo: yoni - fix or improve the test
|
# todo: yoni - fix or improve the test
|
||||||
@unittest.skip("Difference is slightly higher than the threshold")
|
@unittest.skip("Difference is slightly higher than the threshold")
|
||||||
def test_batching_equivalence(self):
|
def test_batching_equivalence(self):
|
||||||
|
|||||||
@@ -222,8 +222,10 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
@unittest.skip(
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
|
||||||
|
)
|
||||||
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -407,10 +407,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -367,10 +367,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -302,10 +302,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -331,10 +331,6 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -302,10 +302,6 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -334,10 +334,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -331,10 +331,6 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -378,10 +378,6 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
|
|||||||
def test_generate_methods_with_logits_to_keep(self):
|
def test_generate_methods_with_logits_to_keep(self):
|
||||||
super().test_generate_methods_with_logits_to_keep()
|
super().test_generate_methods_with_logits_to_keep()
|
||||||
|
|
||||||
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip
|
@unittest.skip
|
||||||
def test_training_gradient_checkpointing(self):
|
def test_training_gradient_checkpointing(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -225,10 +225,6 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -305,10 +305,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user