From 1d94d575461a76cb1dcb3ebe6e85f1c85d1dafcd Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Wed, 2 Feb 2022 09:44:22 +0100 Subject: [PATCH] Add option to resize like torchvision's Resize (#15419) * Add torchvision's resize * Rename torch_resize to default_to_square * Apply suggestions from code review * Add support for default_to_square and tuple of length 1 --- src/transformers/image_utils.py | 51 +++++++++++++++++++++++++++---- tests/test_image_utils.py | 53 ++++++++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index e94a360462..951d682c94 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -184,7 +184,7 @@ class ImageFeatureExtractionMixin: else: return (image - mean) / std - def resize(self, image, size, resample=PIL.Image.BILINEAR): + def resize(self, image, size, resample=PIL.Image.BILINEAR, default_to_square=True, max_size=None): """ Resizes `image`. Note that this will trigger a conversion of `image` to a PIL Image. @@ -192,19 +192,58 @@ class ImageFeatureExtractionMixin: image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): The image to resize. size (`int` or `Tuple[int, int]`): - The size to use for resizing the image. + The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be + matched to this. + + If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If + `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to + this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): The filter to user for resampling. + default_to_square (`bool`, *optional*, defaults to `True`): + How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a + square (`size`,`size`). If set to `False`, will replicate + [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) + with support for resizing only the smallest edge and providing an optional `max_size`. + max_size (`int`, *optional*, defaults to `None`): + The maximum allowed for the longer edge of the resized image: if the longer edge of the image is + greater than `max_size` after being resized according to `size`, then the image is resized again so + that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller + edge may be shorter than `size`. Only used if `default_to_square` is `False`. """ self._ensure_format_supported(image) - if isinstance(size, int): - size = (size, size) - elif isinstance(size, list): - size = tuple(size) if not isinstance(image, PIL.Image.Image): image = self.to_pil_image(image) + if isinstance(size, list): + size = tuple(size) + + if isinstance(size, int) or len(size) == 1: + if default_to_square: + size = (size, size) if isinstance(size, int) else (size[0], size[0]) + else: + width, height = image.size + # specified size only for the smallest edge + short, long = (width, height) if width <= height else (height, width) + requested_new_short = size if isinstance(size, int) else size[0] + + if short == requested_new_short: + return image + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + size = (new_short, new_long) if width <= height else (new_long, new_short) + return image.resize(size, resample=resample) def center_crop(self, image, size): diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index 702387ee5f..6c870e3341 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -219,7 +219,7 @@ class ImageFeatureExtractionTester(unittest.TestCase): self.assertTrue(isinstance(resized_image1, PIL.Image.Image)) self.assertEqual(resized_image1.size, (8, 16)) - # Passing and array converts it to a PIL Image. + # Passing an array converts it to a PIL Image. resized_image2 = feature_extractor.resize(array, 8) self.assertTrue(isinstance(resized_image2, PIL.Image.Image)) self.assertEqual(resized_image2.size, (8, 8)) @@ -230,6 +230,57 @@ class ImageFeatureExtractionTester(unittest.TestCase): self.assertEqual(resized_image3.size, (8, 16)) self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3))) + def test_resize_image_and_array_non_default_to_square(self): + feature_extractor = ImageFeatureExtractionMixin() + + heights_widths = [ + # height, width + # square image + (28, 28), + (27, 27), + # rectangular image: h < w + (28, 34), + (29, 35), + # rectangular image: h > w + (34, 28), + (35, 29), + ] + + # single integer or single integer in tuple/list + sizes = [22, 27, 28, 36, [22], (27,)] + + for (height, width), size in zip(heights_widths, sizes): + for max_size in (None, 37, 1000): + image = get_random_image(height, width) + array = np.array(image) + + size = size[0] if isinstance(size, (list, tuple)) else size + # Size can be an int or a tuple of ints. + # If size is an int, smaller edge of the image will be matched to this number. + # i.e, if height > width, then image will be rescaled to (size * height / width, size). + if height < width: + exp_w, exp_h = (int(size * width / height), size) + if max_size is not None and max_size < exp_w: + exp_w, exp_h = max_size, int(max_size * exp_h / exp_w) + elif width < height: + exp_w, exp_h = (size, int(size * height / width)) + if max_size is not None and max_size < exp_h: + exp_w, exp_h = int(max_size * exp_w / exp_h), max_size + else: + exp_w, exp_h = (size, size) + if max_size is not None and max_size < size: + exp_w, exp_h = max_size, max_size + + resized_image = feature_extractor.resize(image, size=size, default_to_square=False, max_size=max_size) + self.assertTrue(isinstance(resized_image, PIL.Image.Image)) + self.assertEqual(resized_image.size, (exp_w, exp_h)) + + # Passing an array converts it to a PIL Image. + resized_image2 = feature_extractor.resize(array, size=size, default_to_square=False, max_size=max_size) + self.assertTrue(isinstance(resized_image2, PIL.Image.Image)) + self.assertEqual(resized_image2.size, (exp_w, exp_h)) + self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2))) + @require_torch def test_resize_tensor(self): feature_extractor = ImageFeatureExtractionMixin()