Add input_data_format argument, image transforms (#25462)

* Enable specifying input data format - overriding inferring

* Add tests
This commit is contained in:
amyeroberts
2023-08-11 15:09:31 +01:00
committed by GitHub
parent 0acf56224b
commit 11757e2bbd
3 changed files with 133 additions and 23 deletions

View File

@@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase):
image = to_channel_dimension_format(image, "channels_first")
self.assertEqual(image.shape, (3, 4, 5))
# Can pass in input_data_format and works if data format is ambiguous or unknown.
image = np.random.rand(4, 5, 6)
image = to_channel_dimension_format(image, "channels_first", input_channel_dim="channels_last")
self.assertEqual(image.shape, (6, 4, 5))
def test_get_resize_output_image_size(self):
image = np.random.randint(0, 256, (3, 224, 224))
@@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase):
image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
# Test output size = (int(size * height / width), size) if size is an int and height > width and
# input has 4 channels
image = np.random.randint(0, 256, (4, 50, 40))
self.assertEqual(
get_resize_output_image_size(image, 20, default_to_square=False, input_data_format="channels_first"),
(25, 20),
)
# Test correct channel dimension is returned if output size if height == 3
# Defaults to input format - channels first
image = np.random.randint(0, 256, (3, 18, 97))
@@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue(np.all(resized_image >= 0))
self.assertTrue(np.all(resized_image <= 1))
# Check that an image with 4 channels is resized correctly
image = np.random.randint(0, 256, (4, 224, 224))
resized_image = resize(image, (30, 40), input_data_format="channels_first")
self.assertIsInstance(resized_image, np.ndarray)
self.assertEqual(resized_image.shape, (4, 30, 40))
def test_normalize(self):
image = np.random.randint(0, 256, (224, 224, 3)) / 255
@@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(normalized_image.shape, (3, 224, 224))
self.assertTrue(np.allclose(normalized_image, expected_image))
# Test image with 4 channels is normalized correctly
image = np.random.randint(0, 256, (224, 224, 4)) / 255
mean = (0.5, 0.6, 0.7, 0.8)
std = (0.1, 0.2, 0.3, 0.4)
expected_image = (image - mean) / std
self.assertTrue(
np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image)
)
def test_center_crop(self):
image = np.random.randint(0, 256, (3, 224, 224))
@@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(cropped_image.shape, (300, 260, 3))
self.assertTrue(np.allclose(cropped_image, expected_image))
# Test image with 4 channels is cropped correctly
image = np.random.randint(0, 256, (224, 224, 4))
expected_image = image[52:172, 82:142, :]
self.assertTrue(np.allclose(center_crop(image, (120, 60), input_data_format="channels_last"), expected_image))
def test_center_to_corners_format(self):
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
@@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase):
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
)
# Test we can pad on an image with 2 channels
# fmt: off
image = np.array([
[[0, 1], [2, 3]],
])
expected_image = np.array([
[[0, 0], [0, 1], [2, 3]],
[[0, 0], [0, 0], [0, 0]],
])
# fmt: on
self.assertTrue(
np.allclose(
expected_image, pad(image, ((0, 1), (1, 0)), mode="constant", input_data_format="channels_last")
)
)
@require_vision
def test_convert_to_rgb(self):
# Test that an RGBA image is converted to RGB
@@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue(
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
)
# Can flip when the image has 2 channels
# fmt: off
img_channels_first = np.array([
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
])
# fmt: on
flipped_img_channels_first = img_channels_first[::-1, :, :]
self.assertTrue(
np.allclose(
flip_channel_order(img_channels_first, input_data_format="channels_first"), flipped_img_channels_first
)
)