Add input_data_format argument, image transforms (#25462)
* Enable specifying input data format - overriding inferring * Add tests
This commit is contained in:
@@ -63,6 +63,8 @@ def to_channel_dimension_format(
|
|||||||
The image to have its channel dimension set.
|
The image to have its channel dimension set.
|
||||||
channel_dim (`ChannelDimension`):
|
channel_dim (`ChannelDimension`):
|
||||||
The channel dimension format to use.
|
The channel dimension format to use.
|
||||||
|
input_channel_dim (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
|
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
|
||||||
@@ -88,7 +90,11 @@ def to_channel_dimension_format(
|
|||||||
|
|
||||||
|
|
||||||
def rescale(
|
def rescale(
|
||||||
image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, dtype=np.float32
|
image: np.ndarray,
|
||||||
|
scale: float,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
dtype: np.dtype = np.float32,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Rescales `image` by `scale`.
|
Rescales `image` by `scale`.
|
||||||
@@ -103,6 +109,8 @@ def rescale(
|
|||||||
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
||||||
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
|
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
|
||||||
extractors.
|
extractors.
|
||||||
|
input_data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`np.ndarray`: The rescaled image.
|
`np.ndarray`: The rescaled image.
|
||||||
@@ -112,7 +120,7 @@ def rescale(
|
|||||||
|
|
||||||
rescaled_image = image * scale
|
rescaled_image = image * scale
|
||||||
if data_format is not None:
|
if data_format is not None:
|
||||||
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
|
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
|
||||||
|
|
||||||
rescaled_image = rescaled_image.astype(dtype)
|
rescaled_image = rescaled_image.astype(dtype)
|
||||||
|
|
||||||
@@ -149,6 +157,7 @@ def _rescale_for_pil_conversion(image):
|
|||||||
def to_pil_image(
|
def to_pil_image(
|
||||||
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
|
||||||
do_rescale: Optional[bool] = None,
|
do_rescale: Optional[bool] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
) -> "PIL.Image.Image":
|
) -> "PIL.Image.Image":
|
||||||
"""
|
"""
|
||||||
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
||||||
@@ -161,6 +170,8 @@ def to_pil_image(
|
|||||||
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
|
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
|
||||||
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
|
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
|
||||||
and `False` otherwise.
|
and `False` otherwise.
|
||||||
|
input_data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`PIL.Image.Image`: The converted image.
|
`PIL.Image.Image`: The converted image.
|
||||||
@@ -179,7 +190,7 @@ def to_pil_image(
|
|||||||
raise ValueError("Input image type not supported: {}".format(type(image)))
|
raise ValueError("Input image type not supported: {}".format(type(image)))
|
||||||
|
|
||||||
# If the channel as been moved to first dim, we put it back at the end.
|
# If the channel as been moved to first dim, we put it back at the end.
|
||||||
image = to_channel_dimension_format(image, ChannelDimension.LAST)
|
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
|
||||||
|
|
||||||
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
||||||
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
||||||
@@ -200,6 +211,7 @@ def get_resize_output_image_size(
|
|||||||
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||||
default_to_square: bool = True,
|
default_to_square: bool = True,
|
||||||
max_size: Optional[int] = None,
|
max_size: Optional[int] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
|
||||||
@@ -225,6 +237,8 @@ def get_resize_output_image_size(
|
|||||||
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
|
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
|
||||||
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
|
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
|
||||||
than `size`. Only used if `default_to_square` is `False`.
|
than `size`. Only used if `default_to_square` is `False`.
|
||||||
|
input_data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`tuple`: The target (height, width) dimension of the output image after resizing.
|
`tuple`: The target (height, width) dimension of the output image after resizing.
|
||||||
@@ -241,7 +255,7 @@ def get_resize_output_image_size(
|
|||||||
if default_to_square:
|
if default_to_square:
|
||||||
return (size, size)
|
return (size, size)
|
||||||
|
|
||||||
height, width = get_image_size(input_image)
|
height, width = get_image_size(input_image, input_data_format)
|
||||||
short, long = (width, height) if width <= height else (height, width)
|
short, long = (width, height) if width <= height else (height, width)
|
||||||
requested_new_short = size
|
requested_new_short = size
|
||||||
|
|
||||||
@@ -266,6 +280,7 @@ def resize(
|
|||||||
reducing_gap: Optional[int] = None,
|
reducing_gap: Optional[int] = None,
|
||||||
data_format: Optional[ChannelDimension] = None,
|
data_format: Optional[ChannelDimension] = None,
|
||||||
return_numpy: bool = True,
|
return_numpy: bool = True,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
|
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
|
||||||
@@ -285,6 +300,8 @@ def resize(
|
|||||||
return_numpy (`bool`, *optional*, defaults to `True`):
|
return_numpy (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
|
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
|
||||||
returned.
|
returned.
|
||||||
|
input_data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`np.ndarray`: The resized image.
|
`np.ndarray`: The resized image.
|
||||||
@@ -298,14 +315,16 @@ def resize(
|
|||||||
|
|
||||||
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
||||||
# The resized image from PIL will always have channels last, so find the input format first.
|
# The resized image from PIL will always have channels last, so find the input format first.
|
||||||
data_format = infer_channel_dimension_format(image) if data_format is None else data_format
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
data_format = input_data_format if data_format is None else data_format
|
||||||
|
|
||||||
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
|
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
|
||||||
# the pillow library to resize the image and then convert back to numpy
|
# the pillow library to resize the image and then convert back to numpy
|
||||||
do_rescale = False
|
do_rescale = False
|
||||||
if not isinstance(image, PIL.Image.Image):
|
if not isinstance(image, PIL.Image.Image):
|
||||||
do_rescale = _rescale_for_pil_conversion(image)
|
do_rescale = _rescale_for_pil_conversion(image)
|
||||||
image = to_pil_image(image, do_rescale=do_rescale)
|
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
|
||||||
height, width = size
|
height, width = size
|
||||||
# PIL images are in the format (width, height)
|
# PIL images are in the format (width, height)
|
||||||
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
|
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
|
||||||
@@ -330,6 +349,7 @@ def normalize(
|
|||||||
mean: Union[float, Iterable[float]],
|
mean: Union[float, Iterable[float]],
|
||||||
std: Union[float, Iterable[float]],
|
std: Union[float, Iterable[float]],
|
||||||
data_format: Optional[ChannelDimension] = None,
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
|
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
|
||||||
@@ -345,12 +365,15 @@ def normalize(
|
|||||||
The standard deviation to use for normalization.
|
The standard deviation to use for normalization.
|
||||||
data_format (`ChannelDimension`, *optional*):
|
data_format (`ChannelDimension`, *optional*):
|
||||||
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
The channel dimension format of the output image. If unset, will use the inferred format from the input.
|
||||||
|
input_data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format of the input image. If unset, will use the inferred format from the input.
|
||||||
"""
|
"""
|
||||||
if not isinstance(image, np.ndarray):
|
if not isinstance(image, np.ndarray):
|
||||||
raise ValueError("image must be a numpy array")
|
raise ValueError("image must be a numpy array")
|
||||||
|
|
||||||
input_data_format = infer_channel_dimension_format(image)
|
if input_data_format is None:
|
||||||
channel_axis = get_channel_dimension_axis(image)
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
|
||||||
num_channels = image.shape[channel_axis]
|
num_channels = image.shape[channel_axis]
|
||||||
|
|
||||||
if isinstance(mean, Iterable):
|
if isinstance(mean, Iterable):
|
||||||
@@ -372,7 +395,7 @@ def normalize(
|
|||||||
else:
|
else:
|
||||||
image = ((image.T - mean) / std).T
|
image = ((image.T - mean) / std).T
|
||||||
|
|
||||||
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
|
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
@@ -380,6 +403,7 @@ def center_crop(
|
|||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
size: Tuple[int, int],
|
size: Tuple[int, int],
|
||||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
return_numpy: Optional[bool] = None,
|
return_numpy: Optional[bool] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -396,6 +420,11 @@ def center_crop(
|
|||||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
If unset, will use the inferred format of the input image.
|
If unset, will use the inferred format of the input image.
|
||||||
|
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format for the input image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
If unset, will use the inferred format of the input image.
|
||||||
return_numpy (`bool`, *optional*):
|
return_numpy (`bool`, *optional*):
|
||||||
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
|
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
|
||||||
previous ImageFeatureExtractionMixin method.
|
previous ImageFeatureExtractionMixin method.
|
||||||
@@ -418,13 +447,14 @@ def center_crop(
|
|||||||
if not isinstance(size, Iterable) or len(size) != 2:
|
if not isinstance(size, Iterable) or len(size) != 2:
|
||||||
raise ValueError("size must have 2 elements representing the height and width of the output image")
|
raise ValueError("size must have 2 elements representing the height and width of the output image")
|
||||||
|
|
||||||
input_data_format = infer_channel_dimension_format(image)
|
if input_data_format is None:
|
||||||
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
output_data_format = data_format if data_format is not None else input_data_format
|
output_data_format = data_format if data_format is not None else input_data_format
|
||||||
|
|
||||||
# We perform the crop in (C, H, W) format and then convert to the output format
|
# We perform the crop in (C, H, W) format and then convert to the output format
|
||||||
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
|
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
|
||||||
|
|
||||||
orig_height, orig_width = get_image_size(image)
|
orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
|
||||||
crop_height, crop_width = size
|
crop_height, crop_width = size
|
||||||
crop_height, crop_width = int(crop_height), int(crop_width)
|
crop_height, crop_width = int(crop_height), int(crop_width)
|
||||||
|
|
||||||
@@ -438,7 +468,7 @@ def center_crop(
|
|||||||
# Check if cropped area is within image boundaries
|
# Check if cropped area is within image boundaries
|
||||||
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
|
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
|
||||||
image = image[..., top:bottom, left:right]
|
image = image[..., top:bottom, left:right]
|
||||||
image = to_channel_dimension_format(image, output_data_format)
|
image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
# Otherwise, we may need to pad if the image is too small. Oh joy...
|
||||||
@@ -460,7 +490,7 @@ def center_crop(
|
|||||||
right += left_pad
|
right += left_pad
|
||||||
|
|
||||||
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
|
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
|
||||||
new_image = to_channel_dimension_format(new_image, output_data_format)
|
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
|
||||||
|
|
||||||
if not return_numpy:
|
if not return_numpy:
|
||||||
new_image = to_pil_image(new_image)
|
new_image = to_pil_image(new_image)
|
||||||
@@ -705,7 +735,7 @@ def pad(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid padding mode: {mode}")
|
raise ValueError(f"Invalid padding mode: {mode}")
|
||||||
|
|
||||||
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
|
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
@@ -728,7 +758,11 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
|
def flip_channel_order(
|
||||||
|
image: np.ndarray,
|
||||||
|
data_format: Optional[ChannelDimension] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Flips the channel order of the image.
|
Flips the channel order of the image.
|
||||||
|
|
||||||
@@ -742,9 +776,14 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
|
|||||||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
If unset, will use same as the input image.
|
If unset, will use same as the input image.
|
||||||
|
input_data_format (`ChannelDimension`, *optional*):
|
||||||
|
The channel dimension format for the input image. Can be one of:
|
||||||
|
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
If unset, will use the inferred format of the input image.
|
||||||
"""
|
"""
|
||||||
|
input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
|
||||||
|
|
||||||
input_data_format = infer_channel_dimension_format(image)
|
|
||||||
if input_data_format == ChannelDimension.LAST:
|
if input_data_format == ChannelDimension.LAST:
|
||||||
image = image[..., ::-1]
|
image = image[..., ::-1]
|
||||||
elif input_data_format == ChannelDimension.FIRST:
|
elif input_data_format == ChannelDimension.FIRST:
|
||||||
@@ -753,5 +792,5 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
|
|||||||
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
|
||||||
|
|
||||||
if data_format is not None:
|
if data_format is not None:
|
||||||
image = to_channel_dimension_format(image, data_format)
|
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||||
return image
|
return image
|
||||||
|
|||||||
@@ -176,23 +176,28 @@ def infer_channel_dimension_format(
|
|||||||
raise ValueError("Unable to infer channel dimension format")
|
raise ValueError("Unable to infer channel dimension format")
|
||||||
|
|
||||||
|
|
||||||
def get_channel_dimension_axis(image: np.ndarray) -> int:
|
def get_channel_dimension_axis(
|
||||||
|
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the channel dimension axis of the image.
|
Returns the channel dimension axis of the image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image (`np.ndarray`):
|
image (`np.ndarray`):
|
||||||
The image to get the channel dimension axis of.
|
The image to get the channel dimension axis of.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The channel dimension axis of the image.
|
The channel dimension axis of the image.
|
||||||
"""
|
"""
|
||||||
channel_dim = infer_channel_dimension_format(image)
|
if input_data_format is None:
|
||||||
if channel_dim == ChannelDimension.FIRST:
|
input_data_format = infer_channel_dimension_format(image)
|
||||||
|
if input_data_format == ChannelDimension.FIRST:
|
||||||
return image.ndim - 3
|
return image.ndim - 3
|
||||||
elif channel_dim == ChannelDimension.LAST:
|
elif input_data_format == ChannelDimension.LAST:
|
||||||
return image.ndim - 1
|
return image.ndim - 1
|
||||||
raise ValueError(f"Unsupported data format: {channel_dim}")
|
raise ValueError(f"Unsupported data format: {input_data_format}")
|
||||||
|
|
||||||
|
|
||||||
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
|
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
|
||||||
|
|||||||
@@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
image = to_channel_dimension_format(image, "channels_first")
|
image = to_channel_dimension_format(image, "channels_first")
|
||||||
self.assertEqual(image.shape, (3, 4, 5))
|
self.assertEqual(image.shape, (3, 4, 5))
|
||||||
|
|
||||||
|
# Can pass in input_data_format and works if data format is ambiguous or unknown.
|
||||||
|
image = np.random.rand(4, 5, 6)
|
||||||
|
image = to_channel_dimension_format(image, "channels_first", input_channel_dim="channels_last")
|
||||||
|
self.assertEqual(image.shape, (6, 4, 5))
|
||||||
|
|
||||||
def test_get_resize_output_image_size(self):
|
def test_get_resize_output_image_size(self):
|
||||||
image = np.random.randint(0, 256, (3, 224, 224))
|
image = np.random.randint(0, 256, (3, 224, 224))
|
||||||
|
|
||||||
@@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
image = np.random.randint(0, 256, (3, 50, 40))
|
image = np.random.randint(0, 256, (3, 50, 40))
|
||||||
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
|
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
|
||||||
|
|
||||||
|
# Test output size = (int(size * height / width), size) if size is an int and height > width and
|
||||||
|
# input has 4 channels
|
||||||
|
image = np.random.randint(0, 256, (4, 50, 40))
|
||||||
|
self.assertEqual(
|
||||||
|
get_resize_output_image_size(image, 20, default_to_square=False, input_data_format="channels_first"),
|
||||||
|
(25, 20),
|
||||||
|
)
|
||||||
|
|
||||||
# Test correct channel dimension is returned if output size if height == 3
|
# Test correct channel dimension is returned if output size if height == 3
|
||||||
# Defaults to input format - channels first
|
# Defaults to input format - channels first
|
||||||
image = np.random.randint(0, 256, (3, 18, 97))
|
image = np.random.randint(0, 256, (3, 18, 97))
|
||||||
@@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertTrue(np.all(resized_image >= 0))
|
self.assertTrue(np.all(resized_image >= 0))
|
||||||
self.assertTrue(np.all(resized_image <= 1))
|
self.assertTrue(np.all(resized_image <= 1))
|
||||||
|
|
||||||
|
# Check that an image with 4 channels is resized correctly
|
||||||
|
image = np.random.randint(0, 256, (4, 224, 224))
|
||||||
|
resized_image = resize(image, (30, 40), input_data_format="channels_first")
|
||||||
|
self.assertIsInstance(resized_image, np.ndarray)
|
||||||
|
self.assertEqual(resized_image.shape, (4, 30, 40))
|
||||||
|
|
||||||
def test_normalize(self):
|
def test_normalize(self):
|
||||||
image = np.random.randint(0, 256, (224, 224, 3)) / 255
|
image = np.random.randint(0, 256, (224, 224, 3)) / 255
|
||||||
|
|
||||||
@@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertEqual(normalized_image.shape, (3, 224, 224))
|
self.assertEqual(normalized_image.shape, (3, 224, 224))
|
||||||
self.assertTrue(np.allclose(normalized_image, expected_image))
|
self.assertTrue(np.allclose(normalized_image, expected_image))
|
||||||
|
|
||||||
|
# Test image with 4 channels is normalized correctly
|
||||||
|
image = np.random.randint(0, 256, (224, 224, 4)) / 255
|
||||||
|
mean = (0.5, 0.6, 0.7, 0.8)
|
||||||
|
std = (0.1, 0.2, 0.3, 0.4)
|
||||||
|
expected_image = (image - mean) / std
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image)
|
||||||
|
)
|
||||||
|
|
||||||
def test_center_crop(self):
|
def test_center_crop(self):
|
||||||
image = np.random.randint(0, 256, (3, 224, 224))
|
image = np.random.randint(0, 256, (3, 224, 224))
|
||||||
|
|
||||||
@@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
||||||
self.assertTrue(np.allclose(cropped_image, expected_image))
|
self.assertTrue(np.allclose(cropped_image, expected_image))
|
||||||
|
|
||||||
|
# Test image with 4 channels is cropped correctly
|
||||||
|
image = np.random.randint(0, 256, (224, 224, 4))
|
||||||
|
expected_image = image[52:172, 82:142, :]
|
||||||
|
self.assertTrue(np.allclose(center_crop(image, (120, 60), input_data_format="channels_last"), expected_image))
|
||||||
|
|
||||||
def test_center_to_corners_format(self):
|
def test_center_to_corners_format(self):
|
||||||
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
||||||
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
||||||
@@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
|
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test we can pad on an image with 2 channels
|
||||||
|
# fmt: off
|
||||||
|
image = np.array([
|
||||||
|
[[0, 1], [2, 3]],
|
||||||
|
])
|
||||||
|
expected_image = np.array([
|
||||||
|
[[0, 0], [0, 1], [2, 3]],
|
||||||
|
[[0, 0], [0, 0], [0, 0]],
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(
|
||||||
|
expected_image, pad(image, ((0, 1), (1, 0)), mode="constant", input_data_format="channels_last")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
def test_convert_to_rgb(self):
|
def test_convert_to_rgb(self):
|
||||||
# Test that an RGBA image is converted to RGB
|
# Test that an RGBA image is converted to RGB
|
||||||
@@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
|
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Can flip when the image has 2 channels
|
||||||
|
# fmt: off
|
||||||
|
img_channels_first = np.array([
|
||||||
|
[[ 0, 1, 2, 3],
|
||||||
|
[ 4, 5, 6, 7]],
|
||||||
|
|
||||||
|
[[ 8, 9, 10, 11],
|
||||||
|
[12, 13, 14, 15]],
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
flipped_img_channels_first = img_channels_first[::-1, :, :]
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(
|
||||||
|
flip_channel_order(img_channels_first, input_data_format="channels_first"), flipped_img_channels_first
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user