Rename _supports_flash_attn_2 in examples and tests (#39471)
* delete `_supports_flash_attn_2` from examples and tests * simplify docs
This commit is contained in:
committed by
GitHub
parent
3a152e3a5c
commit
8c102e2eb1
@@ -478,7 +478,7 @@ class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -516,7 +516,7 @@ class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
@mark.flash_attn_test
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -524,7 +524,7 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -934,7 +934,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(reason="Model does not support flash_attention_2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -991,7 +991,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(reason="Model does not support flash_attention_2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -715,7 +715,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -753,7 +753,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
@mark.flash_attn_test
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -318,7 +318,7 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -548,7 +548,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
||||
max_new_tokens = 30
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn_2",
|
||||
"flash_attention_2": "_supports_flash_attn",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
@@ -292,7 +292,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -372,7 +372,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -948,7 +948,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -1096,7 +1096,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -300,7 +300,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -382,7 +382,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence_right_padding
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -948,7 +948,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -1096,7 +1096,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -610,7 +610,7 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -99,7 +99,7 @@ class Siglip2ModelTesterMixin(ModelTesterMixin):
|
||||
dtype = torch.float16
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
# Prepare inputs
|
||||
|
||||
@@ -359,7 +359,7 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -273,7 +273,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -867,7 +867,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -913,7 +913,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
if not model_class._supports_flash_attn:
|
||||
self.skipTest(reason="Model does not support flash_attention_2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user