From af6120b3eb2470b994c21421bb6eaa76576128b0 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 19 Jun 2025 15:11:01 +0200 Subject: [PATCH] Skip sdpa tests if submodule does not support sdpa (#38907) --- tests/test_modeling_common.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8a69b2e0a3..4c7cef05c3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3799,8 +3799,20 @@ class ModelTesterMixin: self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input") if config.model_type in ["sam"]: self.skipTest(reason="SAM requires an attention_mask input for relative positional embeddings") + model = model_class(config) + sub_models_supporting_sdpa = [ + module._supports_sdpa + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ] + supports_sdpa_all_modules = ( + all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa + ) + if not supports_sdpa_all_modules: + self.skipTest(reason="This models' submodels does not support sdpa") + with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa") @@ -3848,8 +3860,20 @@ class ModelTesterMixin: "Cannot compile forward without an existing cache with Hybrid, as `torch._dynamo.mark_static_address` " "is a forbidden call." ) + model = model_class(config) + sub_models_supporting_sdpa = [ + module._supports_sdpa + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ] + supports_sdpa_all_modules = ( + all(sub_models_supporting_sdpa) if len(sub_models_supporting_sdpa) > 0 else model._supports_sdpa + ) + if not supports_sdpa_all_modules: + self.skipTest(reason="This models' submodels does not support sdpa") + with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa")