Add input_data_format argument, image transforms (#25462)
* Enable specifying input data format - overriding inferring * Add tests
This commit is contained in:
@@ -63,6 +63,8 @@ def to_channel_dimension_format(
|
||||
The image to have its channel dimension set.
|
||||
channel_dim (`ChannelDimension`):
|
||||
The channel dimension format to use.
|
||||
input_channel_dim (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
|
||||
@@ -88,7 +90,11 @@ def to_channel_dimension_format(
|
||||
|
||||
|
||||
def rescale(
|
||||
image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, dtype=np.float32
|
||||
image: np.ndarray,
|
||||
scale: float,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
dtype: np.dtype = np.float32,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rescales `image` by `scale`.
|
||||
@@ -103,6 +109,8 @@ def rescale(
|
||||
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
||||
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
|
||||
extractors.
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The rescaled image.
|
||||
@@ -112,7 +120,7 @@ def rescale(
|
||||
|
||||
rescaled_image = image * scale
|
||||
if data_format is not None:
|
||||
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
|
||||
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
|
||||
|
||||
rescaled_image = rescaled_image.astype(dtype)
|
||||
|
||||
@@ -149,6 +157,7 @@ def _rescale_for_pil_conversion(image):
|
||||
def to_pil_image(
|
||||
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
||||
do_rescale: Optional[bool] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> "PIL.Image.Image":
|
||||
"""
|
||||
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
||||
@@ -161,6 +170,8 @@ def to_pil_image(
|
||||
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
|
||||
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
|
||||
and `False` otherwise.
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`: The converted image.
|
||||
@@ -179,7 +190,7 @@ def to_pil_image(
|
||||
raise ValueError("Input image type not supported: {}".format(type(image)))
|
||||
|
||||
# If the channel as been moved to first dim, we put it back at the end.
|
||||
image = to_channel_dimension_format(image, ChannelDimension.LAST)
|
||||
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
|
||||
|
||||
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
||||
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
||||
@@ -200,6 +211,7 @@ def get_resize_output_image_size(
|
||||
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||
default_to_square: bool = True,
|
||||
max_size: Optional[int] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
||||
@@ -225,6 +237,8 @@ def get_resize_output_image_size(
|
||||
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
|
||||
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
|
||||
than `size`. Only used if `default_to_square` is `False`.
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||
|
||||
Returns:
|
||||
`tuple`: The target (height, width) dimension of the output image after resizing.
|
||||
@@ -241,7 +255,7 @@ def get_resize_output_image_size(
|
||||
if default_to_square:
|
||||
return (size, size)
|
||||
|
||||
height, width = get_image_size(input_image)
|
||||
height, width = get_image_size(input_image, input_data_format)
|
||||
short, long = (width, height) if width <= height else (height, width)
|
||||
requested_new_short = size
|
||||
|
||||
@@ -266,6 +280,7 @@ def resize(
|
||||
reducing_gap: Optional[int] = None,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
return_numpy: bool = True,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
|
||||
@@ -285,6 +300,8 @@ def resize(
|
||||
return_numpy (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
|
||||
returned.
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
@@ -298,14 +315,16 @@ def resize(
|
||||
|
||||
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
||||
# The resized image from PIL will always have channels last, so find the input format first.
|
||||
data_format = infer_channel_dimension_format(image) if data_format is None else data_format
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
data_format = input_data_format if data_format is None else data_format
|
||||
|
||||
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
|
||||
# the pillow library to resize the image and then convert back to numpy
|
||||
do_rescale = False
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
do_rescale = _rescale_for_pil_conversion(image)
|
||||
image = to_pil_image(image, do_rescale=do_rescale)
|
||||
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
|
||||
height, width = size
|
||||
# PIL images are in the format (width, height)
|
||||
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
|
||||
@@ -330,6 +349,7 @@ def normalize(
|
||||
mean: Union[float, Iterable[float]],
|
||||
std: Union[float, Iterable[float]],
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
|
||||
@@ -345,12 +365,15 @@ def normalize(
|
||||
The standard deviation to use for normalization.
|
||||
data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||
"""
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError("image must be a numpy array")
|
||||
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
channel_axis = get_channel_dimension_axis(image)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
|
||||
num_channels = image.shape[channel_axis]
|
||||
|
||||
if isinstance(mean, Iterable):
|
||||
@@ -372,7 +395,7 @@ def normalize(
|
||||
else:
|
||||
image = ((image.T - mean) / std).T
|
||||
|
||||
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
|
||||
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
||||
return image
|
||||
|
||||
|
||||
@@ -380,6 +403,7 @@ def center_crop(
|
||||
image: np.ndarray,
|
||||
size: Tuple[int, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
return_numpy: Optional[bool] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
@@ -396,6 +420,11 @@ def center_crop(
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
return_numpy (`bool`, *optional*):
|
||||
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
|
||||
previous ImageFeatureExtractionMixin method.
|
||||
@@ -418,13 +447,14 @@ def center_crop(
|
||||
if not isinstance(size, Iterable) or len(size) != 2:
|
||||
raise ValueError("size must have 2 elements representing the height and width of the output image")
|
||||
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
output_data_format = data_format if data_format is not None else input_data_format
|
||||
|
||||
# We perform the crop in (C, H, W) format and then convert to the output format
|
||||
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
|
||||
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
|
||||
|
||||
orig_height, orig_width = get_image_size(image)
|
||||
orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
|
||||
crop_height, crop_width = size
|
||||
crop_height, crop_width = int(crop_height), int(crop_width)
|
||||
|
||||
@@ -438,7 +468,7 @@ def center_crop(
|
||||
# Check if cropped area is within image boundaries
|
||||
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
|
||||
image = image[..., top:bottom, left:right]
|
||||
image = to_channel_dimension_format(image, output_data_format)
|
||||
image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
|
||||
return image
|
||||
|
||||
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
||||
@@ -460,7 +490,7 @@ def center_crop(
|
||||
right += left_pad
|
||||
|
||||
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
|
||||
new_image = to_channel_dimension_format(new_image, output_data_format)
|
||||
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
|
||||
|
||||
if not return_numpy:
|
||||
new_image = to_pil_image(new_image)
|
||||
@@ -705,7 +735,7 @@ def pad(
|
||||
else:
|
||||
raise ValueError(f"Invalid padding mode: {mode}")
|
||||
|
||||
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
|
||||
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
||||
return image
|
||||
|
||||
|
||||
@@ -728,7 +758,11 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
|
||||
return image
|
||||
|
||||
|
||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
||||
def flip_channel_order(
|
||||
image: np.ndarray,
|
||||
data_format: Optional[ChannelDimension] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Flips the channel order of the image.
|
||||
|
||||
@@ -742,9 +776,14 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
input_data_format (`ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
"""
|
||||
input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
|
||||
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if input_data_format == ChannelDimension.LAST:
|
||||
image = image[..., ::-1]
|
||||
elif input_data_format == ChannelDimension.FIRST:
|
||||
@@ -753,5 +792,5 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
|
||||
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format)
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
return image
|
||||
|
||||
@@ -176,23 +176,28 @@ def infer_channel_dimension_format(
|
||||
raise ValueError("Unable to infer channel dimension format")
|
||||
|
||||
|
||||
def get_channel_dimension_axis(image: np.ndarray) -> int:
|
||||
def get_channel_dimension_axis(
|
||||
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Returns the channel dimension axis of the image.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image to get the channel dimension axis of.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
|
||||
|
||||
Returns:
|
||||
The channel dimension axis of the image.
|
||||
"""
|
||||
channel_dim = infer_channel_dimension_format(image)
|
||||
if channel_dim == ChannelDimension.FIRST:
|
||||
if input_data_format is None:
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
return image.ndim - 3
|
||||
elif channel_dim == ChannelDimension.LAST:
|
||||
elif input_data_format == ChannelDimension.LAST:
|
||||
return image.ndim - 1
|
||||
raise ValueError(f"Unsupported data format: {channel_dim}")
|
||||
raise ValueError(f"Unsupported data format: {input_data_format}")
|
||||
|
||||
|
||||
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
|
||||
|
||||
@@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
image = to_channel_dimension_format(image, "channels_first")
|
||||
self.assertEqual(image.shape, (3, 4, 5))
|
||||
|
||||
# Can pass in input_data_format and works if data format is ambiguous or unknown.
|
||||
image = np.random.rand(4, 5, 6)
|
||||
image = to_channel_dimension_format(image, "channels_first", input_channel_dim="channels_last")
|
||||
self.assertEqual(image.shape, (6, 4, 5))
|
||||
|
||||
def test_get_resize_output_image_size(self):
|
||||
image = np.random.randint(0, 256, (3, 224, 224))
|
||||
|
||||
@@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
image = np.random.randint(0, 256, (3, 50, 40))
|
||||
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
|
||||
|
||||
# Test output size = (int(size * height / width), size) if size is an int and height > width and
|
||||
# input has 4 channels
|
||||
image = np.random.randint(0, 256, (4, 50, 40))
|
||||
self.assertEqual(
|
||||
get_resize_output_image_size(image, 20, default_to_square=False, input_data_format="channels_first"),
|
||||
(25, 20),
|
||||
)
|
||||
|
||||
# Test correct channel dimension is returned if output size if height == 3
|
||||
# Defaults to input format - channels first
|
||||
image = np.random.randint(0, 256, (3, 18, 97))
|
||||
@@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertTrue(np.all(resized_image >= 0))
|
||||
self.assertTrue(np.all(resized_image <= 1))
|
||||
|
||||
# Check that an image with 4 channels is resized correctly
|
||||
image = np.random.randint(0, 256, (4, 224, 224))
|
||||
resized_image = resize(image, (30, 40), input_data_format="channels_first")
|
||||
self.assertIsInstance(resized_image, np.ndarray)
|
||||
self.assertEqual(resized_image.shape, (4, 30, 40))
|
||||
|
||||
def test_normalize(self):
|
||||
image = np.random.randint(0, 256, (224, 224, 3)) / 255
|
||||
|
||||
@@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertEqual(normalized_image.shape, (3, 224, 224))
|
||||
self.assertTrue(np.allclose(normalized_image, expected_image))
|
||||
|
||||
# Test image with 4 channels is normalized correctly
|
||||
image = np.random.randint(0, 256, (224, 224, 4)) / 255
|
||||
mean = (0.5, 0.6, 0.7, 0.8)
|
||||
std = (0.1, 0.2, 0.3, 0.4)
|
||||
expected_image = (image - mean) / std
|
||||
self.assertTrue(
|
||||
np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image)
|
||||
)
|
||||
|
||||
def test_center_crop(self):
|
||||
image = np.random.randint(0, 256, (3, 224, 224))
|
||||
|
||||
@@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
||||
self.assertTrue(np.allclose(cropped_image, expected_image))
|
||||
|
||||
# Test image with 4 channels is cropped correctly
|
||||
image = np.random.randint(0, 256, (224, 224, 4))
|
||||
expected_image = image[52:172, 82:142, :]
|
||||
self.assertTrue(np.allclose(center_crop(image, (120, 60), input_data_format="channels_last"), expected_image))
|
||||
|
||||
def test_center_to_corners_format(self):
|
||||
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
||||
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
||||
@@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
|
||||
)
|
||||
|
||||
# Test we can pad on an image with 2 channels
|
||||
# fmt: off
|
||||
image = np.array([
|
||||
[[0, 1], [2, 3]],
|
||||
])
|
||||
expected_image = np.array([
|
||||
[[0, 0], [0, 1], [2, 3]],
|
||||
[[0, 0], [0, 0], [0, 0]],
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
expected_image, pad(image, ((0, 1), (1, 0)), mode="constant", input_data_format="channels_last")
|
||||
)
|
||||
)
|
||||
|
||||
@require_vision
|
||||
def test_convert_to_rgb(self):
|
||||
# Test that an RGBA image is converted to RGB
|
||||
@@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
|
||||
)
|
||||
|
||||
# Can flip when the image has 2 channels
|
||||
# fmt: off
|
||||
img_channels_first = np.array([
|
||||
[[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7]],
|
||||
|
||||
[[ 8, 9, 10, 11],
|
||||
[12, 13, 14, 15]],
|
||||
])
|
||||
# fmt: on
|
||||
flipped_img_channels_first = img_channels_first[::-1, :, :]
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
flip_channel_order(img_channels_first, input_data_format="channels_first"), flipped_img_channels_first
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user