From 05cda5df3405e6a2ee4ecf8f7e1b2300ebda472e Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 28 Jul 2023 19:52:51 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20=20Fix?= =?UTF-8?q?=20rescale=20ViVit=20Efficientnet=20(#25174)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix rescaling bug * Add tests * Update integration tests * Fix up * Update src/transformers/image_transforms.py * Update test - new possible order in list --- src/transformers/image_transforms.py | 3 ++- .../image_processing_efficientnet.py | 19 ++++++++++++------- .../models/vivit/image_processing_vivit.py | 19 +++++++++++++------ .../test_image_processing_efficientnet.py | 14 ++++++++++++++ .../vivit/test_image_processing_vivit.py | 14 ++++++++++++++ tests/models/vivit/test_modeling_vivit.py | 2 +- ...ipelines_zero_shot_image_classification.py | 1 + 7 files changed, 57 insertions(+), 15 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 63585d5e03..71afaaf268 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -110,10 +110,11 @@ def rescale( if not isinstance(image, np.ndarray): raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") + image = image.astype(dtype) + rescaled_image = image * scale if data_format is not None: rescaled_image = to_channel_dimension_format(rescaled_image, data_format) - rescaled_image = rescaled_image.astype(dtype) return rescaled_image diff --git a/src/transformers/models/efficientnet/image_processing_efficientnet.py b/src/transformers/models/efficientnet/image_processing_efficientnet.py index eaefb9c101..8873a80069 100644 --- a/src/transformers/models/efficientnet/image_processing_efficientnet.py +++ b/src/transformers/models/efficientnet/image_processing_efficientnet.py @@ -153,7 +153,13 @@ class EfficientNetImageProcessor(BaseImageProcessor): **kwargs, ): """ - Rescale an image by a scale factor. image = image * scale. + Rescale an image by a scale factor. + + If offset is True, the image is rescaled between [-1, 1]. + image = image * scale * 2 - 1 + + If offset is False, the image is rescaled between [0, 1]. + image = image * scale Args: image (`np.ndarray`): @@ -165,13 +171,12 @@ class EfficientNetImageProcessor(BaseImageProcessor): data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. """ + scale = scale * 2 if offset else scale + rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs) + if offset: - rescaled_image = (image - 127.5) * scale - if data_format is not None: - rescaled_image = to_channel_dimension_format(rescaled_image, data_format) - rescaled_image = rescaled_image.astype(np.float32) - else: - rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs) + rescaled_image = rescaled_image - 1 + return rescaled_image def preprocess( diff --git a/src/transformers/models/vivit/image_processing_vivit.py b/src/transformers/models/vivit/image_processing_vivit.py index 2aa7a911fe..41666e9999 100644 --- a/src/transformers/models/vivit/image_processing_vivit.py +++ b/src/transformers/models/vivit/image_processing_vivit.py @@ -167,6 +167,7 @@ class VivitImageProcessor(BaseImageProcessor): raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) + # Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale def rescale( self, image: np.ndarray, @@ -178,23 +179,29 @@ class VivitImageProcessor(BaseImageProcessor): """ Rescale an image by a scale factor. - If offset is `True`, image scaled between [-1, 1]: image = (image - 127.5) * scale. If offset is `False`, image - scaled between [0, 1]: image = image * scale + If offset is True, the image is rescaled between [-1, 1]. + image = image * scale * 2 - 1 + + If offset is False, the image is rescaled between [0, 1]. + image = image * scale Args: image (`np.ndarray`): Image to rescale. scale (`int` or `float`): Scale to apply to the image. - offset (`bool`, *optional*): + offset (`bool`, *optional*): Whether to scale the image in both negative and positive directions. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. """ - image = image.astype(np.float32) + scale = scale * 2 if offset else scale + rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs) + if offset: - image = image - (scale / 2) - return rescale(image, scale=scale, data_format=data_format, **kwargs) + rescaled_image = rescaled_image - 1 + + return rescaled_image def _preprocess_image( self, diff --git a/tests/models/efficientnet/test_image_processing_efficientnet.py b/tests/models/efficientnet/test_image_processing_efficientnet.py index 8e4fad3b08..11aee2d01c 100644 --- a/tests/models/efficientnet/test_image_processing_efficientnet.py +++ b/tests/models/efficientnet/test_image_processing_efficientnet.py @@ -193,3 +193,17 @@ class EfficientNetImageProcessorTest(ImageProcessingSavingTestMixin, unittest.Te self.image_processor_tester.size["width"], ), ) + + def test_rescale(self): + # EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1 + image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32) + + image_processor = self.image_processing_class(**self.image_processor_dict) + + rescaled_image = image_processor.rescale(image, scale=1 / 255) + expected_image = image.astype(np.float32) * (2 / 255.0) - 1 + self.assertTrue(np.allclose(rescaled_image, expected_image)) + + rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False) + expected_image = image.astype(np.float32) / 255.0 + self.assertTrue(np.allclose(rescaled_image, expected_image)) diff --git a/tests/models/vivit/test_image_processing_vivit.py b/tests/models/vivit/test_image_processing_vivit.py index f33553db0f..75a3e0264c 100644 --- a/tests/models/vivit/test_image_processing_vivit.py +++ b/tests/models/vivit/test_image_processing_vivit.py @@ -212,3 +212,17 @@ class VivitImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase self.image_processor_tester.crop_size["width"], ), ) + + def test_rescale(self): + # ViVit optionally rescales between -1 and 1 instead of the usual 0 and 1 + image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32) + + image_processor = self.image_processing_class(**self.image_processor_dict) + + rescaled_image = image_processor.rescale(image, scale=1 / 255) + expected_image = image.astype(np.float32) * (2 / 255.0) - 1 + self.assertTrue(np.allclose(rescaled_image, expected_image)) + + rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False) + expected_image = image.astype(np.float32) / 255.0 + self.assertTrue(np.allclose(rescaled_image, expected_image)) diff --git a/tests/models/vivit/test_modeling_vivit.py b/tests/models/vivit/test_modeling_vivit.py index 43db8bad7b..ed032e4bdd 100644 --- a/tests/models/vivit/test_modeling_vivit.py +++ b/tests/models/vivit/test_modeling_vivit.py @@ -345,6 +345,6 @@ class VivitModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.logits.shape, expected_shape) # taken from original model - expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]).to(torch_device) + expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4)) diff --git a/tests/pipelines/test_pipelines_zero_shot_image_classification.py b/tests/pipelines/test_pipelines_zero_shot_image_classification.py index fbbfc78cae..197019f42e 100644 --- a/tests/pipelines/test_pipelines_zero_shot_image_classification.py +++ b/tests/pipelines/test_pipelines_zero_shot_image_classification.py @@ -85,6 +85,7 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase): [ [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}], [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}], + [{"score": 0.333, "label": "b"}, {"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}], ], )