From eac8dede838c6ef965866689e85916d81613400a Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 14 Jun 2023 14:25:24 +0200 Subject: [PATCH] Skip some `TQAPipelineTests` tests in past CI (#24267) fix Co-authored-by: ydshieh --- .../test_pipelines_table_question_answering.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py index 6c427d840c..a30763fc09 100644 --- a/tests/pipelines/test_pipelines_table_question_answering.py +++ b/tests/pipelines/test_pipelines_table_question_answering.py @@ -20,6 +20,7 @@ from transformers import ( AutoTokenizer, TableQuestionAnsweringPipeline, TFAutoModelForTableQuestionAnswering, + is_torch_available, pipeline, ) from transformers.testing_utils import ( @@ -32,6 +33,12 @@ from transformers.testing_utils import ( ) +if is_torch_available(): + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 +else: + is_torch_greater_or_equal_than_1_12 = False + + @is_pipeline_test class TQAPipelineTests(unittest.TestCase): # Putting it there for consistency, but TQA do not have fast tokenizer @@ -143,6 +150,7 @@ class TQAPipelineTests(unittest.TestCase): }, ) + @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch def test_small_model_pt(self): model_id = "lysandre/tiny-tapas-random-wtq" @@ -245,6 +253,7 @@ class TQAPipelineTests(unittest.TestCase): }, ) + @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch def test_slow_tokenizer_sqa_pt(self): model_id = "lysandre/tiny-tapas-random-sqa" @@ -486,6 +495,7 @@ class TQAPipelineTests(unittest.TestCase): }, ) + @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @slow @require_torch def test_integration_wtq_pt(self): @@ -580,6 +590,7 @@ class TQAPipelineTests(unittest.TestCase): ] self.assertListEqual(results, expected_results) + @unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @slow @require_torch def test_integration_sqa_pt(self):