Add input_data_format argument, image transforms (#25462)
* Enable specifying input data format - overriding inferring * Add tests
This commit is contained in:
@@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
image = to_channel_dimension_format(image, "channels_first")
|
||||
self.assertEqual(image.shape, (3, 4, 5))
|
||||
|
||||
# Can pass in input_data_format and works if data format is ambiguous or unknown.
|
||||
image = np.random.rand(4, 5, 6)
|
||||
image = to_channel_dimension_format(image, "channels_first", input_channel_dim="channels_last")
|
||||
self.assertEqual(image.shape, (6, 4, 5))
|
||||
|
||||
def test_get_resize_output_image_size(self):
|
||||
image = np.random.randint(0, 256, (3, 224, 224))
|
||||
|
||||
@@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
image = np.random.randint(0, 256, (3, 50, 40))
|
||||
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
|
||||
|
||||
# Test output size = (int(size * height / width), size) if size is an int and height > width and
|
||||
# input has 4 channels
|
||||
image = np.random.randint(0, 256, (4, 50, 40))
|
||||
self.assertEqual(
|
||||
get_resize_output_image_size(image, 20, default_to_square=False, input_data_format="channels_first"),
|
||||
(25, 20),
|
||||
)
|
||||
|
||||
# Test correct channel dimension is returned if output size if height == 3
|
||||
# Defaults to input format - channels first
|
||||
image = np.random.randint(0, 256, (3, 18, 97))
|
||||
@@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertTrue(np.all(resized_image >= 0))
|
||||
self.assertTrue(np.all(resized_image <= 1))
|
||||
|
||||
# Check that an image with 4 channels is resized correctly
|
||||
image = np.random.randint(0, 256, (4, 224, 224))
|
||||
resized_image = resize(image, (30, 40), input_data_format="channels_first")
|
||||
self.assertIsInstance(resized_image, np.ndarray)
|
||||
self.assertEqual(resized_image.shape, (4, 30, 40))
|
||||
|
||||
def test_normalize(self):
|
||||
image = np.random.randint(0, 256, (224, 224, 3)) / 255
|
||||
|
||||
@@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertEqual(normalized_image.shape, (3, 224, 224))
|
||||
self.assertTrue(np.allclose(normalized_image, expected_image))
|
||||
|
||||
# Test image with 4 channels is normalized correctly
|
||||
image = np.random.randint(0, 256, (224, 224, 4)) / 255
|
||||
mean = (0.5, 0.6, 0.7, 0.8)
|
||||
std = (0.1, 0.2, 0.3, 0.4)
|
||||
expected_image = (image - mean) / std
|
||||
self.assertTrue(
|
||||
np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image)
|
||||
)
|
||||
|
||||
def test_center_crop(self):
|
||||
image = np.random.randint(0, 256, (3, 224, 224))
|
||||
|
||||
@@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
||||
self.assertTrue(np.allclose(cropped_image, expected_image))
|
||||
|
||||
# Test image with 4 channels is cropped correctly
|
||||
image = np.random.randint(0, 256, (224, 224, 4))
|
||||
expected_image = image[52:172, 82:142, :]
|
||||
self.assertTrue(np.allclose(center_crop(image, (120, 60), input_data_format="channels_last"), expected_image))
|
||||
|
||||
def test_center_to_corners_format(self):
|
||||
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
|
||||
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
|
||||
@@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
|
||||
)
|
||||
|
||||
# Test we can pad on an image with 2 channels
|
||||
# fmt: off
|
||||
image = np.array([
|
||||
[[0, 1], [2, 3]],
|
||||
])
|
||||
expected_image = np.array([
|
||||
[[0, 0], [0, 1], [2, 3]],
|
||||
[[0, 0], [0, 0], [0, 0]],
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
expected_image, pad(image, ((0, 1), (1, 0)), mode="constant", input_data_format="channels_last")
|
||||
)
|
||||
)
|
||||
|
||||
@require_vision
|
||||
def test_convert_to_rgb(self):
|
||||
# Test that an RGBA image is converted to RGB
|
||||
@@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
|
||||
)
|
||||
|
||||
# Can flip when the image has 2 channels
|
||||
# fmt: off
|
||||
img_channels_first = np.array([
|
||||
[[ 0, 1, 2, 3],
|
||||
[ 4, 5, 6, 7]],
|
||||
|
||||
[[ 8, 9, 10, 11],
|
||||
[12, 13, 14, 15]],
|
||||
])
|
||||
# fmt: on
|
||||
flipped_img_channels_first = img_channels_first[::-1, :, :]
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
flip_channel_order(img_channels_first, input_data_format="channels_first"), flipped_img_channels_first
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user