diff --git a/docs/source/en/internal/image_processing_utils.mdx b/docs/source/en/internal/image_processing_utils.mdx index 6a35736464..f1658e5552 100644 --- a/docs/source/en/internal/image_processing_utils.mdx +++ b/docs/source/en/internal/image_processing_utils.mdx @@ -21,8 +21,16 @@ Most of those are only useful if you are studying the code of the image processo [[autodoc]] image_transforms.center_crop +[[autodoc]] image_transforms.center_to_corners_format + +[[autodoc]] image_transforms.corners_to_center_format + +[[autodoc]] image_transforms.id_to_rgb + [[autodoc]] image_transforms.normalize +[[autodoc]] image_transforms.rgb_to_id + [[autodoc]] image_transforms.rescale [[autodoc]] image_transforms.resize diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 57bed18054..d8d1d60935 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union import numpy as np -from transformers.image_utils import PILImageResampling +from transformers.utils import TensorType from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available @@ -27,6 +27,7 @@ if is_vision_available(): from .image_utils import ( ChannelDimension, + PILImageResampling, get_channel_dimension_axis, get_image_size, infer_channel_dimension_format, @@ -108,7 +109,7 @@ def rescale( def to_pil_image( - image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.Tensor"], + image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"], do_rescale: Optional[bool] = None, ) -> PIL.Image.Image: """ @@ -300,6 +301,9 @@ def normalize( image = to_numpy_array(image) image = rescale(image, scale=1 / 255) + if not isinstance(image, np.ndarray): + raise ValueError("image must be a numpy array") + input_data_format = infer_channel_dimension_format(image) channel_axis = get_channel_dimension_axis(image) num_channels = image.shape[channel_axis] @@ -420,3 +424,147 @@ def center_crop( new_image = to_pil_image(new_image) return new_image + + +def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor": + center_x, center_y, width, height = bboxes_center.unbind(-1) + bbox_corners = torch.stack( + # top left x, top left y, bottom right x, bottom right y + [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)], + dim=-1, + ) + return bbox_corners + + +def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray: + center_x, center_y, width, height = bboxes_center.T + bboxes_corners = np.stack( + # top left x, top left y, bottom right x, bottom right y + [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], + axis=-1, + ) + return bboxes_corners + + +def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor": + center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1) + bboxes_corners = tf.stack( + # top left x, top left y, bottom right x, bottom right y + [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height], + axis=-1, + ) + return bboxes_corners + + +# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py +def center_to_corners_format(bboxes_center: TensorType) -> TensorType: + """ + Converts bounding boxes from center format to corners format. + + center format: contains the coordinate for the center of the box and its width, height dimensions + (center_x, center_y, width, height) + corners format: contains the coodinates for the top-left and bottom-right corners of the box + (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + """ + # Function is used during model forward pass, so we use the input framework if possible, without + # converting to numpy + if is_torch_tensor(bboxes_center): + return _center_to_corners_format_torch(bboxes_center) + elif isinstance(bboxes_center, np.ndarray): + return _center_to_corners_format_numpy(bboxes_center) + elif is_tf_tensor(bboxes_center): + return _center_to_corners_format_tf(bboxes_center) + + raise ValueError(f"Unsupported input type {type(bboxes_center)}") + + +def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor": + top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1) + b = [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ] + return torch.stack(b, dim=-1) + + +def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray: + top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T + bboxes_center = np.stack( + [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ], + axis=-1, + ) + return bboxes_center + + +def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor": + top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1) + bboxes_center = tf.stack( + [ + (top_left_x + bottom_right_x) / 2, # center x + (top_left_y + bottom_right_y) / 2, # center y + (bottom_right_x - top_left_x), # width + (bottom_right_y - top_left_y), # height + ], + axis=-1, + ) + return bboxes_center + + +def corners_to_center_format(bboxes_corners: TensorType) -> TensorType: + """ + Converts bounding boxes from corners format to center format. + + corners format: contains the coodinates for the top-left and bottom-right corners of the box + (top_left_x, top_left_y, bottom_right_x, bottom_right_y) + center format: contains the coordinate for the center of the box and its the width, height dimensions + (center_x, center_y, width, height) + """ + # Inverse function accepts different input types so implemented here too + if is_torch_tensor(bboxes_corners): + return _corners_to_center_format_torch(bboxes_corners) + elif isinstance(bboxes_corners, np.ndarray): + return _corners_to_center_format_numpy(bboxes_corners) + elif is_tf_tensor(bboxes_corners): + return _corners_to_center_format_tf(bboxes_corners) + + raise ValueError(f"Unsupported input type {type(bboxes_corners)}") + + +# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py +# Copyright (c) 2018, Alexander Kirillov +# All rights reserved. +def rgb_to_id(color): + """ + Converts RGB color to unique ID. + """ + if isinstance(color, np.ndarray) and len(color.shape) == 3: + if color.dtype == np.uint8: + color = color.astype(np.int32) + return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2] + return int(color[0] + 256 * color[1] + 256 * 256 * color[2]) + + +def id_to_rgb(id_map): + """ + Converts unique ID to RGB color. + """ + if isinstance(id_map, np.ndarray): + id_map_copy = id_map.copy() + rgb_shape = tuple(list(id_map.shape) + [3]) + rgb_map = np.zeros(rgb_shape, dtype=np.uint8) + for i in range(3): + rgb_map[..., i] = id_map_copy % 256 + id_map_copy //= 256 + return rgb_map + color = [] + for _ in range(3): + color.append(id_map % 256) + id_map //= 256 + return color diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 28b580945e..d0b7c9ade1 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -36,9 +36,13 @@ if is_vision_available(): from transformers.image_transforms import ( center_crop, + center_to_corners_format, + corners_to_center_format, get_resize_output_image_size, + id_to_rgb, normalize, resize, + rgb_to_id, to_channel_dimension_format, to_pil_image, ) @@ -178,6 +182,11 @@ class ImageTransformsTester(unittest.TestCase): def test_normalize(self): image = np.random.randint(0, 256, (224, 224, 3)) / 255 + # Test that exception is raised if inputs are incorrect + # Not a numpy array image + with self.assertRaises(ValueError): + normalize(5, 5, 5) + # Number of mean values != number of channels with self.assertRaises(ValueError): normalize(image, mean=(0.5, 0.6), std=1) @@ -219,3 +228,64 @@ class ImageTransformsTester(unittest.TestCase): self.assertIsInstance(cropped_image, np.ndarray) self.assertEqual(cropped_image.shape, (300, 260, 3)) self.assertTrue(np.allclose(cropped_image, expected_image)) + + def test_center_to_corners_format(self): + bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) + expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) + self.assertTrue(np.allclose(center_to_corners_format(bbox_center), expected)) + + # Check that the function and inverse function are inverse of each other + self.assertTrue(np.allclose(corners_to_center_format(center_to_corners_format(bbox_center)), bbox_center)) + + def test_corners_to_center_format(self): + bbox_corners = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) + expected = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) + self.assertTrue(np.allclose(corners_to_center_format(bbox_corners), expected)) + + # Check that the function and inverse function are inverse of each other + self.assertTrue(np.allclose(center_to_corners_format(corners_to_center_format(bbox_corners)), bbox_corners)) + + def test_rgb_to_id(self): + # test list input + rgb = [125, 4, 255] + self.assertEqual(rgb_to_id(rgb), 16712829) + + # test numpy array input + color = np.array( + [ + [ + [213, 54, 165], + [88, 207, 39], + [156, 108, 128], + ], + [ + [183, 194, 46], + [137, 58, 88], + [114, 131, 233], + ], + ] + ) + expected = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]]) + self.assertTrue(np.allclose(rgb_to_id(color), expected)) + + def test_id_to_rgb(self): + # test int input + self.assertEqual(id_to_rgb(16712829), [125, 4, 255]) + + # test array input + id_array = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]]) + color = np.array( + [ + [ + [213, 54, 165], + [88, 207, 39], + [156, 108, 128], + ], + [ + [183, 194, 46], + [137, 58, 88], + [114, 131, 233], + ], + ] + ) + self.assertTrue(np.allclose(id_to_rgb(id_array), color))