Add object detection + segmentation transforms (#20003)
* Add transforms for object detection * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Better var names & docstring * Remove unused var desc in docstring * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -21,8 +21,16 @@ Most of those are only useful if you are studying the code of the image processo
|
||||
|
||||
[[autodoc]] image_transforms.center_crop
|
||||
|
||||
[[autodoc]] image_transforms.center_to_corners_format
|
||||
|
||||
[[autodoc]] image_transforms.corners_to_center_format
|
||||
|
||||
[[autodoc]] image_transforms.id_to_rgb
|
||||
|
||||
[[autodoc]] image_transforms.normalize
|
||||
|
||||
[[autodoc]] image_transforms.rgb_to_id
|
||||
|
||||
[[autodoc]] image_transforms.rescale
|
||||
|
||||
[[autodoc]] image_transforms.resize
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import TensorType
|
||||
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ if is_vision_available():
|
||||
|
||||
from .image_utils import (
|
||||
ChannelDimension,
|
||||
PILImageResampling,
|
||||
get_channel_dimension_axis,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
@@ -108,7 +109,7 @@ def rescale(
|
||||
|
||||
|
||||
def to_pil_image(
|
||||
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.Tensor"],
|
||||
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
||||
do_rescale: Optional[bool] = None,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
@@ -300,6 +301,9 @@ def normalize(
|
||||
image = to_numpy_array(image)
|
||||
image = rescale(image, scale=1 / 255)
|
||||
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError("image must be a numpy array")
|
||||
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
channel_axis = get_channel_dimension_axis(image)
|
||||
num_channels = image.shape[channel_axis]
|
||||
@@ -420,3 +424,147 @@ def center_crop(
|
||||
new_image = to_pil_image(new_image)
|
||||
|
||||
return new_image
|
||||
|
||||
|
||||
def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
|
||||
center_x, center_y, width, height = bboxes_center.unbind(-1)
|
||||
bbox_corners = torch.stack(
|
||||
# top left x, top left y, bottom right x, bottom right y
|
||||
[(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
|
||||
dim=-1,
|
||||
)
|
||||
return bbox_corners
|
||||
|
||||
|
||||
def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
|
||||
center_x, center_y, width, height = bboxes_center.T
|
||||
bboxes_corners = np.stack(
|
||||
# top left x, top left y, bottom right x, bottom right y
|
||||
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
|
||||
axis=-1,
|
||||
)
|
||||
return bboxes_corners
|
||||
|
||||
|
||||
def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
|
||||
center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
|
||||
bboxes_corners = tf.stack(
|
||||
# top left x, top left y, bottom right x, bottom right y
|
||||
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
|
||||
axis=-1,
|
||||
)
|
||||
return bboxes_corners
|
||||
|
||||
|
||||
# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
||||
def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
|
||||
"""
|
||||
Converts bounding boxes from center format to corners format.
|
||||
|
||||
center format: contains the coordinate for the center of the box and its width, height dimensions
|
||||
(center_x, center_y, width, height)
|
||||
corners format: contains the coodinates for the top-left and bottom-right corners of the box
|
||||
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
||||
"""
|
||||
# Function is used during model forward pass, so we use the input framework if possible, without
|
||||
# converting to numpy
|
||||
if is_torch_tensor(bboxes_center):
|
||||
return _center_to_corners_format_torch(bboxes_center)
|
||||
elif isinstance(bboxes_center, np.ndarray):
|
||||
return _center_to_corners_format_numpy(bboxes_center)
|
||||
elif is_tf_tensor(bboxes_center):
|
||||
return _center_to_corners_format_tf(bboxes_center)
|
||||
|
||||
raise ValueError(f"Unsupported input type {type(bboxes_center)}")
|
||||
|
||||
|
||||
def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
|
||||
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
|
||||
b = [
|
||||
(top_left_x + bottom_right_x) / 2, # center x
|
||||
(top_left_y + bottom_right_y) / 2, # center y
|
||||
(bottom_right_x - top_left_x), # width
|
||||
(bottom_right_y - top_left_y), # height
|
||||
]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
|
||||
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
|
||||
bboxes_center = np.stack(
|
||||
[
|
||||
(top_left_x + bottom_right_x) / 2, # center x
|
||||
(top_left_y + bottom_right_y) / 2, # center y
|
||||
(bottom_right_x - top_left_x), # width
|
||||
(bottom_right_y - top_left_y), # height
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
return bboxes_center
|
||||
|
||||
|
||||
def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
|
||||
top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
|
||||
bboxes_center = tf.stack(
|
||||
[
|
||||
(top_left_x + bottom_right_x) / 2, # center x
|
||||
(top_left_y + bottom_right_y) / 2, # center y
|
||||
(bottom_right_x - top_left_x), # width
|
||||
(bottom_right_y - top_left_y), # height
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
return bboxes_center
|
||||
|
||||
|
||||
def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
|
||||
"""
|
||||
Converts bounding boxes from corners format to center format.
|
||||
|
||||
corners format: contains the coodinates for the top-left and bottom-right corners of the box
|
||||
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
||||
center format: contains the coordinate for the center of the box and its the width, height dimensions
|
||||
(center_x, center_y, width, height)
|
||||
"""
|
||||
# Inverse function accepts different input types so implemented here too
|
||||
if is_torch_tensor(bboxes_corners):
|
||||
return _corners_to_center_format_torch(bboxes_corners)
|
||||
elif isinstance(bboxes_corners, np.ndarray):
|
||||
return _corners_to_center_format_numpy(bboxes_corners)
|
||||
elif is_tf_tensor(bboxes_corners):
|
||||
return _corners_to_center_format_tf(bboxes_corners)
|
||||
|
||||
raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
|
||||
|
||||
|
||||
# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
|
||||
# Copyright (c) 2018, Alexander Kirillov
|
||||
# All rights reserved.
|
||||
def rgb_to_id(color):
|
||||
"""
|
||||
Converts RGB color to unique ID.
|
||||
"""
|
||||
if isinstance(color, np.ndarray) and len(color.shape) == 3:
|
||||
if color.dtype == np.uint8:
|
||||
color = color.astype(np.int32)
|
||||
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
||||
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
||||
|
||||
|
||||
def id_to_rgb(id_map):
|
||||
"""
|
||||
Converts unique ID to RGB color.
|
||||
"""
|
||||
if isinstance(id_map, np.ndarray):
|
||||
id_map_copy = id_map.copy()
|
||||
rgb_shape = tuple(list(id_map.shape) + [3])
|
||||
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
|
||||
for i in range(3):
|
||||
rgb_map[..., i] = id_map_copy % 256
|
||||
id_map_copy //= 256
|
||||
return rgb_map
|
||||
color = []
|
||||
for _ in range(3):
|
||||
color.append(id_map % 256)
|
||||
id_map //= 256
|
||||
return color
|
||||
|
||||
@@ -36,9 +36,13 @@ if is_vision_available():
|
||||
|
||||
from transformers.image_transforms import (
|
||||
center_crop,
|
||||
center_to_corners_format,
|
||||
corners_to_center_format,
|
||||
get_resize_output_image_size,
|
||||
id_to_rgb,
|
||||
normalize,
|
||||
resize,
|
||||
rgb_to_id,
|
||||
to_channel_dimension_format,
|
||||
to_pil_image,
|
||||
)
|
||||
@@ -178,6 +182,11 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
def test_normalize(self):
|
||||
image = np.random.randint(0, 256, (224, 224, 3)) / 255
|
||||
|
||||
# Test that exception is raised if inputs are incorrect
|
||||
# Not a numpy array image
|
||||
with self.assertRaises(ValueError):
|
||||
normalize(5, 5, 5)
|
||||
|
||||
# Number of mean values != number of channels
|
||||
with self.assertRaises(ValueError):
|
||||
normalize(image, mean=(0.5, 0.6), std=1)
|
||||
@@ -219,3 +228,64 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertIsInstance(cropped_image, np.ndarray)
|
||||
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
||||
self.assertTrue(np.allclose(cropped_image, expected_image))
|
||||
|
||||
def test_center_to_corners_format(self):
|
||||
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
||||
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
||||
self.assertTrue(np.allclose(center_to_corners_format(bbox_center), expected))
|
||||
|
||||
# Check that the function and inverse function are inverse of each other
|
||||
self.assertTrue(np.allclose(corners_to_center_format(center_to_corners_format(bbox_center)), bbox_center))
|
||||
|
||||
def test_corners_to_center_format(self):
|
||||
bbox_corners = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
||||
expected = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
||||
self.assertTrue(np.allclose(corners_to_center_format(bbox_corners), expected))
|
||||
|
||||
# Check that the function and inverse function are inverse of each other
|
||||
self.assertTrue(np.allclose(center_to_corners_format(corners_to_center_format(bbox_corners)), bbox_corners))
|
||||
|
||||
def test_rgb_to_id(self):
|
||||
# test list input
|
||||
rgb = [125, 4, 255]
|
||||
self.assertEqual(rgb_to_id(rgb), 16712829)
|
||||
|
||||
# test numpy array input
|
||||
color = np.array(
|
||||
[
|
||||
[
|
||||
[213, 54, 165],
|
||||
[88, 207, 39],
|
||||
[156, 108, 128],
|
||||
],
|
||||
[
|
||||
[183, 194, 46],
|
||||
[137, 58, 88],
|
||||
[114, 131, 233],
|
||||
],
|
||||
]
|
||||
)
|
||||
expected = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
|
||||
self.assertTrue(np.allclose(rgb_to_id(color), expected))
|
||||
|
||||
def test_id_to_rgb(self):
|
||||
# test int input
|
||||
self.assertEqual(id_to_rgb(16712829), [125, 4, 255])
|
||||
|
||||
# test array input
|
||||
id_array = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
|
||||
color = np.array(
|
||||
[
|
||||
[
|
||||
[213, 54, 165],
|
||||
[88, 207, 39],
|
||||
[156, 108, 128],
|
||||
],
|
||||
[
|
||||
[183, 194, 46],
|
||||
[137, 58, 88],
|
||||
[114, 131, 233],
|
||||
],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(id_to_rgb(id_array), color))
|
||||
|
||||
Reference in New Issue
Block a user