Support loading base64 images in pipelines (#25633)
* support loading base64 images * add test * mention in docs * remove the logging * sort imports * update error message * Update tests/utils/test_image_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * restructure to catch base64 exception * doesn't like the newline * download files * format * optimize imports * guess it needs a space? * support loading base64 images * add test * remove the logging * sort imports * restructure to catch base64 exception * doesn't like the newline * download files * optimize imports * guess it needs a space? --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -13,11 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.file_download import http_get
|
||||
from requests import ReadTimeout
|
||||
|
||||
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
|
||||
@@ -500,6 +503,40 @@ class LoadImageTester(unittest.TestCase):
|
||||
(480, 640, 3),
|
||||
)
|
||||
|
||||
def test_load_img_base64_prefix(self):
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_0.txt", f
|
||||
)
|
||||
|
||||
with open(tmp_file, encoding="utf-8") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
self.assertEqual(img_arr.shape, (64, 32, 3))
|
||||
|
||||
def test_load_img_base64(self):
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_1.txt", f
|
||||
)
|
||||
|
||||
with open(tmp_file, encoding="utf-8") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
self.assertEqual(img_arr.shape, (64, 32, 3))
|
||||
|
||||
def test_load_img_rgba(self):
|
||||
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user