From 81c46679bd447d5cf81cfe3d797ccdc8ea40e5a4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 23 Nov 2022 13:47:59 +0100 Subject: [PATCH] [Image Transformers] to_pil fix float edge cases (#20406) * Correct type checking * up --- src/transformers/image_transforms.py | 2 +- tests/test_image_transforms.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 1909d04e2a..a0a52d6fd4 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -145,7 +145,7 @@ def to_pil_image( image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image # PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed. - do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale + do_rescale = isinstance(image.flat[0], (float, np.float32, np.float64)) if do_rescale is None else do_rescale if do_rescale: image = rescale(image, 255) image = image.astype(np.uint8) diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 618181b004..89a6b135bf 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -61,6 +61,8 @@ class ImageTransformsTester(unittest.TestCase): [ ("numpy_float_channels_first", (3, 4, 5), np.float32), ("numpy_float_channels_last", (4, 5, 3), np.float32), + ("numpy_float_channels_first", (3, 4, 5), np.float64), + ("numpy_float_channels_last", (4, 5, 3), np.float64), ("numpy_int_channels_first", (3, 4, 5), np.int32), ("numpy_uint_channels_first", (3, 4, 5), np.uint8), ] @@ -72,6 +74,27 @@ class ImageTransformsTester(unittest.TestCase): self.assertIsInstance(pil_image, PIL.Image.Image) self.assertEqual(pil_image.size, (5, 4)) + # make sure image is correctly rescaled + self.assertTrue(np.abs(np.asarray(pil_image)).sum() > 0) + + @parameterized.expand( + [ + ("numpy_float_channels_first", (3, 4, 5), np.float32), + ("numpy_float_channels_first", (3, 4, 5), np.float64), + ("numpy_float_channels_last", (4, 5, 3), np.float32), + ("numpy_float_channels_last", (4, 5, 3), np.float64), + ] + ) + @require_vision + def test_to_pil_image_from_float(self, name, image_shape, dtype): + image = np.random.rand(*image_shape).astype(dtype) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + # make sure image is correctly rescaled + self.assertTrue(np.abs(np.asarray(pil_image)).sum() > 0) + @require_tf def test_to_pil_image_from_tensorflow(self): # channels_first