Support loading base64 images in pipelines (#25633)

* support loading base64 images

* add test

* mention in docs

* remove the logging

* sort imports

* update error message

* Update tests/utils/test_image_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* restructure to catch base64 exception

* doesn't like the newline

* download files

* format

* optimize imports

* guess it needs a space?

* support loading base64 images

* add test

* remove the logging

* sort imports

* restructure to catch base64 exception

* doesn't like the newline

* download files

* optimize imports

* guess it needs a space?

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Haylee Schäfer
2023-08-29 20:24:24 +02:00
committed by GitHub
parent ce2d4bc6a1
commit dbc16f4404
3 changed files with 52 additions and 5 deletions

View File

@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import datasets
import numpy as np
import pytest
from huggingface_hub.file_download import http_get
from requests import ReadTimeout
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
@@ -500,6 +503,40 @@ class LoadImageTester(unittest.TestCase):
(480, 640, 3),
)
def test_load_img_base64_prefix(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_0.txt", f
)
with open(tmp_file, encoding="utf-8") as b64:
img = load_image(b64.read())
img_arr = np.array(img)
finally:
os.remove(tmp_file)
self.assertEqual(img_arr.shape, (64, 32, 3))
def test_load_img_base64(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get(
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_1.txt", f
)
with open(tmp_file, encoding="utf-8") as b64:
img = load_image(b64.read())
img_arr = np.array(img)
finally:
os.remove(tmp_file)
self.assertEqual(img_arr.shape, (64, 32, 3))
def test_load_img_rgba(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")