Pix2StructImageProcessor requires torch>=1.11.0 (#24270)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -192,7 +192,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
mel = torch.tensor(mel[None, None, :])
|
||||
mel_shrink = torch.nn.functional.interpolate(
|
||||
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False
|
||||
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False
|
||||
)
|
||||
mel_shrink = mel_shrink[0][0].numpy()
|
||||
mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
|
||||
|
||||
@@ -43,11 +43,23 @@ if is_vision_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_11 = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
DEFAULT_FONT_PATH = "ybelkada/fonts"
|
||||
|
||||
|
||||
def _check_torch_version():
|
||||
if is_torch_available() and not is_torch_greater_or_equal_than_1_11:
|
||||
raise ImportError(
|
||||
f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use "
|
||||
"Pix2StructImageProcessor. Please upgrade torch."
|
||||
)
|
||||
|
||||
|
||||
# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
|
||||
def torch_extract_patches(image_tensor, patch_height, patch_width):
|
||||
"""
|
||||
@@ -63,6 +75,7 @@ def torch_extract_patches(image_tensor, patch_height, patch_width):
|
||||
The width of the patches to extract.
|
||||
"""
|
||||
requires_backends(torch_extract_patches, ["torch"])
|
||||
_check_torch_version()
|
||||
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
|
||||
@@ -240,6 +253,7 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
||||
A sequence of `max_patches` flattened patches.
|
||||
"""
|
||||
requires_backends(self.extract_flattened_patches, "torch")
|
||||
_check_torch_version()
|
||||
|
||||
# convert to torch
|
||||
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
|
||||
|
||||
@@ -30,6 +30,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_
|
||||
|
||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
|
||||
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
|
||||
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
|
||||
|
||||
|
||||
@@ -28,6 +28,10 @@ from ...test_image_processing_common import ImageProcessingSavingTestMixin, prep
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_11 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@@ -70,6 +74,10 @@ class Pix2StructImageProcessingTester(unittest.TestCase):
|
||||
return raw_image
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase):
|
||||
@@ -237,6 +245,10 @@ class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_torch
|
||||
@require_vision
|
||||
class Pix2StructImageProcessingTestFourChannels(ImageProcessingSavingTestMixin, unittest.TestCase):
|
||||
|
||||
@@ -48,6 +48,9 @@ if is_torch_available():
|
||||
Pix2StructVisionModel,
|
||||
)
|
||||
from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_11 = False
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@@ -697,6 +700,10 @@ def prepare_img():
|
||||
return im
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_vision
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
@@ -19,9 +19,14 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_11 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
@@ -34,6 +39,10 @@ if is_vision_available():
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_1_11,
|
||||
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||
)
|
||||
@require_vision
|
||||
@require_torch
|
||||
class Pix2StructProcessorTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user