Put load_image function in image_utils.py & fix image rotation issue (#14062)
* Fix img load rotation * Add `load_image` to `image_utils.py` * Implement LoadImageTester * Use hf-internal-testing dataset * Add img utils comments * Refactor LoadImageTester * Import load_image under is_vision_available
This commit is contained in:
@@ -13,10 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
|
||||
import requests
|
||||
|
||||
from .file_utils import _is_torch, is_torch_available
|
||||
|
||||
@@ -35,6 +39,39 @@ def is_torch_tensor(obj):
|
||||
return _is_torch(obj) if is_torch_available() else False
|
||||
|
||||
|
||||
def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
|
||||
"""
|
||||
Loads :obj:`image` to a PIL Image.
|
||||
|
||||
Args:
|
||||
image (:obj:`str` or :obj:`PIL.Image.Image`):
|
||||
The image to convert to the PIL Image format.
|
||||
|
||||
Returns:
|
||||
:obj:`PIL.Image.Image`: A PIL Image.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
# like http_huggingface_co.png
|
||||
image = PIL.Image.open(requests.get(image, stream=True).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = PIL.Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = image
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
|
||||
)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
# In the future we can add a TF implementation here when we have TF models.
|
||||
class ImageFeatureExtractionMixin:
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import requests
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
@@ -11,6 +8,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
@@ -39,35 +38,13 @@ class ImageClassificationPipeline(Pipeline):
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
||||
|
||||
@staticmethod
|
||||
def load_image(image: Union[str, "Image.Image"]):
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
# like http_huggingface_co.png
|
||||
image = Image.open(requests.get(image, stream=True).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
elif isinstance(image, Image.Image):
|
||||
image = image
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
|
||||
)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
def _sanitize_parameters(self, top_k=None):
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
postprocess_params["top_k"] = top_k
|
||||
return {}, {}, postprocess_params
|
||||
|
||||
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
||||
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||
"""
|
||||
Assign labels to the image(s) passed as inputs.
|
||||
|
||||
@@ -99,7 +76,7 @@ class ImageClassificationPipeline(Pipeline):
|
||||
return super().__call__(images, **kwargs)
|
||||
|
||||
def preprocess(self, image):
|
||||
image = self.load_image(image)
|
||||
image = load_image(image)
|
||||
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
|
||||
return model_inputs
|
||||
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
@@ -15,6 +12,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from ..image_utils import load_image
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@@ -49,28 +48,6 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)
|
||||
|
||||
@staticmethod
|
||||
def load_image(image: Union[str, "Image.Image"]):
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
# like http_huggingface_co.png
|
||||
image = Image.open(requests.get(image, stream=True).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
elif isinstance(image, Image.Image):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
||||
)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
postprocess_kwargs = {}
|
||||
if "threshold" in kwargs:
|
||||
@@ -118,7 +95,7 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
return torch.no_grad
|
||||
|
||||
def preprocess(self, image):
|
||||
image = self.load_image(image)
|
||||
image = load_image(image)
|
||||
target_size = torch.IntTensor([[image.height, image.width]])
|
||||
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
||||
inputs["target_size"] = target_size
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import requests
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||
from ..utils import logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
from ..image_utils import load_image
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -45,28 +43,6 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
requires_backends(self, "vision")
|
||||
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
|
||||
|
||||
@staticmethod
|
||||
def load_image(image: Union[str, "Image.Image"]):
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
# like http_huggingface_co.png
|
||||
image = Image.open(requests.get(image, stream=True).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
elif isinstance(image, Image.Image):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
||||
)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
postprocess_kwargs = {}
|
||||
if "threshold" in kwargs:
|
||||
@@ -105,7 +81,7 @@ class ObjectDetectionPipeline(Pipeline):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def preprocess(self, image):
|
||||
image = self.load_image(image)
|
||||
image = load_image(image)
|
||||
target_size = torch.IntTensor([[image.height, image.width]])
|
||||
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
||||
inputs["target_size"] = target_size
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
@@ -28,6 +29,7 @@ if is_vision_available():
|
||||
import PIL.Image
|
||||
|
||||
from transformers import ImageFeatureExtractionMixin
|
||||
from transformers.image_utils import load_image
|
||||
|
||||
|
||||
def get_random_image(height, width):
|
||||
@@ -367,3 +369,68 @@ class ImageFeatureExtractionTester(unittest.TestCase):
|
||||
# Check result is consistent with PIL.Image.crop
|
||||
cropped_image = feature_extractor.center_crop(image, size)
|
||||
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
|
||||
|
||||
|
||||
@require_vision
|
||||
class LoadImageTester(unittest.TestCase):
|
||||
def test_load_img_local(self):
|
||||
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr.shape,
|
||||
(480, 640, 3),
|
||||
)
|
||||
|
||||
def test_load_img_rgba(self):
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
|
||||
img = load_image(dataset[0]["file"]) # img with mode RGBA
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr.shape,
|
||||
(512, 512, 3),
|
||||
)
|
||||
|
||||
def test_load_img_la(self):
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
|
||||
img = load_image(dataset[1]["file"]) # img with mode LA
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr.shape,
|
||||
(512, 768, 3),
|
||||
)
|
||||
|
||||
def test_load_img_l(self):
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
|
||||
img = load_image(dataset[2]["file"]) # img with mode L
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr.shape,
|
||||
(381, 225, 3),
|
||||
)
|
||||
|
||||
def test_load_img_exif_transpose(self):
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
img_file = dataset[3]["file"]
|
||||
|
||||
img_without_exif_transpose = PIL.Image.open(img_file)
|
||||
img_arr_without_exif_transpose = np.array(img_without_exif_transpose)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr_without_exif_transpose.shape,
|
||||
(333, 500, 3),
|
||||
)
|
||||
|
||||
img_with_exif_transpose = load_image(img_file)
|
||||
img_arr_with_exif_transpose = np.array(img_with_exif_transpose)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr_with_exif_transpose.shape,
|
||||
(500, 333, 3),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user