Add object detection + segmentation transforms (#20003)
* Add transforms for object detection * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Better var names & docstring * Remove unused var desc in docstring * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -21,8 +21,16 @@ Most of those are only useful if you are studying the code of the image processo
|
|||||||
|
|
||||||
[[autodoc]] image_transforms.center_crop
|
[[autodoc]] image_transforms.center_crop
|
||||||
|
|
||||||
|
[[autodoc]] image_transforms.center_to_corners_format
|
||||||
|
|
||||||
|
[[autodoc]] image_transforms.corners_to_center_format
|
||||||
|
|
||||||
|
[[autodoc]] image_transforms.id_to_rgb
|
||||||
|
|
||||||
[[autodoc]] image_transforms.normalize
|
[[autodoc]] image_transforms.normalize
|
||||||
|
|
||||||
|
[[autodoc]] image_transforms.rgb_to_id
|
||||||
|
|
||||||
[[autodoc]] image_transforms.rescale
|
[[autodoc]] image_transforms.rescale
|
||||||
|
|
||||||
[[autodoc]] image_transforms.resize
|
[[autodoc]] image_transforms.resize
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.image_utils import PILImageResampling
|
from transformers.utils import TensorType
|
||||||
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
@@ -27,6 +27,7 @@ if is_vision_available():
|
|||||||
|
|
||||||
from .image_utils import (
|
from .image_utils import (
|
||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
|
PILImageResampling,
|
||||||
get_channel_dimension_axis,
|
get_channel_dimension_axis,
|
||||||
get_image_size,
|
get_image_size,
|
||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
@@ -108,7 +109,7 @@ def rescale(
|
|||||||
|
|
||||||
|
|
||||||
def to_pil_image(
|
def to_pil_image(
|
||||||
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.Tensor"],
|
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
||||||
do_rescale: Optional[bool] = None,
|
do_rescale: Optional[bool] = None,
|
||||||
) -> PIL.Image.Image:
|
) -> PIL.Image.Image:
|
||||||
"""
|
"""
|
||||||
@@ -300,6 +301,9 @@ def normalize(
|
|||||||
image = to_numpy_array(image)
|
image = to_numpy_array(image)
|
||||||
image = rescale(image, scale=1 / 255)
|
image = rescale(image, scale=1 / 255)
|
||||||
|
|
||||||
|
if not isinstance(image, np.ndarray):
|
||||||
|
raise ValueError("image must be a numpy array")
|
||||||
|
|
||||||
input_data_format = infer_channel_dimension_format(image)
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
channel_axis = get_channel_dimension_axis(image)
|
channel_axis = get_channel_dimension_axis(image)
|
||||||
num_channels = image.shape[channel_axis]
|
num_channels = image.shape[channel_axis]
|
||||||
@@ -420,3 +424,147 @@ def center_crop(
|
|||||||
new_image = to_pil_image(new_image)
|
new_image = to_pil_image(new_image)
|
||||||
|
|
||||||
return new_image
|
return new_image
|
||||||
|
|
||||||
|
|
||||||
|
def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
|
||||||
|
center_x, center_y, width, height = bboxes_center.unbind(-1)
|
||||||
|
bbox_corners = torch.stack(
|
||||||
|
# top left x, top left y, bottom right x, bottom right y
|
||||||
|
[(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return bbox_corners
|
||||||
|
|
||||||
|
|
||||||
|
def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
|
||||||
|
center_x, center_y, width, height = bboxes_center.T
|
||||||
|
bboxes_corners = np.stack(
|
||||||
|
# top left x, top left y, bottom right x, bottom right y
|
||||||
|
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return bboxes_corners
|
||||||
|
|
||||||
|
|
||||||
|
def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
|
||||||
|
center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
|
||||||
|
bboxes_corners = tf.stack(
|
||||||
|
# top left x, top left y, bottom right x, bottom right y
|
||||||
|
[center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return bboxes_corners
|
||||||
|
|
||||||
|
|
||||||
|
# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
|
||||||
|
def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
|
||||||
|
"""
|
||||||
|
Converts bounding boxes from center format to corners format.
|
||||||
|
|
||||||
|
center format: contains the coordinate for the center of the box and its width, height dimensions
|
||||||
|
(center_x, center_y, width, height)
|
||||||
|
corners format: contains the coodinates for the top-left and bottom-right corners of the box
|
||||||
|
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
||||||
|
"""
|
||||||
|
# Function is used during model forward pass, so we use the input framework if possible, without
|
||||||
|
# converting to numpy
|
||||||
|
if is_torch_tensor(bboxes_center):
|
||||||
|
return _center_to_corners_format_torch(bboxes_center)
|
||||||
|
elif isinstance(bboxes_center, np.ndarray):
|
||||||
|
return _center_to_corners_format_numpy(bboxes_center)
|
||||||
|
elif is_tf_tensor(bboxes_center):
|
||||||
|
return _center_to_corners_format_tf(bboxes_center)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported input type {type(bboxes_center)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
|
||||||
|
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
|
||||||
|
b = [
|
||||||
|
(top_left_x + bottom_right_x) / 2, # center x
|
||||||
|
(top_left_y + bottom_right_y) / 2, # center y
|
||||||
|
(bottom_right_x - top_left_x), # width
|
||||||
|
(bottom_right_y - top_left_y), # height
|
||||||
|
]
|
||||||
|
return torch.stack(b, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
|
||||||
|
top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
|
||||||
|
bboxes_center = np.stack(
|
||||||
|
[
|
||||||
|
(top_left_x + bottom_right_x) / 2, # center x
|
||||||
|
(top_left_y + bottom_right_y) / 2, # center y
|
||||||
|
(bottom_right_x - top_left_x), # width
|
||||||
|
(bottom_right_y - top_left_y), # height
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return bboxes_center
|
||||||
|
|
||||||
|
|
||||||
|
def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
|
||||||
|
top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
|
||||||
|
bboxes_center = tf.stack(
|
||||||
|
[
|
||||||
|
(top_left_x + bottom_right_x) / 2, # center x
|
||||||
|
(top_left_y + bottom_right_y) / 2, # center y
|
||||||
|
(bottom_right_x - top_left_x), # width
|
||||||
|
(bottom_right_y - top_left_y), # height
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
return bboxes_center
|
||||||
|
|
||||||
|
|
||||||
|
def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
|
||||||
|
"""
|
||||||
|
Converts bounding boxes from corners format to center format.
|
||||||
|
|
||||||
|
corners format: contains the coodinates for the top-left and bottom-right corners of the box
|
||||||
|
(top_left_x, top_left_y, bottom_right_x, bottom_right_y)
|
||||||
|
center format: contains the coordinate for the center of the box and its the width, height dimensions
|
||||||
|
(center_x, center_y, width, height)
|
||||||
|
"""
|
||||||
|
# Inverse function accepts different input types so implemented here too
|
||||||
|
if is_torch_tensor(bboxes_corners):
|
||||||
|
return _corners_to_center_format_torch(bboxes_corners)
|
||||||
|
elif isinstance(bboxes_corners, np.ndarray):
|
||||||
|
return _corners_to_center_format_numpy(bboxes_corners)
|
||||||
|
elif is_tf_tensor(bboxes_corners):
|
||||||
|
return _corners_to_center_format_tf(bboxes_corners)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
|
||||||
|
|
||||||
|
|
||||||
|
# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
|
||||||
|
# Copyright (c) 2018, Alexander Kirillov
|
||||||
|
# All rights reserved.
|
||||||
|
def rgb_to_id(color):
|
||||||
|
"""
|
||||||
|
Converts RGB color to unique ID.
|
||||||
|
"""
|
||||||
|
if isinstance(color, np.ndarray) and len(color.shape) == 3:
|
||||||
|
if color.dtype == np.uint8:
|
||||||
|
color = color.astype(np.int32)
|
||||||
|
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
|
||||||
|
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
|
||||||
|
|
||||||
|
|
||||||
|
def id_to_rgb(id_map):
|
||||||
|
"""
|
||||||
|
Converts unique ID to RGB color.
|
||||||
|
"""
|
||||||
|
if isinstance(id_map, np.ndarray):
|
||||||
|
id_map_copy = id_map.copy()
|
||||||
|
rgb_shape = tuple(list(id_map.shape) + [3])
|
||||||
|
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
|
||||||
|
for i in range(3):
|
||||||
|
rgb_map[..., i] = id_map_copy % 256
|
||||||
|
id_map_copy //= 256
|
||||||
|
return rgb_map
|
||||||
|
color = []
|
||||||
|
for _ in range(3):
|
||||||
|
color.append(id_map % 256)
|
||||||
|
id_map //= 256
|
||||||
|
return color
|
||||||
|
|||||||
@@ -36,9 +36,13 @@ if is_vision_available():
|
|||||||
|
|
||||||
from transformers.image_transforms import (
|
from transformers.image_transforms import (
|
||||||
center_crop,
|
center_crop,
|
||||||
|
center_to_corners_format,
|
||||||
|
corners_to_center_format,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
|
id_to_rgb,
|
||||||
normalize,
|
normalize,
|
||||||
resize,
|
resize,
|
||||||
|
rgb_to_id,
|
||||||
to_channel_dimension_format,
|
to_channel_dimension_format,
|
||||||
to_pil_image,
|
to_pil_image,
|
||||||
)
|
)
|
||||||
@@ -178,6 +182,11 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
def test_normalize(self):
|
def test_normalize(self):
|
||||||
image = np.random.randint(0, 256, (224, 224, 3)) / 255
|
image = np.random.randint(0, 256, (224, 224, 3)) / 255
|
||||||
|
|
||||||
|
# Test that exception is raised if inputs are incorrect
|
||||||
|
# Not a numpy array image
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
normalize(5, 5, 5)
|
||||||
|
|
||||||
# Number of mean values != number of channels
|
# Number of mean values != number of channels
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
normalize(image, mean=(0.5, 0.6), std=1)
|
normalize(image, mean=(0.5, 0.6), std=1)
|
||||||
@@ -219,3 +228,64 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertIsInstance(cropped_image, np.ndarray)
|
self.assertIsInstance(cropped_image, np.ndarray)
|
||||||
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
||||||
self.assertTrue(np.allclose(cropped_image, expected_image))
|
self.assertTrue(np.allclose(cropped_image, expected_image))
|
||||||
|
|
||||||
|
def test_center_to_corners_format(self):
|
||||||
|
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
||||||
|
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
||||||
|
self.assertTrue(np.allclose(center_to_corners_format(bbox_center), expected))
|
||||||
|
|
||||||
|
# Check that the function and inverse function are inverse of each other
|
||||||
|
self.assertTrue(np.allclose(corners_to_center_format(center_to_corners_format(bbox_center)), bbox_center))
|
||||||
|
|
||||||
|
def test_corners_to_center_format(self):
|
||||||
|
bbox_corners = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
||||||
|
expected = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
||||||
|
self.assertTrue(np.allclose(corners_to_center_format(bbox_corners), expected))
|
||||||
|
|
||||||
|
# Check that the function and inverse function are inverse of each other
|
||||||
|
self.assertTrue(np.allclose(center_to_corners_format(corners_to_center_format(bbox_corners)), bbox_corners))
|
||||||
|
|
||||||
|
def test_rgb_to_id(self):
|
||||||
|
# test list input
|
||||||
|
rgb = [125, 4, 255]
|
||||||
|
self.assertEqual(rgb_to_id(rgb), 16712829)
|
||||||
|
|
||||||
|
# test numpy array input
|
||||||
|
color = np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[213, 54, 165],
|
||||||
|
[88, 207, 39],
|
||||||
|
[156, 108, 128],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[183, 194, 46],
|
||||||
|
[137, 58, 88],
|
||||||
|
[114, 131, 233],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
|
||||||
|
self.assertTrue(np.allclose(rgb_to_id(color), expected))
|
||||||
|
|
||||||
|
def test_id_to_rgb(self):
|
||||||
|
# test int input
|
||||||
|
self.assertEqual(id_to_rgb(16712829), [125, 4, 255])
|
||||||
|
|
||||||
|
# test array input
|
||||||
|
id_array = np.array([[10827477, 2608984, 8416412], [3064503, 5782153, 15303538]])
|
||||||
|
color = np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[213, 54, 165],
|
||||||
|
[88, 207, 39],
|
||||||
|
[156, 108, 128],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[183, 194, 46],
|
||||||
|
[137, 58, 88],
|
||||||
|
[114, 131, 233],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertTrue(np.allclose(id_to_rgb(id_array), color))
|
||||||
|
|||||||
Reference in New Issue
Block a user