diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 8744cb168f..04a6ad99b8 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -879,7 +879,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: return @@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: return diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 481d4b24cd..6bd821859e 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -301,7 +301,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa @require_torch_accelerator @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): import torch for model_class in self.all_model_classes: @@ -353,7 +353,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa @require_torch_accelerator @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): import torch for model_class in self.all_model_classes: diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 1b32f1b16e..8c3aa392ba 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -462,7 +462,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Gemma flash attention does not support right padding") @require_torch_sdpa diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 2e675a2851..432097e9d1 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -466,7 +466,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index efd48d6a9c..98654c5133 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -465,7 +465,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mixtral flash attention does not support right padding") # Ignore copy diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 49da4fec98..21ee694bdc 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -477,7 +477,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index f0794c46dc..95f604d06b 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -461,7 +461,7 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Starcoder2 flash attention does not support right padding") diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index b79f3a2c0d..7ff6387ff2 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -888,7 +888,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): import torch for model_class in self.all_model_classes: @@ -934,7 +934,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): import torch for model_class in self.all_model_classes: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3dbfea719a..7241993b6d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3245,7 +3245,7 @@ class ModelTesterMixin: @require_torch_gpu @mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3260,9 +3260,7 @@ class ModelTesterMixin: ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) dummy_input = inputs_dict[model.main_input_name][:1] @@ -3340,7 +3338,7 @@ class ModelTesterMixin: @require_torch_gpu @mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3355,9 +3353,7 @@ class ModelTesterMixin: ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) dummy_input = inputs_dict[model.main_input_name][:1]