From 3c2d4d60d7fd96cb297d1c0c23036ff263bf9311 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 24 Jun 2024 08:09:21 +0100 Subject: [PATCH] Correct @is_flaky test decoration (#31480) * Correct @is_flaky decorator --- tests/models/gemma/test_modeling_gemma.py | 2 +- tests/test_modeling_common.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 6aeb5f23c3..445a48e45a 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -493,7 +493,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test - @is_flaky + @is_flaky() @slow def test_flash_attn_2_equivalence(self): for model_class in self.all_model_classes: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f81bd5c8c3..f7f0db79a8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3407,7 +3407,7 @@ class ModelTesterMixin: @require_torch_gpu @mark.flash_attn_test @slow - @is_flaky + @is_flaky() def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: @@ -3501,7 +3501,7 @@ class ModelTesterMixin: @require_torch_gpu @mark.flash_attn_test @slow - @is_flaky + @is_flaky() def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: @@ -3591,7 +3591,7 @@ class ModelTesterMixin: @require_torch_gpu @mark.flash_attn_test @slow - @is_flaky + @is_flaky() def test_flash_attn_2_generate_left_padding(self): for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: @@ -3635,7 +3635,7 @@ class ModelTesterMixin: @require_flash_attn @require_torch_gpu @mark.flash_attn_test - @is_flaky + @is_flaky() @slow def test_flash_attn_2_generate_padding_right(self): for model_class in self.all_generative_model_classes: