From 11757e2bbd8ef89391ccb9ce0416420e16fa36f9 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Fri, 11 Aug 2023 15:09:31 +0100 Subject: [PATCH] Add input_data_format argument, image transforms (#25462) * Enable specifying input data format - overriding inferring * Add tests --- src/transformers/image_transforms.py | 75 +++++++++++++++++++++------- src/transformers/image_utils.py | 15 ++++-- tests/test_image_transforms.py | 66 ++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 23 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index d1621492d6..3cea0c2d17 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -63,6 +63,8 @@ def to_channel_dimension_format( The image to have its channel dimension set. channel_dim (`ChannelDimension`): The channel dimension format to use. + input_channel_dim (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. Returns: `np.ndarray`: The image with the channel dimension set to `channel_dim`. @@ -88,7 +90,11 @@ def to_channel_dimension_format( def rescale( - image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, dtype=np.float32 + image: np.ndarray, + scale: float, + data_format: Optional[ChannelDimension] = None, + dtype: np.dtype = np.float32, + input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """ Rescales `image` by `scale`. @@ -103,6 +109,8 @@ def rescale( dtype (`np.dtype`, *optional*, defaults to `np.float32`): The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature extractors. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. Returns: `np.ndarray`: The rescaled image. @@ -112,7 +120,7 @@ def rescale( rescaled_image = image * scale if data_format is not None: - rescaled_image = to_channel_dimension_format(rescaled_image, data_format) + rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format) rescaled_image = rescaled_image.astype(dtype) @@ -149,6 +157,7 @@ def _rescale_for_pil_conversion(image): def to_pil_image( image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"], do_rescale: Optional[bool] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> "PIL.Image.Image": """ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if @@ -161,6 +170,8 @@ def to_pil_image( Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default to `True` if the image type is a floating type and casting to `int` would result in a loss of precision, and `False` otherwise. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. Returns: `PIL.Image.Image`: The converted image. @@ -179,7 +190,7 @@ def to_pil_image( raise ValueError("Input image type not supported: {}".format(type(image))) # If the channel as been moved to first dim, we put it back at the end. - image = to_channel_dimension_format(image, ChannelDimension.LAST) + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) # If there is a single channel, we squeeze it, as otherwise PIL can't handle it. image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image @@ -200,6 +211,7 @@ def get_resize_output_image_size( size: Union[int, Tuple[int, int], List[int], Tuple[int]], default_to_square: bool = True, max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> tuple: """ Find the target (height, width) dimension of the output image after resizing given the input image and the desired @@ -225,6 +237,8 @@ def get_resize_output_image_size( than `max_size` after being resized according to `size`, then the image is resized again so that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter than `size`. Only used if `default_to_square` is `False`. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. Returns: `tuple`: The target (height, width) dimension of the output image after resizing. @@ -241,7 +255,7 @@ def get_resize_output_image_size( if default_to_square: return (size, size) - height, width = get_image_size(input_image) + height, width = get_image_size(input_image, input_data_format) short, long = (width, height) if width <= height else (height, width) requested_new_short = size @@ -266,6 +280,7 @@ def resize( reducing_gap: Optional[int] = None, data_format: Optional[ChannelDimension] = None, return_numpy: bool = True, + input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """ Resizes `image` to `(height, width)` specified by `size` using the PIL library. @@ -285,6 +300,8 @@ def resize( return_numpy (`bool`, *optional*, defaults to `True`): Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is returned. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. Returns: `np.ndarray`: The resized image. @@ -298,14 +315,16 @@ def resize( # For all transformations, we want to keep the same data format as the input image unless otherwise specified. # The resized image from PIL will always have channels last, so find the input format first. - data_format = infer_channel_dimension_format(image) if data_format is None else data_format + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + data_format = input_data_format if data_format is None else data_format # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use # the pillow library to resize the image and then convert back to numpy do_rescale = False if not isinstance(image, PIL.Image.Image): do_rescale = _rescale_for_pil_conversion(image) - image = to_pil_image(image, do_rescale=do_rescale) + image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format) height, width = size # PIL images are in the format (width, height) resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap) @@ -330,6 +349,7 @@ def normalize( mean: Union[float, Iterable[float]], std: Union[float, Iterable[float]], data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: """ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`. @@ -345,12 +365,15 @@ def normalize( The standard deviation to use for normalization. data_format (`ChannelDimension`, *optional*): The channel dimension format of the output image. If unset, will use the inferred format from the input. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. """ if not isinstance(image, np.ndarray): raise ValueError("image must be a numpy array") - input_data_format = infer_channel_dimension_format(image) - channel_axis = get_channel_dimension_axis(image) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format) num_channels = image.shape[channel_axis] if isinstance(mean, Iterable): @@ -372,7 +395,7 @@ def normalize( else: image = ((image.T - mean) / std).T - image = to_channel_dimension_format(image, data_format) if data_format is not None else image + image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image return image @@ -380,6 +403,7 @@ def center_crop( image: np.ndarray, size: Tuple[int, int], data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, return_numpy: Optional[bool] = None, ) -> np.ndarray: """ @@ -396,6 +420,11 @@ def center_crop( - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. If unset, will use the inferred format of the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. return_numpy (`bool`, *optional*): Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the previous ImageFeatureExtractionMixin method. @@ -418,13 +447,14 @@ def center_crop( if not isinstance(size, Iterable) or len(size) != 2: raise ValueError("size must have 2 elements representing the height and width of the output image") - input_data_format = infer_channel_dimension_format(image) + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) output_data_format = data_format if data_format is not None else input_data_format # We perform the crop in (C, H, W) format and then convert to the output format - image = to_channel_dimension_format(image, ChannelDimension.FIRST) + image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format) - orig_height, orig_width = get_image_size(image) + orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST) crop_height, crop_width = size crop_height, crop_width = int(crop_height), int(crop_width) @@ -438,7 +468,7 @@ def center_crop( # Check if cropped area is within image boundaries if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width: image = image[..., top:bottom, left:right] - image = to_channel_dimension_format(image, output_data_format) + image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST) return image # Otherwise, we may need to pad if the image is too small. Oh joy... @@ -460,7 +490,7 @@ def center_crop( right += left_pad new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] - new_image = to_channel_dimension_format(new_image, output_data_format) + new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST) if not return_numpy: new_image = to_pil_image(new_image) @@ -705,7 +735,7 @@ def pad( else: raise ValueError(f"Invalid padding mode: {mode}") - image = to_channel_dimension_format(image, data_format) if data_format is not None else image + image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image return image @@ -728,7 +758,11 @@ def convert_to_rgb(image: ImageInput) -> ImageInput: return image -def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray: +def flip_channel_order( + image: np.ndarray, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: """ Flips the channel order of the image. @@ -742,9 +776,14 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format. If unset, will use same as the input image. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. """ + input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format - input_data_format = infer_channel_dimension_format(image) if input_data_format == ChannelDimension.LAST: image = image[..., ::-1] elif input_data_format == ChannelDimension.FIRST: @@ -753,5 +792,5 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension raise ValueError(f"Unsupported channel dimension: {input_data_format}") if data_format is not None: - image = to_channel_dimension_format(image, data_format) + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) return image diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 72ed45492b..a86938e1ff 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -176,23 +176,28 @@ def infer_channel_dimension_format( raise ValueError("Unable to infer channel dimension format") -def get_channel_dimension_axis(image: np.ndarray) -> int: +def get_channel_dimension_axis( + image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None +) -> int: """ Returns the channel dimension axis of the image. Args: image (`np.ndarray`): The image to get the channel dimension axis of. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the image. If `None`, will infer the channel dimension from the image. Returns: The channel dimension axis of the image. """ - channel_dim = infer_channel_dimension_format(image) - if channel_dim == ChannelDimension.FIRST: + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + if input_data_format == ChannelDimension.FIRST: return image.ndim - 3 - elif channel_dim == ChannelDimension.LAST: + elif input_data_format == ChannelDimension.LAST: return image.ndim - 1 - raise ValueError(f"Unsupported data format: {channel_dim}") + raise ValueError(f"Unsupported data format: {input_data_format}") def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 70db390394..2941685e69 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase): image = to_channel_dimension_format(image, "channels_first") self.assertEqual(image.shape, (3, 4, 5)) + # Can pass in input_data_format and works if data format is ambiguous or unknown. + image = np.random.rand(4, 5, 6) + image = to_channel_dimension_format(image, "channels_first", input_channel_dim="channels_last") + self.assertEqual(image.shape, (6, 4, 5)) + def test_get_resize_output_image_size(self): image = np.random.randint(0, 256, (3, 224, 224)) @@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase): image = np.random.randint(0, 256, (3, 50, 40)) self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17)) + # Test output size = (int(size * height / width), size) if size is an int and height > width and + # input has 4 channels + image = np.random.randint(0, 256, (4, 50, 40)) + self.assertEqual( + get_resize_output_image_size(image, 20, default_to_square=False, input_data_format="channels_first"), + (25, 20), + ) + # Test correct channel dimension is returned if output size if height == 3 # Defaults to input format - channels first image = np.random.randint(0, 256, (3, 18, 97)) @@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase): self.assertTrue(np.all(resized_image >= 0)) self.assertTrue(np.all(resized_image <= 1)) + # Check that an image with 4 channels is resized correctly + image = np.random.randint(0, 256, (4, 224, 224)) + resized_image = resize(image, (30, 40), input_data_format="channels_first") + self.assertIsInstance(resized_image, np.ndarray) + self.assertEqual(resized_image.shape, (4, 30, 40)) + def test_normalize(self): image = np.random.randint(0, 256, (224, 224, 3)) / 255 @@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase): self.assertEqual(normalized_image.shape, (3, 224, 224)) self.assertTrue(np.allclose(normalized_image, expected_image)) + # Test image with 4 channels is normalized correctly + image = np.random.randint(0, 256, (224, 224, 4)) / 255 + mean = (0.5, 0.6, 0.7, 0.8) + std = (0.1, 0.2, 0.3, 0.4) + expected_image = (image - mean) / std + self.assertTrue( + np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image) + ) + def test_center_crop(self): image = np.random.randint(0, 256, (3, 224, 224)) @@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase): self.assertEqual(cropped_image.shape, (300, 260, 3)) self.assertTrue(np.allclose(cropped_image, expected_image)) + # Test image with 4 channels is cropped correctly + image = np.random.randint(0, 256, (224, 224, 4)) + expected_image = image[52:172, 82:142, :] + self.assertTrue(np.allclose(center_crop(image, (120, 60), input_data_format="channels_last"), expected_image)) + def test_center_to_corners_format(self): bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) @@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase): np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last")) ) + # Test we can pad on an image with 2 channels + # fmt: off + image = np.array([ + [[0, 1], [2, 3]], + ]) + expected_image = np.array([ + [[0, 0], [0, 1], [2, 3]], + [[0, 0], [0, 0], [0, 0]], + ]) + # fmt: on + self.assertTrue( + np.allclose( + expected_image, pad(image, ((0, 1), (1, 0)), mode="constant", input_data_format="channels_last") + ) + ) + @require_vision def test_convert_to_rgb(self): # Test that an RGBA image is converted to RGB @@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase): self.assertTrue( np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first) ) + + # Can flip when the image has 2 channels + # fmt: off + img_channels_first = np.array([ + [[ 0, 1, 2, 3], + [ 4, 5, 6, 7]], + + [[ 8, 9, 10, 11], + [12, 13, 14, 15]], + ]) + # fmt: on + flipped_img_channels_first = img_channels_first[::-1, :, :] + self.assertTrue( + np.allclose( + flip_channel_order(img_channels_first, input_data_format="channels_first"), flipped_img_channels_first + ) + )