Rename _supports_flash_attn_2 in examples and tests (#39471)
* delete `_supports_flash_attn_2` from examples and tests * simplify docs
This commit is contained in:
committed by
GitHub
parent
3a152e3a5c
commit
8c102e2eb1
@@ -497,8 +497,7 @@ class Multimodal2VisionPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "multimodal2_vision"
|
base_model_prefix = "multimodal2_vision"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn = True
|
||||||
_supports_flash_attn_3 = True
|
|
||||||
_supports_flex_attn = True
|
_supports_flex_attn = True
|
||||||
_supports_attention_backend = True
|
_supports_attention_backend = True
|
||||||
|
|
||||||
|
|||||||
@@ -289,8 +289,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["MyNewModel2DecoderLayer"]
|
_no_split_modules = ["MyNewModel2DecoderLayer"]
|
||||||
_skip_keys_device_placement = ["past_key_values"]
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn = True
|
||||||
_supports_flash_attn_3 = True
|
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_flex_attn = True
|
_supports_flex_attn = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
|||||||
@@ -95,8 +95,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
|||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_quantized_cache = True
|
_supports_quantized_cache = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn = True
|
||||||
_supports_flash_attn_3 = True
|
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_flex_attn = True
|
_supports_flex_attn = True
|
||||||
_supports_attention_backend = True
|
_supports_attention_backend = True
|
||||||
|
|||||||
@@ -288,8 +288,7 @@ class SuperPreTrainedModel(PreTrainedModel):
|
|||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["SuperDecoderLayer"]
|
_no_split_modules = ["SuperDecoderLayer"]
|
||||||
_skip_keys_device_placement = ["past_key_values"]
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn = True
|
||||||
_supports_flash_attn_3 = True
|
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_supports_flex_attn = True
|
_supports_flex_attn = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_skip_keys_device_placement = ["past_key_values"]
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
_no_split_modules = ["ModernBertDecoderLayer"]
|
_no_split_modules = ["ModernBertDecoderLayer"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn = True
|
||||||
_supports_sdpa = False
|
_supports_sdpa = False
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_supports_static_cache = False
|
_supports_static_cache = False
|
||||||
|
|||||||
@@ -398,7 +398,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
|
|||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_skip_keys_device_placement = ["past_key_values"]
|
_skip_keys_device_placement = ["past_key_values"]
|
||||||
_no_split_modules = ["ModernBertDecoderLayer"]
|
_no_split_modules = ["ModernBertDecoderLayer"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn = True
|
||||||
_supports_sdpa = False
|
_supports_sdpa = False
|
||||||
_supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_supports_static_cache = False
|
_supports_static_cache = False
|
||||||
|
|||||||
@@ -422,7 +422,7 @@ class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_equivalence(self):
|
def test_flash_attn_2_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -2297,8 +2297,8 @@ class GenerationTesterMixin:
|
|||||||
max_new_tokens = 3
|
max_new_tokens = 3
|
||||||
support_flag = {
|
support_flag = {
|
||||||
"sdpa": "_supports_sdpa",
|
"sdpa": "_supports_sdpa",
|
||||||
"flash_attention_2": "_supports_flash_attn_2",
|
"flash_attention_2": "_supports_flash_attn",
|
||||||
"flash_attention_3": "_supports_flash_attn_3",
|
"flash_attention_3": "_supports_flash_attn",
|
||||||
}
|
}
|
||||||
|
|
||||||
set_model_tester_for_less_flaky_test(self)
|
set_model_tester_for_less_flaky_test(self)
|
||||||
|
|||||||
@@ -478,7 +478,7 @@ class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -516,7 +516,7 @@ class Aimv2ModelTest(Aimv2ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -524,7 +524,7 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
max_new_tokens = 30
|
max_new_tokens = 30
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -934,7 +934,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(reason="Model does not support flash_attention_2")
|
self.skipTest(reason="Model does not support flash_attention_2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -991,7 +991,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(reason="Model does not support flash_attention_2")
|
self.skipTest(reason="Model does not support flash_attention_2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -715,7 +715,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -753,7 +753,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -318,7 +318,7 @@ class EsmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_equivalence(self):
|
def test_flash_attn_2_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -548,7 +548,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
max_new_tokens = 30
|
max_new_tokens = 30
|
||||||
support_flag = {
|
support_flag = {
|
||||||
"sdpa": "_supports_sdpa",
|
"sdpa": "_supports_sdpa",
|
||||||
"flash_attention_2": "_supports_flash_attn_2",
|
"flash_attention_2": "_supports_flash_attn",
|
||||||
}
|
}
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -372,7 +372,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -948,7 +948,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -1096,7 +1096,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence
|
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -382,7 +382,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence_right_padding
|
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence_right_padding
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -948,7 +948,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -1096,7 +1096,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -610,7 +610,7 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test
|
|||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_equivalence(self):
|
def test_flash_attn_2_inference_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class Siglip2ModelTesterMixin(ModelTesterMixin):
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
|
|||||||
@@ -359,7 +359,7 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ class ViTMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -867,7 +867,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -913,7 +913,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(reason="Model does not support flash_attention_2")
|
self.skipTest(reason="Model does not support flash_attention_2")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -3471,9 +3471,7 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
|
if not model_class._supports_flash_attn:
|
||||||
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
|
|
||||||
):
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -3969,22 +3967,12 @@ class ModelTesterMixin:
|
|||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
sub_models_supporting_fa = [
|
sub_models_supporting_fa = [
|
||||||
(
|
module._supports_flash_attn
|
||||||
module._supports_flash_attn_3
|
|
||||||
if attn_implementation == "flash_attention_3"
|
|
||||||
else module._supports_flash_attn_2
|
|
||||||
)
|
|
||||||
for name, module in model.named_modules()
|
for name, module in model.named_modules()
|
||||||
if isinstance(module, PreTrainedModel) and name != ""
|
if isinstance(module, PreTrainedModel) and name != ""
|
||||||
]
|
]
|
||||||
supports_fa_all_modules = (
|
supports_fa_all_modules = (
|
||||||
all(sub_models_supporting_fa)
|
all(sub_models_supporting_fa) if len(sub_models_supporting_fa) > 0 else model._supports_flash_attn
|
||||||
if len(sub_models_supporting_fa) > 0
|
|
||||||
else (
|
|
||||||
model._supports_flash_attn_3
|
|
||||||
if attn_implementation == "flash_attention_3"
|
|
||||||
else model._supports_flash_attn_2
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if not supports_fa_all_modules:
|
if not supports_fa_all_modules:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@@ -4037,7 +4025,7 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_flash_attn_2:
|
if not model_class._supports_flash_attn:
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
@@ -4104,9 +4092,8 @@ class ModelTesterMixin:
|
|||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config._attn_implementation = "flash_attention_2"
|
cls = self._torch_compile_train_cls # e.g. LlamaFroCausalLM
|
||||||
cls = self._torch_compile_train_cls
|
model = cls(config, attn_implementation="flash_attention_2").to(device=torch_device, dtype=torch_dtype)
|
||||||
model = cls(config).to(device=torch_device, dtype=torch_dtype)
|
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||||
@@ -4268,9 +4255,7 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if (attn_implementation == "flash_attention_2" and not model_class._supports_flash_attn_2) or (
|
if not model_class._supports_flash_attn:
|
||||||
attn_implementation == "flash_attention_3" and not model_class._supports_flash_attn_3
|
|
||||||
):
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user