Pix2StructImageProcessor requires torch>=1.11.0 (#24270)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -192,7 +192,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
|
|
||||||
mel = torch.tensor(mel[None, None, :])
|
mel = torch.tensor(mel[None, None, :])
|
||||||
mel_shrink = torch.nn.functional.interpolate(
|
mel_shrink = torch.nn.functional.interpolate(
|
||||||
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False
|
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False
|
||||||
)
|
)
|
||||||
mel_shrink = mel_shrink[0][0].numpy()
|
mel_shrink = mel_shrink[0][0].numpy()
|
||||||
mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
|
mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
|
||||||
|
|||||||
@@ -43,11 +43,23 @@ if is_vision_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||||
|
else:
|
||||||
|
is_torch_greater_or_equal_than_1_11 = False
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
DEFAULT_FONT_PATH = "ybelkada/fonts"
|
DEFAULT_FONT_PATH = "ybelkada/fonts"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_torch_version():
|
||||||
|
if is_torch_available() and not is_torch_greater_or_equal_than_1_11:
|
||||||
|
raise ImportError(
|
||||||
|
f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use "
|
||||||
|
"Pix2StructImageProcessor. Please upgrade torch."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
|
# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
|
||||||
def torch_extract_patches(image_tensor, patch_height, patch_width):
|
def torch_extract_patches(image_tensor, patch_height, patch_width):
|
||||||
"""
|
"""
|
||||||
@@ -63,6 +75,7 @@ def torch_extract_patches(image_tensor, patch_height, patch_width):
|
|||||||
The width of the patches to extract.
|
The width of the patches to extract.
|
||||||
"""
|
"""
|
||||||
requires_backends(torch_extract_patches, ["torch"])
|
requires_backends(torch_extract_patches, ["torch"])
|
||||||
|
_check_torch_version()
|
||||||
|
|
||||||
image_tensor = image_tensor.unsqueeze(0)
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
|
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
|
||||||
@@ -240,6 +253,7 @@ class Pix2StructImageProcessor(BaseImageProcessor):
|
|||||||
A sequence of `max_patches` flattened patches.
|
A sequence of `max_patches` flattened patches.
|
||||||
"""
|
"""
|
||||||
requires_backends(self.extract_flattened_patches, "torch")
|
requires_backends(self.extract_flattened_patches, "torch")
|
||||||
|
_check_torch_version()
|
||||||
|
|
||||||
# convert to torch
|
# convert to torch
|
||||||
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
|
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_
|
|||||||
|
|
||||||
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
|
||||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||||
|
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
|
||||||
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
|
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
|
||||||
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
|
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ from ...test_image_processing_common import ImageProcessingSavingTestMixin, prep
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||||
|
else:
|
||||||
|
is_torch_greater_or_equal_than_1_11 = False
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -70,6 +74,10 @@ class Pix2StructImageProcessingTester(unittest.TestCase):
|
|||||||
return raw_image
|
return raw_image
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_torch_greater_or_equal_than_1_11,
|
||||||
|
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||||
|
)
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase):
|
class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase):
|
||||||
@@ -237,6 +245,10 @@ class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_torch_greater_or_equal_than_1_11,
|
||||||
|
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||||
|
)
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_vision
|
@require_vision
|
||||||
class Pix2StructImageProcessingTestFourChannels(ImageProcessingSavingTestMixin, unittest.TestCase):
|
class Pix2StructImageProcessingTestFourChannels(ImageProcessingSavingTestMixin, unittest.TestCase):
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ if is_torch_available():
|
|||||||
Pix2StructVisionModel,
|
Pix2StructVisionModel,
|
||||||
)
|
)
|
||||||
from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||||
|
else:
|
||||||
|
is_torch_greater_or_equal_than_1_11 = False
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -697,6 +700,10 @@ def prepare_img():
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_torch_greater_or_equal_than_1_11,
|
||||||
|
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||||
|
)
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -19,9 +19,14 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, require_vision
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
|
||||||
|
else:
|
||||||
|
is_torch_greater_or_equal_than_1_11 = False
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -34,6 +39,10 @@ if is_vision_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not is_torch_greater_or_equal_than_1_11,
|
||||||
|
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
|
||||||
|
)
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
class Pix2StructProcessorTest(unittest.TestCase):
|
class Pix2StructProcessorTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user