Skip some TQAPipelineTests tests in past CI (#24267)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-06-14 14:25:24 +02:00
committed by GitHub
parent 91b62f5a78
commit eac8dede83

View File

@@ -20,6 +20,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
TableQuestionAnsweringPipeline, TableQuestionAnsweringPipeline,
TFAutoModelForTableQuestionAnswering, TFAutoModelForTableQuestionAnswering,
is_torch_available,
pipeline, pipeline,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
@@ -32,6 +33,12 @@ from transformers.testing_utils import (
) )
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
else:
is_torch_greater_or_equal_than_1_12 = False
@is_pipeline_test @is_pipeline_test
class TQAPipelineTests(unittest.TestCase): class TQAPipelineTests(unittest.TestCase):
# Putting it there for consistency, but TQA do not have fast tokenizer # Putting it there for consistency, but TQA do not have fast tokenizer
@@ -143,6 +150,7 @@ class TQAPipelineTests(unittest.TestCase):
}, },
) )
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
model_id = "lysandre/tiny-tapas-random-wtq" model_id = "lysandre/tiny-tapas-random-wtq"
@@ -245,6 +253,7 @@ class TQAPipelineTests(unittest.TestCase):
}, },
) )
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch @require_torch
def test_slow_tokenizer_sqa_pt(self): def test_slow_tokenizer_sqa_pt(self):
model_id = "lysandre/tiny-tapas-random-sqa" model_id = "lysandre/tiny-tapas-random-sqa"
@@ -486,6 +495,7 @@ class TQAPipelineTests(unittest.TestCase):
}, },
) )
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow @slow
@require_torch @require_torch
def test_integration_wtq_pt(self): def test_integration_wtq_pt(self):
@@ -580,6 +590,7 @@ class TQAPipelineTests(unittest.TestCase):
] ]
self.assertListEqual(results, expected_results) self.assertListEqual(results, expected_results)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow @slow
@require_torch @require_torch
def test_integration_sqa_pt(self): def test_integration_sqa_pt(self):