Update tests regarding attention types after #35235 (#36024)

* update

* update

* update

* dev-ci

* more changes

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-02-04 18:04:47 +01:00
committed by GitHub
parent 014a1fa2c8
commit fe52679e74

View File

@@ -3872,11 +3872,13 @@ class ModelTesterMixin:
for name, submodule in model.named_modules(): for name, submodule in model.named_modules():
class_name = submodule.__class__.__name__ class_name = submodule.__class__.__name__
if ( if (
"SdpaAttention" in class_name class_name.endswith("Attention")
or "SdpaSelfAttention" in class_name and getattr(submodule, "config", None)
or "FlashAttention" in class_name and submodule.config._attn_implementation != "eager"
): ):
raise ValueError(f"The eager model should not have SDPA/FA2 attention layers but got {class_name}") raise ValueError(
f"The eager model should not have SDPA/FA2 attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`"
)
@require_torch_sdpa @require_torch_sdpa
def test_sdpa_can_dispatch_non_composite_models(self): def test_sdpa_can_dispatch_non_composite_models(self):
@@ -3907,8 +3909,14 @@ class ModelTesterMixin:
for name, submodule in model_eager.named_modules(): for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__ class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: if (
raise ValueError(f"The eager model should not have SDPA attention layers but got {class_name}") class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError(
f"The eager model should not have SDPA attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`"
)
@require_torch_sdpa @require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self): def test_sdpa_can_dispatch_composite_models(self):
@@ -3959,7 +3967,11 @@ class ModelTesterMixin:
for name, submodule in model_eager.named_modules(): for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__ class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: if (
class_name.endswith("Attention")
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "sdpa"
):
raise ValueError("The eager model should not have SDPA attention layers") raise ValueError("The eager model should not have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@@ -4446,7 +4458,11 @@ class ModelTesterMixin:
has_fa2 = False has_fa2 = False
for name, submodule in model_fa2.named_modules(): for name, submodule in model_fa2.named_modules():
class_name = submodule.__class__.__name__ class_name = submodule.__class__.__name__
if "FlashAttention" in class_name: if (
"Attention" in class_name
and getattr(submodule, "config", None)
and submodule.config._attn_implementation == "flash_attention_2"
):
has_fa2 = True has_fa2 = True
break break
if not has_fa2: if not has_fa2: