From fe52679e74be29c6984ea15b318e0074703f5c77 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 4 Feb 2025 18:04:47 +0100 Subject: [PATCH] Update tests regarding attention types after #35235 (#36024) * update * update * update * dev-ci * more changes * fix * fix * fix --------- Co-authored-by: ydshieh --- tests/test_modeling_common.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0f47767e41..7fedc4e754 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3872,11 +3872,13 @@ class ModelTesterMixin: for name, submodule in model.named_modules(): class_name = submodule.__class__.__name__ if ( - "SdpaAttention" in class_name - or "SdpaSelfAttention" in class_name - or "FlashAttention" in class_name + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation != "eager" ): - raise ValueError(f"The eager model should not have SDPA/FA2 attention layers but got {class_name}") + raise ValueError( + f"The eager model should not have SDPA/FA2 attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`" + ) @require_torch_sdpa def test_sdpa_can_dispatch_non_composite_models(self): @@ -3907,8 +3909,14 @@ class ModelTesterMixin: for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError(f"The eager model should not have SDPA attention layers but got {class_name}") + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): + raise ValueError( + f"The eager model should not have SDPA attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`" + ) @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): @@ -3959,7 +3967,11 @@ class ModelTesterMixin: for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + if ( + class_name.endswith("Attention") + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "sdpa" + ): raise ValueError("The eager model should not have SDPA attention layers") @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @@ -4446,7 +4458,11 @@ class ModelTesterMixin: has_fa2 = False for name, submodule in model_fa2.named_modules(): class_name = submodule.__class__.__name__ - if "FlashAttention" in class_name: + if ( + "Attention" in class_name + and getattr(submodule, "config", None) + and submodule.config._attn_implementation == "flash_attention_2" + ): has_fa2 = True break if not has_fa2: