Move convert_to_rgb to image_transforms module (#20784)

* Move convert_to_rgb to image_transforms module

* Fix tests
This commit is contained in:
amyeroberts
2022-12-15 18:47:04 +00:00
committed by GitHub
parent 4bc723f87d
commit 491e951875
6 changed files with 58 additions and 61 deletions

View File

@@ -37,6 +37,7 @@ if is_vision_available():
from transformers.image_transforms import (
center_crop,
center_to_corners_format,
convert_to_rgb,
corners_to_center_format,
get_resize_output_image_size,
id_to_rgb,
@@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue(
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
)
@require_vision
def test_convert_to_rgb(self):
# Test that an RGBA image is converted to RGB
image = np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.uint8)
pil_image = PIL.Image.fromarray(image)
self.assertEqual(pil_image.mode, "RGBA")
self.assertEqual(pil_image.size, (2, 1))
# For the moment, numpy images are returned as is
rgb_image = convert_to_rgb(image)
self.assertEqual(rgb_image.shape, (1, 2, 4))
self.assertTrue(np.allclose(rgb_image, image))
# And PIL images are converted
rgb_image = convert_to_rgb(pil_image)
self.assertEqual(rgb_image.mode, "RGB")
self.assertEqual(rgb_image.size, (2, 1))
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[1, 2, 3], [5, 6, 7]]], dtype=np.uint8)))
# Test that a grayscale image is converted to RGB
image = np.array([[0, 255]], dtype=np.uint8)
pil_image = PIL.Image.fromarray(image)
self.assertEqual(pil_image.mode, "L")
self.assertEqual(pil_image.size, (2, 1))
rgb_image = convert_to_rgb(pil_image)
self.assertEqual(rgb_image.mode, "RGB")
self.assertEqual(rgb_image.size, (2, 1))
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))