From 569f6c7d43a9bfee1fafe9d57f8d951141080a06 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Mon, 1 Apr 2024 08:51:00 +0100 Subject: [PATCH] Fix FA2 tests (#29909) * fix FA2 tests * refactor inference test name --- tests/models/bark/test_modeling_bark.py | 4 ++-- tests/models/distilbert/test_modeling_distilbert.py | 4 ++-- tests/models/gemma/test_modeling_gemma.py | 2 +- tests/models/mistral/test_modeling_mistral.py | 2 +- tests/models/mixtral/test_modeling_mixtral.py | 2 +- tests/models/qwen2/test_modeling_qwen2.py | 2 +- tests/models/starcoder2/test_modeling_starcoder2.py | 2 +- tests/models/whisper/test_modeling_whisper.py | 4 ++-- tests/test_modeling_common.py | 12 ++++-------- 9 files changed, 15 insertions(+), 19 deletions(-) diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 8744cb168f..04a6ad99b8 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -879,7 +879,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: return @@ -936,7 +936,7 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: return diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 481d4b24cd..6bd821859e 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -301,7 +301,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa @require_torch_accelerator @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): import torch for model_class in self.all_model_classes: @@ -353,7 +353,7 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa @require_torch_accelerator @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): import torch for model_class in self.all_model_classes: diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 1b32f1b16e..8c3aa392ba 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -462,7 +462,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Gemma flash attention does not support right padding") @require_torch_sdpa diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 2e675a2851..432097e9d1 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -466,7 +466,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mistral flash attention does not support right padding") diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index efd48d6a9c..98654c5133 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -465,7 +465,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Mixtral flash attention does not support right padding") # Ignore copy diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 49da4fec98..21ee694bdc 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -477,7 +477,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Qwen2 flash attention does not support right padding") diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index f0794c46dc..95f604d06b 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -461,7 +461,7 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("Starcoder2 flash attention does not support right padding") diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index b79f3a2c0d..7ff6387ff2 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -888,7 +888,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): import torch for model_class in self.all_model_classes: @@ -934,7 +934,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_gpu @pytest.mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): import torch for model_class in self.all_model_classes: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3dbfea719a..7241993b6d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3245,7 +3245,7 @@ class ModelTesterMixin: @require_torch_gpu @mark.flash_attn_test @slow - def test_flash_attn_2_inference(self): + def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3260,9 +3260,7 @@ class ModelTesterMixin: ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) dummy_input = inputs_dict[model.main_input_name][:1] @@ -3340,7 +3338,7 @@ class ModelTesterMixin: @require_torch_gpu @mark.flash_attn_test @slow - def test_flash_attn_2_inference_padding_right(self): + def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") @@ -3355,9 +3353,7 @@ class ModelTesterMixin: ) model_fa.to(torch_device) - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) model.to(torch_device) dummy_input = inputs_dict[model.main_input_name][:1]