Byebye torch 1.10 (#28207)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -40,7 +40,6 @@ if is_torch_available():
|
||||
LongT5Model,
|
||||
)
|
||||
from transformers.models.longt5.modeling_longt5 import LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.pytorch_utils import is_torch_less_than_1_11
|
||||
|
||||
|
||||
class LongT5ModelTester:
|
||||
@@ -595,10 +594,6 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
model = LongT5Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_available() or is_torch_less_than_1_11,
|
||||
"Test failed with torch < 1.11 with an exception in a C++ file.",
|
||||
)
|
||||
@slow
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
@@ -28,10 +28,6 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_11 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@@ -85,10 +81,6 @@ class Pix2StructImageProcessingTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
@@ -290,10 +282,6 @@ class Pix2StructImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Pix2StructImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase):
|
||||
|
||||
@@ -49,9 +49,6 @@ if is_torch_available():
|
||||
Pix2StructVisionModel,
|
||||
)
|
||||
from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_11 = False
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@@ -746,10 +743,6 @@ def prepare_img():
|
||||
return im
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_vision
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
@@ -19,14 +19,9 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_11 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@@ -39,10 +34,6 @@ if is_vision_available():
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_vision
|
||||
@require_torch
|
||||
class Pix2StructProcessorTest(unittest.TestCase):
|
||||
|
||||
@@ -45,10 +45,6 @@ if is_torch_available():
|
||||
|
||||
from transformers import Pop2PianoForConditionalGeneration
|
||||
from transformers.models.pop2piano.modeling_pop2piano import POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.pytorch_utils import is_torch_1_8_0
|
||||
|
||||
else:
|
||||
is_torch_1_8_0 = False
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -616,10 +612,6 @@ class Pop2PianoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_onnx
|
||||
@unittest.skipIf(
|
||||
is_torch_1_8_0,
|
||||
reason="Test has a segmentation fault on torch 1.8.0",
|
||||
)
|
||||
def test_export_to_onnx(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = Pop2PianoForConditionalGeneration(config_and_inputs[0]).to(torch_device)
|
||||
|
||||
@@ -906,7 +906,6 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 270)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 647)
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
@@ -931,5 +930,4 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.5616, 3)
|
||||
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertAlmostEqual(outputs.loss.item(), 18.5925, 2)
|
||||
|
||||
@@ -1928,7 +1928,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 555)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 299)
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
@@ -1953,7 +1952,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).numpy(), 0.7594, 3)
|
||||
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 2)
|
||||
|
||||
@require_torchaudio
|
||||
|
||||
@@ -515,7 +515,6 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_HIDDEN_STATES_SLICE = torch.tensor(
|
||||
[[[0.0577, 0.1161], [0.0579, 0.1165]], [[0.0199, 0.1237], [0.0059, 0.0605]]]
|
||||
)
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, atol=5e-2))
|
||||
|
||||
def test_inference_large(self):
|
||||
@@ -567,7 +566,6 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 258)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 647)
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
@@ -592,5 +590,4 @@ class WavLMModelIntegrationTest(unittest.TestCase):
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.4780, 3)
|
||||
|
||||
# TODO: update the tolerance after the CI moves to torch 1.10
|
||||
self.assertAlmostEqual(outputs.loss.item(), 18.4154, 2)
|
||||
|
||||
Reference in New Issue
Block a user