load_image - decode b64encode and encodebytes strings (#30192)
* Decode b64encode and encodebytes strings * Remove conditional encode -- image is always a string
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import codecs
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -544,6 +545,23 @@ class LoadImageTester(unittest.TestCase):
|
||||
|
||||
self.assertEqual(img_arr.shape, (64, 32, 3))
|
||||
|
||||
def test_load_img_base64_encoded_bytes(self):
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_2.txt", f
|
||||
)
|
||||
|
||||
with codecs.open(tmp_file, encoding="unicode_escape") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
self.assertEqual(img_arr.shape, (256, 256, 3))
|
||||
|
||||
def test_load_img_rgba(self):
|
||||
# we use revision="refs/pr/1" until the PR is merged
|
||||
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
|
||||
|
||||
Reference in New Issue
Block a user