From a04ebc8b33314c42349c3e12885960a292c9c9dd Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 14 Jun 2023 17:05:40 +0200 Subject: [PATCH] `Pix2StructImageProcessor` requires `torch>=1.11.0` (#24270) * fix * fix * fix --------- Co-authored-by: ydshieh --- .../models/clap/feature_extraction_clap.py | 2 +- .../pix2struct/image_processing_pix2struct.py | 14 ++++++++++++++ src/transformers/pytorch_utils.py | 1 + .../pix2struct/test_image_processing_pix2struct.py | 12 ++++++++++++ .../models/pix2struct/test_modeling_pix2struct.py | 7 +++++++ .../models/pix2struct/test_processor_pix2struct.py | 11 ++++++++++- 6 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/clap/feature_extraction_clap.py b/src/transformers/models/clap/feature_extraction_clap.py index d33307ffbd..5b9df8225b 100644 --- a/src/transformers/models/clap/feature_extraction_clap.py +++ b/src/transformers/models/clap/feature_extraction_clap.py @@ -192,7 +192,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor): mel = torch.tensor(mel[None, None, :]) mel_shrink = torch.nn.functional.interpolate( - mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False + mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False ) mel_shrink = mel_shrink[0][0].numpy() mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) diff --git a/src/transformers/models/pix2struct/image_processing_pix2struct.py b/src/transformers/models/pix2struct/image_processing_pix2struct.py index b85d123841..3e72bc5ad5 100644 --- a/src/transformers/models/pix2struct/image_processing_pix2struct.py +++ b/src/transformers/models/pix2struct/image_processing_pix2struct.py @@ -43,11 +43,23 @@ if is_vision_available(): if is_torch_available(): import torch + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11 +else: + is_torch_greater_or_equal_than_1_11 = False + logger = logging.get_logger(__name__) DEFAULT_FONT_PATH = "ybelkada/fonts" +def _check_torch_version(): + if is_torch_available() and not is_torch_greater_or_equal_than_1_11: + raise ImportError( + f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use " + "Pix2StructImageProcessor. Please upgrade torch." + ) + + # adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2 def torch_extract_patches(image_tensor, patch_height, patch_width): """ @@ -63,6 +75,7 @@ def torch_extract_patches(image_tensor, patch_height, patch_width): The width of the patches to extract. """ requires_backends(torch_extract_patches, ["torch"]) + _check_torch_version() image_tensor = image_tensor.unsqueeze(0) patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) @@ -240,6 +253,7 @@ class Pix2StructImageProcessor(BaseImageProcessor): A sequence of `max_patches` flattened patches. """ requires_backends(self.extract_flattened_patches, "torch") + _check_torch_version() # convert to torch image = to_channel_dimension_format(image, ChannelDimension.FIRST) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 8392616634..3beaf31efa 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -30,6 +30,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_ is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") +is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11") is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10") is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11") diff --git a/tests/models/pix2struct/test_image_processing_pix2struct.py b/tests/models/pix2struct/test_image_processing_pix2struct.py index fd805da696..51a4708a76 100644 --- a/tests/models/pix2struct/test_image_processing_pix2struct.py +++ b/tests/models/pix2struct/test_image_processing_pix2struct.py @@ -28,6 +28,10 @@ from ...test_image_processing_common import ImageProcessingSavingTestMixin, prep if is_torch_available(): import torch + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11 +else: + is_torch_greater_or_equal_than_1_11 = False + if is_vision_available(): from PIL import Image @@ -70,6 +74,10 @@ class Pix2StructImageProcessingTester(unittest.TestCase): return raw_image +@unittest.skipIf( + not is_torch_greater_or_equal_than_1_11, + reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.", +) @require_torch @require_vision class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase): @@ -237,6 +245,10 @@ class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes ) +@unittest.skipIf( + not is_torch_greater_or_equal_than_1_11, + reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.", +) @require_torch @require_vision class Pix2StructImageProcessingTestFourChannels(ImageProcessingSavingTestMixin, unittest.TestCase): diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 4d028d111d..e40e6c81ea 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -48,6 +48,9 @@ if is_torch_available(): Pix2StructVisionModel, ) from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11 +else: + is_torch_greater_or_equal_than_1_11 = False if is_vision_available(): @@ -697,6 +700,10 @@ def prepare_img(): return im +@unittest.skipIf( + not is_torch_greater_or_equal_than_1_11, + reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.", +) @require_vision @require_torch @slow diff --git a/tests/models/pix2struct/test_processor_pix2struct.py b/tests/models/pix2struct/test_processor_pix2struct.py index 318e6f301f..e0ee398b3a 100644 --- a/tests/models/pix2struct/test_processor_pix2struct.py +++ b/tests/models/pix2struct/test_processor_pix2struct.py @@ -19,9 +19,14 @@ import numpy as np import pytest from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_vision_available +from transformers.utils import is_torch_available, is_vision_available +if is_torch_available(): + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11 +else: + is_torch_greater_or_equal_than_1_11 = False + if is_vision_available(): from PIL import Image @@ -34,6 +39,10 @@ if is_vision_available(): ) +@unittest.skipIf( + not is_torch_greater_or_equal_than_1_11, + reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.", +) @require_vision @require_torch class Pix2StructProcessorTest(unittest.TestCase):