Add timeout parameter to load_image function (#25184)
* Add timeout parameter to load_image function. * Remove line. * Reformat code Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add parameter to docs. --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,9 @@ import unittest
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pytest
|
||||
from requests import ReadTimeout
|
||||
|
||||
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
@@ -478,6 +480,16 @@ class ImageFeatureExtractionTester(unittest.TestCase):
|
||||
|
||||
@require_vision
|
||||
class LoadImageTester(unittest.TestCase):
|
||||
def test_load_img_url(self):
|
||||
img = load_image(INVOICE_URL)
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(img_arr.shape, (1061, 750, 3))
|
||||
|
||||
def test_load_img_url_timeout(self):
|
||||
with self.assertRaises(ReadTimeout):
|
||||
load_image(INVOICE_URL, timeout=0.001)
|
||||
|
||||
def test_load_img_local(self):
|
||||
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
img_arr = np.array(img)
|
||||
|
||||
Reference in New Issue
Block a user