Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (#27940)

fix sdpa dispatch
This commit is contained in:
fxmarty
2023-12-11 10:56:38 +01:00
committed by GitHub
parent 7ea21f1f03
commit 9f18cc6df0
2 changed files with 10 additions and 8 deletions

View File

@@ -83,6 +83,7 @@ from transformers.utils import (
is_flax_available,
is_tf_available,
is_torch_fx_available,
is_torch_sdpa_available,
)
from transformers.utils.generic import ModelOutput
@@ -778,7 +779,7 @@ class ModelTesterMixin:
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
continue
configs_no_init._attn_implementation = attn_implementation