From 8c102e2eb1c3d6c590eb3ace22fc6c249f1b69a5 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 21 Jul 2025 14:02:57 +0200 Subject: [PATCH] Rename `_supports_flash_attn_2` in examples and tests (#39471) * delete `_supports_flash_attn_2` from examples and tests * simplify docs --- .../modeling_multimodal2.py | 3 +- .../modeling_my_new_model2.py | 3 +- .../modeling_new_task_model.py | 3 +- .../modular-transformers/modeling_super.py | 3 +- .../modeling_modernbert_decoder.py | 2 +- .../modular_modernbert_decoder.py | 2 +- tests/causal_lm_tester.py | 2 +- tests/generation/test_utils.py | 4 +-- tests/models/aimv2/test_modeling_aimv2.py | 4 +-- tests/models/bamba/test_modeling_bamba.py | 2 +- tests/models/bark/test_modeling_bark.py | 4 +-- tests/models/clip/test_modeling_clip.py | 4 +-- tests/models/esm/test_modeling_esm.py | 2 +- .../test_modeling_kyutai_speech_to_text.py | 2 +- .../models/musicgen/test_modeling_musicgen.py | 8 ++--- .../test_modeling_musicgen_melody.py | 8 ++--- tests/models/siglip/test_modeling_siglip.py | 2 +- tests/models/siglip2/test_modeling_siglip2.py | 2 +- .../models/videomae/test_modeling_videomae.py | 2 +- tests/models/vit_mae/test_modeling_vit_mae.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 4 +-- tests/test_modeling_common.py | 29 +++++-------------- 22 files changed, 39 insertions(+), 58 deletions(-) diff --git a/examples/modular-transformers/modeling_multimodal2.py b/examples/modular-transformers/modeling_multimodal2.py index d7592047fc..01aca3a4d5 100644 --- a/examples/modular-transformers/modeling_multimodal2.py +++ b/examples/modular-transformers/modeling_multimodal2.py @@ -497,8 +497,7 @@ class Multimodal2VisionPreTrainedModel(PreTrainedModel): base_model_prefix = "multimodal2_vision" supports_gradient_checkpointing = True _supports_sdpa = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index b87977cf4c..981d40bb6d 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -289,8 +289,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MyNewModel2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 5bcb2f4fb5..21aabae4d3 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -95,8 +95,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 6f78eb972c..c44f12f02f 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -288,8 +288,7 @@ class SuperPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["SuperDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_flash_attn_3 = True + _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 56b53446fd..c3d90771e1 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -221,7 +221,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): base_model_prefix = "model" _skip_keys_device_placement = ["past_key_values"] _no_split_modules = ["ModernBertDecoderLayer"] - _supports_flash_attn_2 = True + _supports_flash_attn = True _supports_sdpa = False _supports_gradient_checkpointing = True _supports_static_cache = False diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index f82fb1573a..d215ccbf0b 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -398,7 +398,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): base_model_prefix = "model" _skip_keys_device_placement = ["past_key_values"] _no_split_modules = ["ModernBertDecoderLayer"] - _supports_flash_attn_2 = True + _supports_flash_attn = True _supports_sdpa = False _supports_gradient_checkpointing = True _supports_static_cache = False diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index 9807c88560..b13f824bf7 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -422,7 +422,7 @@ class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM @slow def test_flash_attn_2_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(reason="Model does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 531bf70d5e..fab1672b5c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2297,8 +2297,8 @@ class GenerationTesterMixin: max_new_tokens = 3 support_flag = { "sdpa": "_supports_sdpa", - "flash_attention_2": "_supports_flash_attn_2", - "flash_attention_3": "_supports_flash_attn_3", + "flash_attention_2": "_supports_flash_attn", + "flash_attention_3": "_supports_flash_attn", } set_model_tester_for_less_flaky_test(self) diff --git a/tests/models/aimv2/test_modeling_aimv2.py b/tests/models/aimv2/test_modeling_aimv2.py index 77893985f9..86ac5087dc 100644 --- a/tests/models/aimv2/test_modeling_aimv2.py +++ b/tests/models/aimv2/test_modeling_aimv2.py @@ -478,7 +478,7 @@ class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa @slow def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -516,7 +516,7 @@ class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa @mark.flash_attn_test def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 0afc7bdbf4..e1d8128a2c 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -524,7 +524,7 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi max_new_tokens = 30 for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 701cd7938c..f8dabd2fa3 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -934,7 +934,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): @slow def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(reason="Model does not support flash_attention_2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -991,7 +991,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): @slow def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(reason="Model does not support flash_attention_2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 82e04fc454..90506e26db 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -715,7 +715,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase @slow def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -753,7 +753,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase @mark.flash_attn_test def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/esm/test_modeling_esm.py b/tests/models/esm/test_modeling_esm.py index 18887bb592..79dd701efd 100644 --- a/tests/models/esm/test_modeling_esm.py +++ b/tests/models/esm/test_modeling_esm.py @@ -318,7 +318,7 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): @slow def test_flash_attn_2_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(reason="Model does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py index ad516904ef..c9aae79438 100644 --- a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py +++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py @@ -548,7 +548,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel max_new_tokens = 30 support_flag = { "sdpa": "_supports_sdpa", - "flash_attention_2": "_supports_flash_attn_2", + "flash_attention_2": "_supports_flash_attn", } for model_class in self.all_generative_model_classes: diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 9356ddf92e..e7eee02ce8 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -292,7 +292,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -372,7 +372,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -948,7 +948,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -1096,7 +1096,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 4aa812a0ae..3d7b45b643 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -300,7 +300,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -382,7 +382,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence_right_padding def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -948,7 +948,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -1096,7 +1096,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 8a5b1037eb..4bff15040e 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -610,7 +610,7 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test @slow def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/siglip2/test_modeling_siglip2.py b/tests/models/siglip2/test_modeling_siglip2.py index 3963d7e48e..f6825308ff 100644 --- a/tests/models/siglip2/test_modeling_siglip2.py +++ b/tests/models/siglip2/test_modeling_siglip2.py @@ -99,7 +99,7 @@ class Siglip2ModelTesterMixin(ModelTesterMixin): dtype = torch.float16 for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") # Prepare inputs diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index 2c592290a6..f8d6f0efae 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -359,7 +359,7 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index c2265b076e..f0494cc059 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -273,7 +273,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index df442342ee..e6526852b9 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -867,7 +867,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi import torch for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(reason="Model does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -913,7 +913,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi import torch for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(reason="Model does not support flash_attention_2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 26f6a032d6..62641ad356 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3471,9 +3471,7 @@ class ModelTesterMixin: self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_model_classes: - if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( - attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 - ): + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -3969,22 +3967,12 @@ class ModelTesterMixin: model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) sub_models_supporting_fa = [ - ( - module._supports_flash_attn_3 - if attn_implementation == "flash_attention_3" - else module._supports_flash_attn_2 - ) + module._supports_flash_attn for name, module in model.named_modules() if isinstance(module, PreTrainedModel) and name != "" ] supports_fa_all_modules = ( - all(sub_models_supporting_fa) - if len(sub_models_supporting_fa) > 0 - else ( - model._supports_flash_attn_3 - if attn_implementation == "flash_attention_3" - else model._supports_flash_attn_2 - ) + all(sub_models_supporting_fa) if len(sub_models_supporting_fa) > 0 else model._supports_flash_attn ) if not supports_fa_all_modules: with self.assertRaises(ValueError): @@ -4037,7 +4025,7 @@ class ModelTesterMixin: self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -4104,9 +4092,8 @@ class ModelTesterMixin: torch_dtype = torch.float16 config, _ = self.model_tester.prepare_config_and_inputs_for_common() - config._attn_implementation = "flash_attention_2" - cls = self._torch_compile_train_cls - model = cls(config).to(device=torch_device, dtype=torch_dtype) + cls = self._torch_compile_train_cls # e.g. LlamaFroCausalLM + model = cls(config, attn_implementation="flash_attention_2").to(device=torch_device, dtype=torch_dtype) inputs = { "input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device), @@ -4268,9 +4255,7 @@ class ModelTesterMixin: self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_generative_model_classes: - if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or ( - attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3 - ): + if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()