Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (#27940)
fix sdpa dispatch
This commit is contained in:
@@ -83,6 +83,7 @@ from transformers.utils import (
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_fx_available,
|
||||
is_torch_sdpa_available,
|
||||
)
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
@@ -778,7 +779,7 @@ class ModelTesterMixin:
|
||||
configs_no_init.torchscript = True
|
||||
for model_class in self.all_model_classes:
|
||||
for attn_implementation in ["eager", "sdpa"]:
|
||||
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
|
||||
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
||||
continue
|
||||
|
||||
configs_no_init._attn_implementation = attn_implementation
|
||||
|
||||
Reference in New Issue
Block a user