From 5041bc3511d098814598cf1cfc6c6bd20e72c144 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 19 Oct 2022 16:15:01 +0100 Subject: [PATCH] Image transforms add center crop (#19718) * Add center crop to transforms library * Return PIL images if PIL image input by default * Fixup and add docstring * Trigger CI * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * PR comments - move comments; unindent Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../en/internal/image_processing_utils.mdx | 2 + src/transformers/image_transforms.py | 94 +++++++++++++++++++ tests/test_image_transforms.py | 24 +++++ 3 files changed, 120 insertions(+) diff --git a/docs/source/en/internal/image_processing_utils.mdx b/docs/source/en/internal/image_processing_utils.mdx index 857d48f0fe..6450913053 100644 --- a/docs/source/en/internal/image_processing_utils.mdx +++ b/docs/source/en/internal/image_processing_utils.mdx @@ -19,6 +19,8 @@ Most of those are only useful if you are studying the code of the image processo ## Image Transformations +[[autodoc]] image_transforms.center_crop + [[autodoc]] image_transforms.normalize [[autodoc]] image_transforms.rescale diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 289db3ae78..d4826307cb 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -317,3 +317,97 @@ def normalize( image = to_channel_dimension_format(image, data_format) if data_format is not None else image return image + + +def center_crop( + image: np.ndarray, + size: Tuple[int, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + return_numpy: Optional[bool] = None, +) -> np.ndarray: + """ + Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to + the size given, it will be padded (so the returned result will always be of size `size`). + + Args: + image (`np.ndarray`): + The image to crop. + size (`Tuple[int, int]`): + The target size for the cropped image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + return_numpy (`bool`, *optional*): + Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the + previous ImageFeatureExtractionMixin method. + - Unset: will return the same type as the input image. + - `True`: will return a numpy array. + - `False`: will return a `PIL.Image.Image` object. + Returns: + `np.ndarray`: The cropped image. + """ + if isinstance(image, PIL.Image.Image): + warnings.warn( + "PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.", + FutureWarning, + ) + image = to_numpy_array(image) + return_numpy = False if return_numpy is None else return_numpy + else: + return_numpy = True if return_numpy is None else return_numpy + + if not isinstance(image, np.ndarray): + raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") + + if not isinstance(size, Iterable) or len(size) != 2: + raise ValueError("size must have 2 elements representing the height and width of the output image") + + input_data_format = infer_channel_dimension_format(image) + output_data_format = data_format if data_format is not None else input_data_format + + # We perform the crop in (C, H, W) format and then convert to the output format + image = to_channel_dimension_format(image, ChannelDimension.FIRST) + + orig_height, orig_width = get_image_size(image) + crop_height, crop_width = size + + # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. + top = (orig_height - crop_height) // 2 + bottom = top + crop_height + # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. + left = (orig_width - crop_width) // 2 + right = left + crop_width + + # Check if cropped area is within image boundaries + if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width: + image = image[..., top:bottom, left:right] + image = to_channel_dimension_format(image, output_data_format) + return image + + # Otherwise, we may need to pad if the image is too small. Oh joy... + new_height = max(crop_height, orig_height) + new_width = max(crop_width, orig_width) + new_shape = image.shape[:-2] + (new_height, new_width) + new_image = np.zeros_like(image, shape=new_shape) + + # If the image is too small, pad it with zeros + top_pad = (new_height - orig_height) // 2 + bottom_pad = top_pad + orig_height + left_pad = (new_width - orig_width) // 2 + right_pad = left_pad + orig_width + new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image + + top += top_pad + bottom += top_pad + left += left_pad + right += left_pad + + new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] + new_image = to_channel_dimension_format(new_image, output_data_format) + + if not return_numpy: + new_image = to_pil_image(new_image) + + return new_image diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index ee51bd358f..28b580945e 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -35,6 +35,7 @@ if is_vision_available(): import PIL.Image from transformers.image_transforms import ( + center_crop, get_resize_output_image_size, normalize, resize, @@ -195,3 +196,26 @@ class ImageTransformsTester(unittest.TestCase): self.assertIsInstance(normalized_image, np.ndarray) self.assertEqual(normalized_image.shape, (3, 224, 224)) self.assertTrue(np.allclose(normalized_image, expected_image)) + + def test_center_crop(self): + image = np.random.randint(0, 256, (3, 224, 224)) + + # Test that exception is raised if inputs are incorrect + with self.assertRaises(ValueError): + center_crop(image, 10) + + # Test result is correct - output data format is channels_first and center crop + # correctly computed + expected_image = image[:, 52:172, 82:142].transpose(1, 2, 0) + cropped_image = center_crop(image, (120, 60), data_format="channels_last") + self.assertIsInstance(cropped_image, np.ndarray) + self.assertEqual(cropped_image.shape, (120, 60, 3)) + self.assertTrue(np.allclose(cropped_image, expected_image)) + + # Test that image is padded with zeros if crop size is larger than image size + expected_image = np.zeros((300, 260, 3)) + expected_image[38:262, 18:242, :] = image.transpose((1, 2, 0)) + cropped_image = center_crop(image, (300, 260), data_format="channels_last") + self.assertIsInstance(cropped_image, np.ndarray) + self.assertEqual(cropped_image.shape, (300, 260, 3)) + self.assertTrue(np.allclose(cropped_image, expected_image))