Put load_image function in image_utils.py & fix image rotation issue (#14062)
* Fix img load rotation * Add `load_image` to `image_utils.py` * Implement LoadImageTester * Use hf-internal-testing dataset * Add img utils comments * Refactor LoadImageTester * Import load_image under is_vision_available
This commit is contained in:
@@ -13,10 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
import PIL.ImageOps
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from .file_utils import _is_torch, is_torch_available
|
from .file_utils import _is_torch, is_torch_available
|
||||||
|
|
||||||
@@ -35,6 +39,39 @@ def is_torch_tensor(obj):
|
|||||||
return _is_torch(obj) if is_torch_available() else False
|
return _is_torch(obj) if is_torch_available() else False
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
|
||||||
|
"""
|
||||||
|
Loads :obj:`image` to a PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (:obj:`str` or :obj:`PIL.Image.Image`):
|
||||||
|
The image to convert to the PIL Image format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`PIL.Image.Image`: A PIL Image.
|
||||||
|
"""
|
||||||
|
if isinstance(image, str):
|
||||||
|
if image.startswith("http://") or image.startswith("https://"):
|
||||||
|
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||||
|
# like http_huggingface_co.png
|
||||||
|
image = PIL.Image.open(requests.get(image, stream=True).raw)
|
||||||
|
elif os.path.isfile(image):
|
||||||
|
image = PIL.Image.open(image)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||||
|
)
|
||||||
|
elif isinstance(image, PIL.Image.Image):
|
||||||
|
image = image
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
|
||||||
|
)
|
||||||
|
image = PIL.ImageOps.exif_transpose(image)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
# In the future we can add a TF implementation here when we have TF models.
|
# In the future we can add a TF implementation here when we have TF models.
|
||||||
class ImageFeatureExtractionMixin:
|
class ImageFeatureExtractionMixin:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
import os
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
@@ -11,6 +8,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..image_utils import load_image
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
@@ -39,35 +38,13 @@ class ImageClassificationPipeline(Pipeline):
|
|||||||
requires_backends(self, "vision")
|
requires_backends(self, "vision")
|
||||||
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_image(image: Union[str, "Image.Image"]):
|
|
||||||
if isinstance(image, str):
|
|
||||||
if image.startswith("http://") or image.startswith("https://"):
|
|
||||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
|
||||||
# like http_huggingface_co.png
|
|
||||||
image = Image.open(requests.get(image, stream=True).raw)
|
|
||||||
elif os.path.isfile(image):
|
|
||||||
image = Image.open(image)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
|
||||||
)
|
|
||||||
elif isinstance(image, Image.Image):
|
|
||||||
image = image
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
|
|
||||||
)
|
|
||||||
image = image.convert("RGB")
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _sanitize_parameters(self, top_k=None):
|
def _sanitize_parameters(self, top_k=None):
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
postprocess_params["top_k"] = top_k
|
postprocess_params["top_k"] = top_k
|
||||||
return {}, {}, postprocess_params
|
return {}, {}, postprocess_params
|
||||||
|
|
||||||
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
|
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Assign labels to the image(s) passed as inputs.
|
Assign labels to the image(s) passed as inputs.
|
||||||
|
|
||||||
@@ -99,7 +76,7 @@ class ImageClassificationPipeline(Pipeline):
|
|||||||
return super().__call__(images, **kwargs)
|
return super().__call__(images, **kwargs)
|
||||||
|
|
||||||
def preprocess(self, image):
|
def preprocess(self, image):
|
||||||
image = self.load_image(image)
|
image = load_image(image)
|
||||||
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
|
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
@@ -15,6 +12,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..image_utils import load_image
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -49,28 +48,6 @@ class ImageSegmentationPipeline(Pipeline):
|
|||||||
requires_backends(self, "vision")
|
requires_backends(self, "vision")
|
||||||
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)
|
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_image(image: Union[str, "Image.Image"]):
|
|
||||||
if isinstance(image, str):
|
|
||||||
if image.startswith("http://") or image.startswith("https://"):
|
|
||||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
|
||||||
# like http_huggingface_co.png
|
|
||||||
image = Image.open(requests.get(image, stream=True).raw)
|
|
||||||
elif os.path.isfile(image):
|
|
||||||
image = Image.open(image)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
|
||||||
)
|
|
||||||
elif isinstance(image, Image.Image):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
|
||||||
)
|
|
||||||
image = image.convert("RGB")
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(self, **kwargs):
|
||||||
postprocess_kwargs = {}
|
postprocess_kwargs = {}
|
||||||
if "threshold" in kwargs:
|
if "threshold" in kwargs:
|
||||||
@@ -118,7 +95,7 @@ class ImageSegmentationPipeline(Pipeline):
|
|||||||
return torch.no_grad
|
return torch.no_grad
|
||||||
|
|
||||||
def preprocess(self, image):
|
def preprocess(self, image):
|
||||||
image = self.load_image(image)
|
image = load_image(image)
|
||||||
target_size = torch.IntTensor([[image.height, image.width]])
|
target_size = torch.IntTensor([[image.height, image.width]])
|
||||||
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
||||||
inputs["target_size"] = target_size
|
inputs["target_size"] = target_size
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
import os
|
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from ..image_utils import load_image
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@@ -45,28 +43,6 @@ class ObjectDetectionPipeline(Pipeline):
|
|||||||
requires_backends(self, "vision")
|
requires_backends(self, "vision")
|
||||||
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
|
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_image(image: Union[str, "Image.Image"]):
|
|
||||||
if isinstance(image, str):
|
|
||||||
if image.startswith("http://") or image.startswith("https://"):
|
|
||||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
|
||||||
# like http_huggingface_co.png
|
|
||||||
image = Image.open(requests.get(image, stream=True).raw)
|
|
||||||
elif os.path.isfile(image):
|
|
||||||
image = Image.open(image)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
|
||||||
)
|
|
||||||
elif isinstance(image, Image.Image):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
|
|
||||||
)
|
|
||||||
image = image.convert("RGB")
|
|
||||||
return image
|
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(self, **kwargs):
|
||||||
postprocess_kwargs = {}
|
postprocess_kwargs = {}
|
||||||
if "threshold" in kwargs:
|
if "threshold" in kwargs:
|
||||||
@@ -105,7 +81,7 @@ class ObjectDetectionPipeline(Pipeline):
|
|||||||
return super().__call__(*args, **kwargs)
|
return super().__call__(*args, **kwargs)
|
||||||
|
|
||||||
def preprocess(self, image):
|
def preprocess(self, image):
|
||||||
image = self.load_image(image)
|
image = load_image(image)
|
||||||
target_size = torch.IntTensor([[image.height, image.width]])
|
target_size = torch.IntTensor([[image.height, image.width]])
|
||||||
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
inputs = self.feature_extractor(images=[image], return_tensors="pt")
|
||||||
inputs["target_size"] = target_size
|
inputs["target_size"] = target_size
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
@@ -28,6 +29,7 @@ if is_vision_available():
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
|
||||||
from transformers import ImageFeatureExtractionMixin
|
from transformers import ImageFeatureExtractionMixin
|
||||||
|
from transformers.image_utils import load_image
|
||||||
|
|
||||||
|
|
||||||
def get_random_image(height, width):
|
def get_random_image(height, width):
|
||||||
@@ -367,3 +369,68 @@ class ImageFeatureExtractionTester(unittest.TestCase):
|
|||||||
# Check result is consistent with PIL.Image.crop
|
# Check result is consistent with PIL.Image.crop
|
||||||
cropped_image = feature_extractor.center_crop(image, size)
|
cropped_image = feature_extractor.center_crop(image, size)
|
||||||
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
|
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
class LoadImageTester(unittest.TestCase):
|
||||||
|
def test_load_img_local(self):
|
||||||
|
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
img_arr = np.array(img)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
img_arr.shape,
|
||||||
|
(480, 640, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_load_img_rgba(self):
|
||||||
|
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||||
|
|
||||||
|
img = load_image(dataset[0]["file"]) # img with mode RGBA
|
||||||
|
img_arr = np.array(img)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
img_arr.shape,
|
||||||
|
(512, 512, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_load_img_la(self):
|
||||||
|
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||||
|
|
||||||
|
img = load_image(dataset[1]["file"]) # img with mode LA
|
||||||
|
img_arr = np.array(img)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
img_arr.shape,
|
||||||
|
(512, 768, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_load_img_l(self):
|
||||||
|
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||||
|
|
||||||
|
img = load_image(dataset[2]["file"]) # img with mode L
|
||||||
|
img_arr = np.array(img)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
img_arr.shape,
|
||||||
|
(381, 225, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_load_img_exif_transpose(self):
|
||||||
|
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||||
|
img_file = dataset[3]["file"]
|
||||||
|
|
||||||
|
img_without_exif_transpose = PIL.Image.open(img_file)
|
||||||
|
img_arr_without_exif_transpose = np.array(img_without_exif_transpose)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
img_arr_without_exif_transpose.shape,
|
||||||
|
(333, 500, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
img_with_exif_transpose = load_image(img_file)
|
||||||
|
img_arr_with_exif_transpose = np.array(img_with_exif_transpose)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
img_arr_with_exif_transpose.shape,
|
||||||
|
(500, 333, 3),
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user