From 233113149b02756460fe07cfc047aff1d9c7dfeb Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 13 Jun 2023 20:33:26 +0200 Subject: [PATCH] Skip `GPT-J` fx tests for torch < 1.12 (#24256) * fix * fix --------- Co-authored-by: ydshieh --- tests/models/gptj/test_modeling_gptj.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py index 5fe0fec391..7fd6a40e17 100644 --- a/tests/models/gptj/test_modeling_gptj.py +++ b/tests/models/gptj/test_modeling_gptj.py @@ -37,6 +37,9 @@ if is_torch_available(): GPTJForSequenceClassification, GPTJModel, ) + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 +else: + is_torch_greater_or_equal_than_1_12 = False class GPTJModelTester: @@ -385,6 +388,18 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin test_model_parallel = False test_head_masking = False + @unittest.skipIf( + not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+." + ) + def test_torch_fx(self): + super().test_torch_fx() + + @unittest.skipIf( + not is_torch_greater_or_equal_than_1_12, reason="PR #22069 made changes that require torch v1.12+." + ) + def test_torch_fx_output_loss(self): + super().test_torch_fx_output_loss() + # TODO: Fix the failed tests def is_pipeline_test_to_skip( self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name