Add validate images and text inputs order util for processors and test_processing_utils (#33285)
* Add validate images and test processing utils * Remove encoded text from possible inputs in tests * Removed encoded inputs as valid in processing_utils * change text input check to be recursive * change text check to all element of lists and not just the first one in recursive checks
This commit is contained in:
@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .image_utils import ChannelDimension, is_vision_available
|
from .image_utils import ChannelDimension, is_vision_available, valid_images
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -993,6 +993,50 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_images_text_input_order(images, text):
|
||||||
|
"""
|
||||||
|
For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped.
|
||||||
|
This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes.
|
||||||
|
Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled
|
||||||
|
in the processor's `__call__` method before calling this method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _is_valid_text_input_for_processor(t):
|
||||||
|
if isinstance(t, str):
|
||||||
|
# Strings are fine
|
||||||
|
return True
|
||||||
|
elif isinstance(t, (list, tuple)):
|
||||||
|
# List are fine as long as they are...
|
||||||
|
if len(t) == 0:
|
||||||
|
# ... not empty
|
||||||
|
return False
|
||||||
|
for t_s in t:
|
||||||
|
return _is_valid_text_input_for_processor(t_s)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_valid(input, validator):
|
||||||
|
return validator(input) or input is None
|
||||||
|
|
||||||
|
images_is_valid = _is_valid(images, valid_images)
|
||||||
|
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False
|
||||||
|
|
||||||
|
text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
|
||||||
|
text_is_images = valid_images(text) if not text_is_valid else False
|
||||||
|
# Handle cases where both inputs are valid
|
||||||
|
if images_is_valid and text_is_valid:
|
||||||
|
return images, text
|
||||||
|
|
||||||
|
# Handle cases where inputs need to and can be swapped
|
||||||
|
if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images):
|
||||||
|
logger.warning_once(
|
||||||
|
"You may have used the wrong order for inputs. `images` should be passed before `text`. "
|
||||||
|
"The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47."
|
||||||
|
)
|
||||||
|
return text, images
|
||||||
|
|
||||||
|
raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.")
|
||||||
|
|
||||||
|
|
||||||
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
||||||
if ProcessorMixin.push_to_hub.__doc__ is not None:
|
if ProcessorMixin.push_to_hub.__doc__ is not None:
|
||||||
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
|
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
|
||||||
|
|||||||
164
tests/utils/test_processing_utils.py
Normal file
164
tests/utils/test_processing_utils.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import is_torch_available, is_vision_available
|
||||||
|
from transformers.processing_utils import _validate_images_text_input_order
|
||||||
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
class ProcessingUtilTester(unittest.TestCase):
|
||||||
|
def test_validate_images_text_input_order(self):
|
||||||
|
# text string and PIL images inputs
|
||||||
|
images = PIL.Image.new("RGB", (224, 224))
|
||||||
|
text = "text"
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertEqual(valid_images, images)
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertEqual(valid_images, images)
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
|
||||||
|
# text list of string and numpy images inputs
|
||||||
|
images = np.random.rand(224, 224, 3)
|
||||||
|
text = ["text1", "text2"]
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertTrue(np.array_equal(valid_images, images))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertTrue(np.array_equal(valid_images, images))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
|
||||||
|
# text nested list of string and list of pil images inputs
|
||||||
|
images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))]
|
||||||
|
text = [["text1", "text2, text3"], ["text3", "text4"]]
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertEqual(valid_images, images)
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertEqual(valid_images, images)
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
|
||||||
|
# list of strings and list of numpy images inputs
|
||||||
|
images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)]
|
||||||
|
text = ["text1", "text2"]
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertTrue(np.array_equal(valid_images[0], images[0]))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertTrue(np.array_equal(valid_images[0], images[0]))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
|
||||||
|
# list of strings and nested list of numpy images inputs
|
||||||
|
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
|
||||||
|
text = ["text1", "text2"]
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
|
||||||
|
# nested list of strings and nested list of PIL images inputs
|
||||||
|
images = [
|
||||||
|
[PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))],
|
||||||
|
[PIL.Image.new("RGB", (224, 224))],
|
||||||
|
]
|
||||||
|
text = [["text1", "text2, text3"], ["text3", "text4"]]
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertEqual(valid_images, images)
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertEqual(valid_images, images)
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
|
||||||
|
# None images
|
||||||
|
images = None
|
||||||
|
text = "text"
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertEqual(images, None)
|
||||||
|
self.assertEqual(text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertEqual(images, None)
|
||||||
|
self.assertEqual(text, text)
|
||||||
|
|
||||||
|
# None text
|
||||||
|
images = PIL.Image.new("RGB", (224, 224))
|
||||||
|
text = None
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertEqual(images, images)
|
||||||
|
self.assertEqual(text, None)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertEqual(images, images)
|
||||||
|
self.assertEqual(text, None)
|
||||||
|
|
||||||
|
# incorrect inputs
|
||||||
|
images = "text"
|
||||||
|
text = "text"
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_validate_images_text_input_order(images=images, text=text)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_validate_images_text_input_order_torch(self):
|
||||||
|
# text string and torch images inputs
|
||||||
|
images = torch.rand(224, 224, 3)
|
||||||
|
text = "text"
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertTrue(torch.equal(valid_images, images))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertTrue(torch.equal(valid_images, images))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
|
||||||
|
# text list of string and list of torch images inputs
|
||||||
|
images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)]
|
||||||
|
text = ["text1", "text2"]
|
||||||
|
# test correct text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
|
||||||
|
self.assertTrue(torch.equal(valid_images[0], images[0]))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
|
# test incorrect text and images order
|
||||||
|
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
|
||||||
|
self.assertTrue(torch.equal(valid_images[0], images[0]))
|
||||||
|
self.assertEqual(valid_text, text)
|
||||||
Reference in New Issue
Block a user