From 090e3e68963e64b1da55a8b704a3f1679f5968a8 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 5 Apr 2021 15:28:51 -0400 Subject: [PATCH] Add center_crop to ImageFeatureExtractoMixin (#11066) --- src/transformers/image_utils.py | 52 +++++++++++++++++++++++++++++++++ tests/test_image_utils.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 2fd5b4528d..fd6f31e03d 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -156,3 +156,55 @@ class ImageFeatureExtractionMixin: image = self.to_pil_image(image) return image.resize(size, resample=resample) + + def center_crop(self, image, size): + """ + Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to + the size given, it will be padded (so the returned result has the size asked). + + Args: + image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`): + The image to resize. + size (:obj:`int` or :obj:`Tuple[int, int]`): + The size to which crop the image. + """ + self._ensure_format_supported(image) + if not isinstance(size, tuple): + size = (size, size) + + # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width) + image_shape = (image.size[1], image.size[0]) if isinstance(image, PIL.Image.Image) else image.shape[-2:] + top = (image_shape[0] - size[0]) // 2 + bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result. + left = (image_shape[1] - size[1]) // 2 + right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result. + + # For PIL Images we have a method to crop directly. + if isinstance(image, PIL.Image.Image): + return image.crop((left, top, right, bottom)) + + # Check if all the dimensions are inside the image. + if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]: + return image[..., top:bottom, left:right] + + # Otherwise, we may need to pad if the image is too small. Oh joy... + new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1])) + if isinstance(image, np.ndarray): + new_image = np.zeros_like(image, shape=new_shape) + elif is_torch_tensor(image): + new_image = image.new_zeros(new_shape) + + top_pad = (new_shape[-2] - image_shape[0]) // 2 + bottom_pad = top_pad + image_shape[0] + left_pad = (new_shape[-1] - image_shape[1]) // 2 + right_pad = left_pad + image_shape[1] + new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image + + top += top_pad + bottom += top_pad + left += left_pad + right += left_pad + + return new_image[ + ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right) + ] diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index 7f65c25f6d..584cf3f251 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -315,3 +315,55 @@ class ImageFeatureExtractionTester(unittest.TestCase): normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std)) self.assertTrue(torch.equal(normalized_tensor, expected)) + + def test_center_crop_image(self): + feature_extractor = ImageFeatureExtractionMixin() + image = get_random_image(16, 32) + + # Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions. + crop_sizes = [8, (8, 64), 20, (32, 64)] + for size in crop_sizes: + cropped_image = feature_extractor.center_crop(image, size) + self.assertTrue(isinstance(cropped_image, PIL.Image.Image)) + + # PIL Image.size is transposed compared to NumPy or PyTorch (width first instead of height first). + expected_size = (size, size) if isinstance(size, int) else (size[1], size[0]) + self.assertEqual(cropped_image.size, expected_size) + + def test_center_crop_array(self): + feature_extractor = ImageFeatureExtractionMixin() + image = get_random_image(16, 32) + array = feature_extractor.to_numpy_array(image) + + # Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions. + crop_sizes = [8, (8, 64), 20, (32, 64)] + for size in crop_sizes: + cropped_array = feature_extractor.center_crop(array, size) + self.assertTrue(isinstance(cropped_array, np.ndarray)) + + expected_size = (size, size) if isinstance(size, int) else size + self.assertEqual(cropped_array.shape[-2:], expected_size) + + # Check result is consistent with PIL.Image.crop + cropped_image = feature_extractor.center_crop(image, size) + self.assertTrue(np.array_equal(cropped_array, feature_extractor.to_numpy_array(cropped_image))) + + @require_torch + def test_center_crop_tensor(self): + feature_extractor = ImageFeatureExtractionMixin() + image = get_random_image(16, 32) + array = feature_extractor.to_numpy_array(image) + tensor = torch.tensor(array) + + # Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions. + crop_sizes = [8, (8, 64), 20, (32, 64)] + for size in crop_sizes: + cropped_tensor = feature_extractor.center_crop(tensor, size) + self.assertTrue(isinstance(cropped_tensor, torch.Tensor)) + + expected_size = (size, size) if isinstance(size, int) else size + self.assertEqual(cropped_tensor.shape[-2:], expected_size) + + # Check result is consistent with PIL.Image.crop + cropped_image = feature_extractor.center_crop(image, size) + self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))