* update * update * update * dev-ci * more changes * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -3872,11 +3872,13 @@ class ModelTesterMixin:
|
|||||||
for name, submodule in model.named_modules():
|
for name, submodule in model.named_modules():
|
||||||
class_name = submodule.__class__.__name__
|
class_name = submodule.__class__.__name__
|
||||||
if (
|
if (
|
||||||
"SdpaAttention" in class_name
|
class_name.endswith("Attention")
|
||||||
or "SdpaSelfAttention" in class_name
|
and getattr(submodule, "config", None)
|
||||||
or "FlashAttention" in class_name
|
and submodule.config._attn_implementation != "eager"
|
||||||
):
|
):
|
||||||
raise ValueError(f"The eager model should not have SDPA/FA2 attention layers but got {class_name}")
|
raise ValueError(
|
||||||
|
f"The eager model should not have SDPA/FA2 attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`"
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
def test_sdpa_can_dispatch_non_composite_models(self):
|
def test_sdpa_can_dispatch_non_composite_models(self):
|
||||||
@@ -3907,8 +3909,14 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
for name, submodule in model_eager.named_modules():
|
for name, submodule in model_eager.named_modules():
|
||||||
class_name = submodule.__class__.__name__
|
class_name = submodule.__class__.__name__
|
||||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
if (
|
||||||
raise ValueError(f"The eager model should not have SDPA attention layers but got {class_name}")
|
class_name.endswith("Attention")
|
||||||
|
and getattr(submodule, "config", None)
|
||||||
|
and submodule.config._attn_implementation == "sdpa"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"The eager model should not have SDPA attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`"
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
def test_sdpa_can_dispatch_composite_models(self):
|
def test_sdpa_can_dispatch_composite_models(self):
|
||||||
@@ -3959,7 +3967,11 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
for name, submodule in model_eager.named_modules():
|
for name, submodule in model_eager.named_modules():
|
||||||
class_name = submodule.__class__.__name__
|
class_name = submodule.__class__.__name__
|
||||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
if (
|
||||||
|
class_name.endswith("Attention")
|
||||||
|
and getattr(submodule, "config", None)
|
||||||
|
and submodule.config._attn_implementation == "sdpa"
|
||||||
|
):
|
||||||
raise ValueError("The eager model should not have SDPA attention layers")
|
raise ValueError("The eager model should not have SDPA attention layers")
|
||||||
|
|
||||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||||
@@ -4446,7 +4458,11 @@ class ModelTesterMixin:
|
|||||||
has_fa2 = False
|
has_fa2 = False
|
||||||
for name, submodule in model_fa2.named_modules():
|
for name, submodule in model_fa2.named_modules():
|
||||||
class_name = submodule.__class__.__name__
|
class_name = submodule.__class__.__name__
|
||||||
if "FlashAttention" in class_name:
|
if (
|
||||||
|
"Attention" in class_name
|
||||||
|
and getattr(submodule, "config", None)
|
||||||
|
and submodule.config._attn_implementation == "flash_attention_2"
|
||||||
|
):
|
||||||
has_fa2 = True
|
has_fa2 = True
|
||||||
break
|
break
|
||||||
if not has_fa2:
|
if not has_fa2:
|
||||||
|
|||||||
Reference in New Issue
Block a user