[Image Transformers] to_pil fix float edge cases (#20406)

* Correct type checking

* up
This commit is contained in:
Patrick von Platen
2022-11-23 13:47:59 +01:00
committed by GitHub
parent 1c6309bf79
commit 81c46679bd
2 changed files with 24 additions and 1 deletions

View File

@@ -61,6 +61,8 @@ class ImageTransformsTester(unittest.TestCase):
[
("numpy_float_channels_first", (3, 4, 5), np.float32),
("numpy_float_channels_last", (4, 5, 3), np.float32),
("numpy_float_channels_first", (3, 4, 5), np.float64),
("numpy_float_channels_last", (4, 5, 3), np.float64),
("numpy_int_channels_first", (3, 4, 5), np.int32),
("numpy_uint_channels_first", (3, 4, 5), np.uint8),
]
@@ -72,6 +74,27 @@ class ImageTransformsTester(unittest.TestCase):
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
# make sure image is correctly rescaled
self.assertTrue(np.abs(np.asarray(pil_image)).sum() > 0)
@parameterized.expand(
[
("numpy_float_channels_first", (3, 4, 5), np.float32),
("numpy_float_channels_first", (3, 4, 5), np.float64),
("numpy_float_channels_last", (4, 5, 3), np.float32),
("numpy_float_channels_last", (4, 5, 3), np.float64),
]
)
@require_vision
def test_to_pil_image_from_float(self, name, image_shape, dtype):
image = np.random.rand(*image_shape).astype(dtype)
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))
# make sure image is correctly rescaled
self.assertTrue(np.abs(np.asarray(pil_image)).sum() > 0)
@require_tf
def test_to_pil_image_from_tensorflow(self):
# channels_first