From 671569ddf71cd2b3bfd6e116b17052e4d251ef73 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 3 Nov 2021 14:53:05 +0100 Subject: [PATCH] Put `load_image` function in `image_utils.py` & fix image rotation issue (#14062) * Fix img load rotation * Add `load_image` to `image_utils.py` * Implement LoadImageTester * Use hf-internal-testing dataset * Add img utils comments * Refactor LoadImageTester * Import load_image under is_vision_available --- src/transformers/image_utils.py | 37 ++++++++++ .../pipelines/image_classification.py | 31 ++------- .../pipelines/image_segmentation.py | 29 +------- .../pipelines/object_detection.py | 30 +-------- tests/test_image_utils.py | 67 +++++++++++++++++++ 5 files changed, 114 insertions(+), 80 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 6489ee508b..d6cf5badbe 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -13,10 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import List, Union import numpy as np import PIL.Image +import PIL.ImageOps + +import requests from .file_utils import _is_torch, is_torch_available @@ -35,6 +39,39 @@ def is_torch_tensor(obj): return _is_torch(obj) if is_torch_available() else False +def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image": + """ + Loads :obj:`image` to a PIL Image. + + Args: + image (:obj:`str` or :obj:`PIL.Image.Image`): + The image to convert to the PIL Image format. + + Returns: + :obj:`PIL.Image.Image`: A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + # We need to actually check for a real protocol, otherwise it's impossible to use a local file + # like http_huggingface_co.png + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image + + # In the future we can add a TF implementation here when we have TF models. class ImageFeatureExtractionMixin: """ diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index e89940e890..2d2ab68cba 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -1,8 +1,5 @@ -import os from typing import List, Union -import requests - from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -11,6 +8,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline if is_vision_available(): from PIL import Image + from ..image_utils import load_image + if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING @@ -39,35 +38,13 @@ class ImageClassificationPipeline(Pipeline): requires_backends(self, "vision") self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) - @staticmethod - def load_image(image: Union[str, "Image.Image"]): - if isinstance(image, str): - if image.startswith("http://") or image.startswith("https://"): - # We need to actually check for a real protocol, otherwise it's impossible to use a local file - # like http_huggingface_co.png - image = Image.open(requests.get(image, stream=True).raw) - elif os.path.isfile(image): - image = Image.open(image) - else: - raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" - ) - elif isinstance(image, Image.Image): - image = image - else: - raise ValueError( - "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." - ) - image = image.convert("RGB") - return image - def _sanitize_parameters(self, top_k=None): postprocess_params = {} if top_k is not None: postprocess_params["top_k"] = top_k return {}, {}, postprocess_params - def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs): + def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): """ Assign labels to the image(s) passed as inputs. @@ -99,7 +76,7 @@ class ImageClassificationPipeline(Pipeline): return super().__call__(images, **kwargs) def preprocess(self, image): - image = self.load_image(image) + image = load_image(image) model_inputs = self.feature_extractor(images=image, return_tensors="pt") return model_inputs diff --git a/src/transformers/pipelines/image_segmentation.py b/src/transformers/pipelines/image_segmentation.py index 923cbd7236..84a3e67ef6 100644 --- a/src/transformers/pipelines/image_segmentation.py +++ b/src/transformers/pipelines/image_segmentation.py @@ -1,12 +1,9 @@ import base64 import io -import os from typing import Any, Dict, List, Union import numpy as np -import requests - from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -15,6 +12,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline if is_vision_available(): from PIL import Image + from ..image_utils import load_image + if is_torch_available(): import torch @@ -49,28 +48,6 @@ class ImageSegmentationPipeline(Pipeline): requires_backends(self, "vision") self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING) - @staticmethod - def load_image(image: Union[str, "Image.Image"]): - if isinstance(image, str): - if image.startswith("http://") or image.startswith("https://"): - # We need to actually check for a real protocol, otherwise it's impossible to use a local file - # like http_huggingface_co.png - image = Image.open(requests.get(image, stream=True).raw) - elif os.path.isfile(image): - image = Image.open(image) - else: - raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" - ) - elif isinstance(image, Image.Image): - pass - else: - raise ValueError( - "Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image." - ) - image = image.convert("RGB") - return image - def _sanitize_parameters(self, **kwargs): postprocess_kwargs = {} if "threshold" in kwargs: @@ -118,7 +95,7 @@ class ImageSegmentationPipeline(Pipeline): return torch.no_grad def preprocess(self, image): - image = self.load_image(image) + image = load_image(image) target_size = torch.IntTensor([[image.height, image.width]]) inputs = self.feature_extractor(images=[image], return_tensors="pt") inputs["target_size"] = target_size diff --git a/src/transformers/pipelines/object_detection.py b/src/transformers/pipelines/object_detection.py index 6ecdc41f38..0d8df38575 100644 --- a/src/transformers/pipelines/object_detection.py +++ b/src/transformers/pipelines/object_detection.py @@ -1,15 +1,13 @@ -import os from typing import Any, Dict, List, Union -import requests - from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline if is_vision_available(): - from PIL import Image + from ..image_utils import load_image + if is_torch_available(): import torch @@ -45,28 +43,6 @@ class ObjectDetectionPipeline(Pipeline): requires_backends(self, "vision") self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING) - @staticmethod - def load_image(image: Union[str, "Image.Image"]): - if isinstance(image, str): - if image.startswith("http://") or image.startswith("https://"): - # We need to actually check for a real protocol, otherwise it's impossible to use a local file - # like http_huggingface_co.png - image = Image.open(requests.get(image, stream=True).raw) - elif os.path.isfile(image): - image = Image.open(image) - else: - raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" - ) - elif isinstance(image, Image.Image): - pass - else: - raise ValueError( - "Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image." - ) - image = image.convert("RGB") - return image - def _sanitize_parameters(self, **kwargs): postprocess_kwargs = {} if "threshold" in kwargs: @@ -105,7 +81,7 @@ class ObjectDetectionPipeline(Pipeline): return super().__call__(*args, **kwargs) def preprocess(self, image): - image = self.load_image(image) + image = load_image(image) target_size = torch.IntTensor([[image.height, image.width]]) inputs = self.feature_extractor(images=[image], return_tensors="pt") inputs["target_size"] = target_size diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index 584cf3f251..702387ee5f 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -15,6 +15,7 @@ import unittest +import datasets import numpy as np from transformers import is_torch_available, is_vision_available @@ -28,6 +29,7 @@ if is_vision_available(): import PIL.Image from transformers import ImageFeatureExtractionMixin + from transformers.image_utils import load_image def get_random_image(height, width): @@ -367,3 +369,68 @@ class ImageFeatureExtractionTester(unittest.TestCase): # Check result is consistent with PIL.Image.crop cropped_image = feature_extractor.center_crop(image, size) self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image)))) + + +@require_vision +class LoadImageTester(unittest.TestCase): + def test_load_img_local(self): + img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png") + img_arr = np.array(img) + + self.assertEqual( + img_arr.shape, + (480, 640, 3), + ) + + def test_load_img_rgba(self): + dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test") + + img = load_image(dataset[0]["file"]) # img with mode RGBA + img_arr = np.array(img) + + self.assertEqual( + img_arr.shape, + (512, 512, 3), + ) + + def test_load_img_la(self): + dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test") + + img = load_image(dataset[1]["file"]) # img with mode LA + img_arr = np.array(img) + + self.assertEqual( + img_arr.shape, + (512, 768, 3), + ) + + def test_load_img_l(self): + dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test") + + img = load_image(dataset[2]["file"]) # img with mode L + img_arr = np.array(img) + + self.assertEqual( + img_arr.shape, + (381, 225, 3), + ) + + def test_load_img_exif_transpose(self): + dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test") + img_file = dataset[3]["file"] + + img_without_exif_transpose = PIL.Image.open(img_file) + img_arr_without_exif_transpose = np.array(img_without_exif_transpose) + + self.assertEqual( + img_arr_without_exif_transpose.shape, + (333, 500, 3), + ) + + img_with_exif_transpose = load_image(img_file) + img_arr_with_exif_transpose = np.array(img_with_exif_transpose) + + self.assertEqual( + img_arr_with_exif_transpose.shape, + (500, 333, 3), + )