🚨🚨🚨 Fix rescale ViVit Efficientnet (#25174)

* Fix rescaling bug

* Add tests

* Update integration tests

* Fix up

* Update src/transformers/image_transforms.py

* Update test - new possible order in list
This commit is contained in:
amyeroberts
2023-07-28 19:52:51 +01:00
committed by GitHub
parent 03f98f9683
commit 05cda5df34
7 changed files with 57 additions and 15 deletions

View File

@@ -110,10 +110,11 @@ def rescale(
if not isinstance(image, np.ndarray): if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
image = image.astype(dtype)
rescaled_image = image * scale rescaled_image = image * scale
if data_format is not None: if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format) rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image return rescaled_image

View File

@@ -153,7 +153,13 @@ class EfficientNetImageProcessor(BaseImageProcessor):
**kwargs, **kwargs,
): ):
""" """
Rescale an image by a scale factor. image = image * scale. Rescale an image by a scale factor.
If offset is True, the image is rescaled between [-1, 1].
image = image * scale * 2 - 1
If offset is False, the image is rescaled between [0, 1].
image = image * scale
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
@@ -165,13 +171,12 @@ class EfficientNetImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
scale = scale * 2 if offset else scale
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
if offset: if offset:
rescaled_image = (image - 127.5) * scale rescaled_image = rescaled_image - 1
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(np.float32)
else:
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
return rescaled_image return rescaled_image
def preprocess( def preprocess(

View File

@@ -167,6 +167,7 @@ class VivitImageProcessor(BaseImageProcessor):
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
# Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale
def rescale( def rescale(
self, self,
image: np.ndarray, image: np.ndarray,
@@ -178,23 +179,29 @@ class VivitImageProcessor(BaseImageProcessor):
""" """
Rescale an image by a scale factor. Rescale an image by a scale factor.
If offset is `True`, image scaled between [-1, 1]: image = (image - 127.5) * scale. If offset is `False`, image If offset is True, the image is rescaled between [-1, 1].
scaled between [0, 1]: image = image * scale image = image * scale * 2 - 1
If offset is False, the image is rescaled between [0, 1].
image = image * scale
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
Image to rescale. Image to rescale.
scale (`int` or `float`): scale (`int` or `float`):
Scale to apply to the image. Scale to apply to the image.
offset (`bool`, *optional*): offset (`bool`, *optional*):
Whether to scale the image in both negative and positive directions. Whether to scale the image in both negative and positive directions.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
image = image.astype(np.float32) scale = scale * 2 if offset else scale
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
if offset: if offset:
image = image - (scale / 2) rescaled_image = rescaled_image - 1
return rescale(image, scale=scale, data_format=data_format, **kwargs)
return rescaled_image
def _preprocess_image( def _preprocess_image(
self, self,

View File

@@ -193,3 +193,17 @@ class EfficientNetImageProcessorTest(ImageProcessingSavingTestMixin, unittest.Te
self.image_processor_tester.size["width"], self.image_processor_tester.size["width"],
), ),
) )
def test_rescale(self):
# EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
image_processor = self.image_processing_class(**self.image_processor_dict)
rescaled_image = image_processor.rescale(image, scale=1 / 255)
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
self.assertTrue(np.allclose(rescaled_image, expected_image))
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
expected_image = image.astype(np.float32) / 255.0
self.assertTrue(np.allclose(rescaled_image, expected_image))

View File

@@ -212,3 +212,17 @@ class VivitImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase
self.image_processor_tester.crop_size["width"], self.image_processor_tester.crop_size["width"],
), ),
) )
def test_rescale(self):
# ViVit optionally rescales between -1 and 1 instead of the usual 0 and 1
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
image_processor = self.image_processing_class(**self.image_processor_dict)
rescaled_image = image_processor.rescale(image, scale=1 / 255)
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
self.assertTrue(np.allclose(rescaled_image, expected_image))
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
expected_image = image.astype(np.float32) / 255.0
self.assertTrue(np.allclose(rescaled_image, expected_image))

View File

@@ -345,6 +345,6 @@ class VivitModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
# taken from original model # taken from original model
expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]).to(torch_device) expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))

View File

@@ -85,6 +85,7 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
[ [
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}], [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}], [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}],
[{"score": 0.333, "label": "b"}, {"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}],
], ],
) )