[Image Transformers] to_pil fix float edge cases (#20406)
* Correct type checking * up
This commit is contained in:
committed by
GitHub
parent
1c6309bf79
commit
81c46679bd
@@ -145,7 +145,7 @@ def to_pil_image(
|
|||||||
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
|
||||||
|
|
||||||
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
|
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
|
||||||
do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale
|
do_rescale = isinstance(image.flat[0], (float, np.float32, np.float64)) if do_rescale is None else do_rescale
|
||||||
if do_rescale:
|
if do_rescale:
|
||||||
image = rescale(image, 255)
|
image = rescale(image, 255)
|
||||||
image = image.astype(np.uint8)
|
image = image.astype(np.uint8)
|
||||||
|
|||||||
@@ -61,6 +61,8 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
[
|
[
|
||||||
("numpy_float_channels_first", (3, 4, 5), np.float32),
|
("numpy_float_channels_first", (3, 4, 5), np.float32),
|
||||||
("numpy_float_channels_last", (4, 5, 3), np.float32),
|
("numpy_float_channels_last", (4, 5, 3), np.float32),
|
||||||
|
("numpy_float_channels_first", (3, 4, 5), np.float64),
|
||||||
|
("numpy_float_channels_last", (4, 5, 3), np.float64),
|
||||||
("numpy_int_channels_first", (3, 4, 5), np.int32),
|
("numpy_int_channels_first", (3, 4, 5), np.int32),
|
||||||
("numpy_uint_channels_first", (3, 4, 5), np.uint8),
|
("numpy_uint_channels_first", (3, 4, 5), np.uint8),
|
||||||
]
|
]
|
||||||
@@ -72,6 +74,27 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertIsInstance(pil_image, PIL.Image.Image)
|
self.assertIsInstance(pil_image, PIL.Image.Image)
|
||||||
self.assertEqual(pil_image.size, (5, 4))
|
self.assertEqual(pil_image.size, (5, 4))
|
||||||
|
|
||||||
|
# make sure image is correctly rescaled
|
||||||
|
self.assertTrue(np.abs(np.asarray(pil_image)).sum() > 0)
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
("numpy_float_channels_first", (3, 4, 5), np.float32),
|
||||||
|
("numpy_float_channels_first", (3, 4, 5), np.float64),
|
||||||
|
("numpy_float_channels_last", (4, 5, 3), np.float32),
|
||||||
|
("numpy_float_channels_last", (4, 5, 3), np.float64),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@require_vision
|
||||||
|
def test_to_pil_image_from_float(self, name, image_shape, dtype):
|
||||||
|
image = np.random.rand(*image_shape).astype(dtype)
|
||||||
|
pil_image = to_pil_image(image)
|
||||||
|
self.assertIsInstance(pil_image, PIL.Image.Image)
|
||||||
|
self.assertEqual(pil_image.size, (5, 4))
|
||||||
|
|
||||||
|
# make sure image is correctly rescaled
|
||||||
|
self.assertTrue(np.abs(np.asarray(pil_image)).sum() > 0)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_to_pil_image_from_tensorflow(self):
|
def test_to_pil_image_from_tensorflow(self):
|
||||||
# channels_first
|
# channels_first
|
||||||
|
|||||||
Reference in New Issue
Block a user