Move convert_to_rgb to image_transforms module (#20784)
* Move convert_to_rgb to image_transforms module * Fix tests
This commit is contained in:
@@ -20,6 +20,7 @@ import numpy as np
|
|||||||
|
|
||||||
from transformers.image_utils import (
|
from transformers.image_utils import (
|
||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
get_channel_dimension_axis,
|
get_channel_dimension_axis,
|
||||||
get_image_size,
|
get_image_size,
|
||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
@@ -687,3 +688,22 @@ def pad(
|
|||||||
|
|
||||||
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
|
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
|
||||||
|
def convert_to_rgb(image: ImageInput) -> ImageInput:
|
||||||
|
"""
|
||||||
|
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
|
||||||
|
as is.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image):
|
||||||
|
The image to convert.
|
||||||
|
"""
|
||||||
|
requires_backends(convert_to_rgb, ["vision"])
|
||||||
|
|
||||||
|
if not isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
|
||||||
|
image = image.convert("RGB")
|
||||||
|
return image
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for BiT."""
|
"""Image processor class for BiT."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
|||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
center_crop,
|
center_crop,
|
||||||
|
convert_to_rgb,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
normalize,
|
normalize,
|
||||||
rescale,
|
rescale,
|
||||||
@@ -41,20 +42,6 @@ if is_vision_available():
|
|||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
|
||||||
"""
|
|
||||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`PIL.Image.Image`):
|
|
||||||
The image to convert.
|
|
||||||
"""
|
|
||||||
if not isinstance(image, PIL.Image.Image):
|
|
||||||
return image
|
|
||||||
|
|
||||||
return image.convert("RGB")
|
|
||||||
|
|
||||||
|
|
||||||
class BitImageProcessor(BaseImageProcessor):
|
class BitImageProcessor(BaseImageProcessor):
|
||||||
r"""
|
r"""
|
||||||
Constructs a BiT image processor.
|
Constructs a BiT image processor.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for Chinese-CLIP."""
|
"""Image processor class for Chinese-CLIP."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
|||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
center_crop,
|
center_crop,
|
||||||
|
convert_to_rgb,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
normalize,
|
normalize,
|
||||||
rescale,
|
rescale,
|
||||||
@@ -41,20 +42,6 @@ if is_vision_available():
|
|||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
|
||||||
"""
|
|
||||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`PIL.Image.Image`):
|
|
||||||
The image to convert.
|
|
||||||
"""
|
|
||||||
if not isinstance(image, PIL.Image.Image):
|
|
||||||
return image
|
|
||||||
|
|
||||||
return image.convert("RGB")
|
|
||||||
|
|
||||||
|
|
||||||
class ChineseCLIPImageProcessor(BaseImageProcessor):
|
class ChineseCLIPImageProcessor(BaseImageProcessor):
|
||||||
r"""
|
r"""
|
||||||
Constructs a Chinese-CLIP image processor.
|
Constructs a Chinese-CLIP image processor.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for CLIP."""
|
"""Image processor class for CLIP."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
|||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
center_crop,
|
center_crop,
|
||||||
|
convert_to_rgb,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
normalize,
|
normalize,
|
||||||
rescale,
|
rescale,
|
||||||
@@ -41,20 +42,6 @@ if is_vision_available():
|
|||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
|
||||||
"""
|
|
||||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`PIL.Image.Image`):
|
|
||||||
The image to convert.
|
|
||||||
"""
|
|
||||||
if not isinstance(image, PIL.Image.Image):
|
|
||||||
return image
|
|
||||||
|
|
||||||
return image.convert("RGB")
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPImageProcessor(BaseImageProcessor):
|
class CLIPImageProcessor(BaseImageProcessor):
|
||||||
r"""
|
r"""
|
||||||
Constructs a CLIP image processor.
|
Constructs a CLIP image processor.
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Image processor class for ViT hybrid."""
|
"""Image processor class for ViT hybrid."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
|
|||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import (
|
from ...image_transforms import (
|
||||||
center_crop,
|
center_crop,
|
||||||
|
convert_to_rgb,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
normalize,
|
normalize,
|
||||||
rescale,
|
rescale,
|
||||||
@@ -41,21 +42,6 @@ if is_vision_available():
|
|||||||
import PIL
|
import PIL
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bit.image_processing_bit.convert_to_rgb
|
|
||||||
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
|
|
||||||
"""
|
|
||||||
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (`PIL.Image.Image`):
|
|
||||||
The image to convert.
|
|
||||||
"""
|
|
||||||
if not isinstance(image, PIL.Image.Image):
|
|
||||||
return image
|
|
||||||
|
|
||||||
return image.convert("RGB")
|
|
||||||
|
|
||||||
|
|
||||||
class ViTHybridImageProcessor(BaseImageProcessor):
|
class ViTHybridImageProcessor(BaseImageProcessor):
|
||||||
r"""
|
r"""
|
||||||
Constructs a ViT Hybrid image processor.
|
Constructs a ViT Hybrid image processor.
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ if is_vision_available():
|
|||||||
from transformers.image_transforms import (
|
from transformers.image_transforms import (
|
||||||
center_crop,
|
center_crop,
|
||||||
center_to_corners_format,
|
center_to_corners_format,
|
||||||
|
convert_to_rgb,
|
||||||
corners_to_center_format,
|
corners_to_center_format,
|
||||||
get_resize_output_image_size,
|
get_resize_output_image_size,
|
||||||
id_to_rgb,
|
id_to_rgb,
|
||||||
@@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase):
|
|||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
|
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
def test_convert_to_rgb(self):
|
||||||
|
# Test that an RGBA image is converted to RGB
|
||||||
|
image = np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.uint8)
|
||||||
|
pil_image = PIL.Image.fromarray(image)
|
||||||
|
self.assertEqual(pil_image.mode, "RGBA")
|
||||||
|
self.assertEqual(pil_image.size, (2, 1))
|
||||||
|
|
||||||
|
# For the moment, numpy images are returned as is
|
||||||
|
rgb_image = convert_to_rgb(image)
|
||||||
|
self.assertEqual(rgb_image.shape, (1, 2, 4))
|
||||||
|
self.assertTrue(np.allclose(rgb_image, image))
|
||||||
|
|
||||||
|
# And PIL images are converted
|
||||||
|
rgb_image = convert_to_rgb(pil_image)
|
||||||
|
self.assertEqual(rgb_image.mode, "RGB")
|
||||||
|
self.assertEqual(rgb_image.size, (2, 1))
|
||||||
|
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[1, 2, 3], [5, 6, 7]]], dtype=np.uint8)))
|
||||||
|
|
||||||
|
# Test that a grayscale image is converted to RGB
|
||||||
|
image = np.array([[0, 255]], dtype=np.uint8)
|
||||||
|
pil_image = PIL.Image.fromarray(image)
|
||||||
|
self.assertEqual(pil_image.mode, "L")
|
||||||
|
self.assertEqual(pil_image.size, (2, 1))
|
||||||
|
rgb_image = convert_to_rgb(pil_image)
|
||||||
|
self.assertEqual(rgb_image.mode, "RGB")
|
||||||
|
self.assertEqual(rgb_image.size, (2, 1))
|
||||||
|
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))
|
||||||
|
|||||||
Reference in New Issue
Block a user