Aurevoir PyTorch 1 (#35358)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-12-20 14:36:31 +01:00
committed by GitHub
parent 4567ee8057
commit 05de764e9c
53 changed files with 37 additions and 228 deletions

View File

@@ -20,7 +20,6 @@ from transformers import (
AutoTokenizer,
TableQuestionAnsweringPipeline,
TFAutoModelForTableQuestionAnswering,
is_torch_available,
pipeline,
)
from transformers.testing_utils import (
@@ -33,12 +32,6 @@ from transformers.testing_utils import (
)
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
else:
is_torch_greater_or_equal_than_1_12 = False
@is_pipeline_test
class TQAPipelineTests(unittest.TestCase):
# Putting it there for consistency, but TQA do not have fast tokenizer
@@ -150,7 +143,6 @@ class TQAPipelineTests(unittest.TestCase):
},
)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_small_model_pt(self, torch_dtype="float32"):
model_id = "lysandre/tiny-tapas-random-wtq"
@@ -253,12 +245,10 @@ class TQAPipelineTests(unittest.TestCase):
},
)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_small_model_pt_fp16(self):
self.test_small_model_pt(torch_dtype="float16")
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_slow_tokenizer_sqa_pt(self, torch_dtype="float32"):
model_id = "lysandre/tiny-tapas-random-sqa"
@@ -378,7 +368,6 @@ class TQAPipelineTests(unittest.TestCase):
},
)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@require_torch
def test_slow_tokenizer_sqa_pt_fp16(self):
self.test_slow_tokenizer_sqa_pt(torch_dtype="float16")
@@ -505,7 +494,6 @@ class TQAPipelineTests(unittest.TestCase):
},
)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_wtq_pt(self, torch_dtype="float32"):
@@ -551,7 +539,6 @@ class TQAPipelineTests(unittest.TestCase):
]
self.assertListEqual(results, expected_results)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_wtq_pt_fp16(self):
@@ -606,7 +593,6 @@ class TQAPipelineTests(unittest.TestCase):
]
self.assertListEqual(results, expected_results)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_sqa_pt(self, torch_dtype="float32"):
@@ -632,7 +618,6 @@ class TQAPipelineTests(unittest.TestCase):
]
self.assertListEqual(results, expected_results)
@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+")
@slow
@require_torch
def test_integration_sqa_pt_fp16(self):