From dbc16f4404eca4a75459683d5135f6accea35a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haylee=20Sch=C3=A4fer?= Date: Tue, 29 Aug 2023 20:24:24 +0200 Subject: [PATCH] Support loading base64 images in pipelines (#25633) * support loading base64 images * add test * mention in docs * remove the logging * sort imports * update error message * Update tests/utils/test_image_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * restructure to catch base64 exception * doesn't like the newline * download files * format * optimize imports * guess it needs a space? * support loading base64 images * add test * remove the logging * sort imports * restructure to catch base64 exception * doesn't like the newline * download files * optimize imports * guess it needs a space? --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/pipeline_tutorial.md | 2 +- src/transformers/image_utils.py | 18 ++++++++++---- tests/utils/test_image_utils.py | 37 +++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/docs/source/en/pipeline_tutorial.md b/docs/source/en/pipeline_tutorial.md index 1b13c401b9..e2d728aea3 100644 --- a/docs/source/en/pipeline_tutorial.md +++ b/docs/source/en/pipeline_tutorial.md @@ -204,7 +204,7 @@ page. Using a [`pipeline`] for vision tasks is practically identical. -Specify your task and pass your image to the classifier. The image can be a link or a local path to the image. For example, what species of cat is shown below? +Specify your task and pass your image to the classifier. The image can be a link, a local path or a base64-encoded image. For example, what species of cat is shown below? ![pipeline-cat-chonk](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index d76ee57281..628fe5dea7 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import os +from io import BytesIO from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -298,14 +300,22 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = elif os.path.isfile(image): image = PIL.Image.open(image) else: - raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" - ) + if image.startswith("data:image/"): + image = image.split(",")[1] + + # Try to load as base64 + try: + b64 = base64.b64decode(image, validate=True) + image = PIL.Image.open(BytesIO(b64)) + except Exception as e: + raise ValueError( + f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" + ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( - "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index 0ba901b6c3..1813c2a21f 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest import datasets import numpy as np import pytest +from huggingface_hub.file_download import http_get from requests import ReadTimeout from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL @@ -500,6 +503,40 @@ class LoadImageTester(unittest.TestCase): (480, 640, 3), ) + def test_load_img_base64_prefix(self): + try: + tmp_file = tempfile.mktemp() + with open(tmp_file, "wb") as f: + http_get( + "https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_0.txt", f + ) + + with open(tmp_file, encoding="utf-8") as b64: + img = load_image(b64.read()) + img_arr = np.array(img) + + finally: + os.remove(tmp_file) + + self.assertEqual(img_arr.shape, (64, 32, 3)) + + def test_load_img_base64(self): + try: + tmp_file = tempfile.mktemp() + with open(tmp_file, "wb") as f: + http_get( + "https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_1.txt", f + ) + + with open(tmp_file, encoding="utf-8") as b64: + img = load_image(b64.read()) + img_arr = np.array(img) + + finally: + os.remove(tmp_file) + + self.assertEqual(img_arr.shape, (64, 32, 3)) + def test_load_img_rgba(self): dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")