Uniformize kwargs for LLaVa processor and update docs (#32858)
* Uniformize kwargs for LlaVa and update docs * Change order of processor inputs in docstring * Improve BC support for reversed images and text inputs * cleanup llava processor call docstring * Add encoded inputs as valid text inputs in reverse input check, add deprecation version in warning * Put function check reversed images text outside base processor class * Refactor _validate_images_text_input_order * Add ProcessingUtilTester * fix processing and test_processing
This commit is contained in:
@@ -405,7 +405,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
|||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
>>> # Generate
|
>>> # Generate
|
||||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
|||||||
@@ -16,18 +16,33 @@
|
|||||||
Processor class for Llava.
|
Processor class for Llava.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
import sys
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||||
from ...processing_utils import ProcessorMixin
|
from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
|
||||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import TensorType, logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
from typing import Unpack
|
||||||
|
else:
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding": False,
|
||||||
|
},
|
||||||
|
"images_kwargs": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class LlavaProcessor(ProcessorMixin):
|
class LlavaProcessor(ProcessorMixin):
|
||||||
r"""
|
r"""
|
||||||
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
|
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
|
||||||
@@ -73,12 +88,11 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
|
||||||
images: ImageInput = None,
|
images: ImageInput = None,
|
||||||
padding: Union[bool, str, PaddingStrategy] = False,
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
audio=None,
|
||||||
max_length=None,
|
videos=None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
**kwargs: Unpack[LlavaProcessorKwargs],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||||
@@ -88,29 +102,15 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
of the above two methods for more information.
|
of the above two methods for more information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||||
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. Both channels-first and channels-last formats are supported.
|
||||||
text (`str`, `List[str]`, `List[List[str]]`):
|
text (`str`, `List[str]`, `List[List[str]]`):
|
||||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
|
||||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
|
||||||
tensor. Both channels-first and channels-last formats are supported.
|
|
||||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
|
|
||||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
|
||||||
index) among:
|
|
||||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
|
||||||
sequence if provided).
|
|
||||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
|
||||||
acceptable input length for the model if that argument is not provided.
|
|
||||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
|
||||||
lengths).
|
|
||||||
max_length (`int`, *optional*):
|
|
||||||
Maximum length of the returned list and optionally padding length (see above).
|
|
||||||
truncation (`bool`, *optional*):
|
|
||||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
|
||||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||||
If set, will return tensors of a particular framework. Acceptable values are:
|
If set, will return tensors of a particular framework. Acceptable values are:
|
||||||
|
|
||||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
@@ -125,8 +125,19 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
`None`).
|
`None`).
|
||||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||||
"""
|
"""
|
||||||
|
if images is None and text is None:
|
||||||
|
raise ValueError("You have to specify at least one of `images` or `text`.")
|
||||||
|
|
||||||
|
# check if images and text inputs are reversed for BC
|
||||||
|
images, text = _validate_images_text_input_order(images, text)
|
||||||
|
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
LlavaProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_inputs = self.image_processor(images, return_tensors=return_tensors)
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
else:
|
else:
|
||||||
image_inputs = {}
|
image_inputs = {}
|
||||||
|
|
||||||
@@ -158,13 +169,7 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||||
)
|
)
|
||||||
|
|
||||||
text_inputs = self.tokenizer(
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
||||||
prompt_strings,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
padding=padding,
|
|
||||||
truncation=truncation,
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs})
|
return BatchFeature(data={**text_inputs, **image_inputs})
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
|
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
|
||||||
image_file = "https://llava-vl.github.io/static/images/view.jpg"
|
image_file = "https://llava-vl.github.io/static/images/view.jpg"
|
||||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||||
inputs = self.processor(prompt, raw_image, return_tensors="pt")
|
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
|
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
|
||||||
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
|
||||||
@@ -299,7 +299,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
|
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
|
||||||
image_file = "https://llava-vl.github.io/static/images/view.jpg"
|
image_file = "https://llava-vl.github.io/static/images/view.jpg"
|
||||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
|
||||||
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
|
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
|
||||||
@@ -325,7 +325,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||||
|
|
||||||
inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
|
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
@@ -349,7 +349,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||||
|
|
||||||
inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
|
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
@@ -381,7 +381,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
|
||||||
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
|
||||||
|
|
||||||
inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
|
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=20)
|
output = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
|
||||||
@@ -409,8 +409,8 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
image2 = Image.open(requests.get(url2, stream=True).raw)
|
image2 = Image.open(requests.get(url2, stream=True).raw)
|
||||||
|
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
text=[prompt1, prompt2, prompt3],
|
|
||||||
images=[image1, image2, image1, image2],
|
images=[image1, image2, image1, image2],
|
||||||
|
text=[prompt1, prompt2, prompt3],
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
@@ -444,7 +444,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
|
||||||
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
||||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||||
|
|
||||||
# Make sure that `generate` works
|
# Make sure that `generate` works
|
||||||
_ = model.generate(**inputs, max_new_tokens=20)
|
_ = model.generate(**inputs, max_new_tokens=20)
|
||||||
@@ -510,7 +510,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
processor = AutoProcessor.from_pretrained(model_id)
|
processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
|
||||||
# Prepare inputs with no images
|
# Prepare inputs with no images
|
||||||
inputs = processor("Hello, I am", return_tensors="pt").to(torch_device)
|
inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
# Make sure that `generate` works
|
# Make sure that `generate` works
|
||||||
_ = model.generate(**inputs, max_new_tokens=20)
|
_ = model.generate(**inputs, max_new_tokens=20)
|
||||||
@@ -554,13 +554,13 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
# check processing with expansion of inputs
|
# check processing with expansion of inputs
|
||||||
processor.vision_feature_select_strategy = "default"
|
processor.vision_feature_select_strategy = "default"
|
||||||
processor.patch_size = 14
|
processor.patch_size = 14
|
||||||
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
|
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
|
||||||
|
|
||||||
# check processing without expansion of inputs (legacy behavior)
|
# check processing without expansion of inputs (legacy behavior)
|
||||||
processor.vision_feature_select_strategy = None
|
processor.vision_feature_select_strategy = None
|
||||||
processor.patch_size = None
|
processor.patch_size = None
|
||||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||||
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
||||||
|
|
||||||
# generate exactly 20 tokens
|
# generate exactly 20 tokens
|
||||||
|
|||||||
@@ -11,18 +11,43 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.testing_utils import require_vision
|
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
|
||||||
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
from transformers.utils import is_vision_available
|
from transformers.utils import is_vision_available
|
||||||
|
|
||||||
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from transformers import AutoTokenizer, LlavaProcessor
|
from transformers import CLIPImageProcessor
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
class LlavaProcessorTest(unittest.TestCase):
|
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||||
|
processor_class = LlavaProcessor
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
image_processor = CLIPImageProcessor(do_center_crop=False)
|
||||||
|
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
|
||||||
|
|
||||||
|
processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
def test_can_load_various_tokenizers(self):
|
def test_can_load_various_tokenizers(self):
|
||||||
for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]:
|
for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]:
|
||||||
processor = LlavaProcessor.from_pretrained(checkpoint)
|
processor = LlavaProcessor.from_pretrained(checkpoint)
|
||||||
@@ -45,3 +70,29 @@ class LlavaProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
self.assertEqual(expected_prompt, formatted_prompt)
|
self.assertEqual(expected_prompt, formatted_prompt)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
def test_unstructured_kwargs_batched(self):
|
||||||
|
if "image_processor" not in self.processor_class.attributes:
|
||||||
|
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||||
|
image_processor = self.get_component("image_processor")
|
||||||
|
tokenizer = self.get_component("tokenizer")
|
||||||
|
|
||||||
|
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||||
|
self.skip_processor_without_typed_kwargs(processor)
|
||||||
|
|
||||||
|
input_str = ["lower newer", "upper older longer string"]
|
||||||
|
image_input = self.prepare_image_inputs() * 2
|
||||||
|
inputs = processor(
|
||||||
|
images=image_input,
|
||||||
|
text=input_str,
|
||||||
|
return_tensors="pt",
|
||||||
|
size={"height": 214, "width": 214},
|
||||||
|
padding="longest",
|
||||||
|
max_length=76,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||||
|
|
||||||
|
self.assertEqual(len(inputs["input_ids"][0]), 5)
|
||||||
|
|||||||
Reference in New Issue
Block a user