Use non nested images and batched text Idefics2/3 (#34222)
* add support for non nested images and add tests * add tests error scenario * fix style * added single and no image to error tests
This commit is contained in:
@@ -99,6 +99,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
|
|||||||
isinstance(images, (list, tuple))
|
isinstance(images, (list, tuple))
|
||||||
and len(images) > 0
|
and len(images) > 0
|
||||||
and isinstance(images[0], (list, tuple))
|
and isinstance(images[0], (list, tuple))
|
||||||
|
and len(images[0]) > 0
|
||||||
and is_valid_image(images[0][0])
|
and is_valid_image(images[0][0])
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
Processor class for IDEFICS2.
|
Processor class for IDEFICS2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
@@ -218,7 +219,21 @@ class Idefics2Processor(ProcessorMixin):
|
|||||||
if is_image_or_image_url(images):
|
if is_image_or_image_url(images):
|
||||||
images = [[images]]
|
images = [[images]]
|
||||||
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||||
images = [images]
|
if text is not None:
|
||||||
|
if sum(n_images_in_text) != len(images):
|
||||||
|
raise ValueError(
|
||||||
|
f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed."
|
||||||
|
f" Found {sum(n_images_in_text)} {image_token} tokens and {len(images)} images."
|
||||||
|
)
|
||||||
|
# Reorganize the images to match the prompts
|
||||||
|
cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
|
||||||
|
images = [
|
||||||
|
images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
|
||||||
|
for i in range(len(n_images_in_text))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
images = [images]
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
not isinstance(images, list)
|
not isinstance(images, list)
|
||||||
and not isinstance(images[0], list)
|
and not isinstance(images[0], list)
|
||||||
|
|||||||
@@ -151,9 +151,11 @@ def get_resize_output_image_size(
|
|||||||
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
|
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
Convert a single image or a list of images to a list of numpy arrays.
|
Convert a single image or a list of images to a list of numpy arrays.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
images (`ImageInput`):
|
images (`ImageInput`):
|
||||||
A single image or a list of images.
|
A single image or a list of images.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of numpy arrays.
|
A list of numpy arrays.
|
||||||
"""
|
"""
|
||||||
@@ -168,6 +170,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
|
|||||||
isinstance(images, (list, tuple))
|
isinstance(images, (list, tuple))
|
||||||
and len(images) > 0
|
and len(images) > 0
|
||||||
and isinstance(images[0], (list, tuple))
|
and isinstance(images[0], (list, tuple))
|
||||||
|
and len(images[0]) > 0
|
||||||
and is_valid_image(images[0][0])
|
and is_valid_image(images[0][0])
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ Processor class for Idefics3.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
@@ -241,11 +242,31 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
n_images_in_images = []
|
n_images_in_images = []
|
||||||
inputs = BatchFeature()
|
inputs = BatchFeature()
|
||||||
|
|
||||||
|
if text is not None:
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
n_images_in_text = [sample.count(self.image_token.content) for sample in text]
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
if is_image_or_image_url(images):
|
if is_image_or_image_url(images):
|
||||||
images = [[images]]
|
images = [[images]]
|
||||||
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||||
images = [images]
|
if text is not None:
|
||||||
|
if sum(n_images_in_text) != len(images):
|
||||||
|
raise ValueError(
|
||||||
|
f"The total number of {self.image_token.content} tokens in the prompts should be the same as the number of images passed."
|
||||||
|
f" Found {sum(n_images_in_text)} {self.image_token.content} tokens and {len(images)} images."
|
||||||
|
)
|
||||||
|
# Reorganize the images to match the prompts
|
||||||
|
cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
|
||||||
|
images = [
|
||||||
|
images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
|
||||||
|
for i in range(len(n_images_in_text))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
images = [images]
|
||||||
elif (
|
elif (
|
||||||
not isinstance(images, list)
|
not isinstance(images, list)
|
||||||
and not isinstance(images[0], list)
|
and not isinstance(images[0], list)
|
||||||
@@ -263,10 +284,10 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
inputs.update(image_inputs)
|
inputs.update(image_inputs)
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
if isinstance(text, str):
|
if n_images_in_images != n_images_in_text:
|
||||||
text = [text]
|
raise ValueError(
|
||||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
|
||||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
)
|
||||||
|
|
||||||
image_rows = inputs.pop("rows", [[0] * len(text)])
|
image_rows = inputs.pop("rows", [[0] * len(text)])
|
||||||
image_cols = inputs.pop("cols", [[0] * len(text)])
|
image_cols = inputs.pop("cols", [[0] * len(text)])
|
||||||
@@ -277,8 +298,6 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
|
|
||||||
prompt_strings = []
|
prompt_strings = []
|
||||||
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
|
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
|
||||||
n_images_in_text.append(sample.count(image_token))
|
|
||||||
|
|
||||||
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
|
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
|
||||||
image_prompt_strings = []
|
image_prompt_strings = []
|
||||||
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
||||||
@@ -305,11 +324,6 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
|
||||||
inputs.update(text_inputs)
|
inputs.update(text_inputs)
|
||||||
|
|
||||||
if n_images_in_images != n_images_in_text:
|
|
||||||
raise ValueError(
|
|
||||||
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
|
|||||||
isinstance(images, (list, tuple))
|
isinstance(images, (list, tuple))
|
||||||
and len(images) > 0
|
and len(images) > 0
|
||||||
and isinstance(images[0], (list, tuple))
|
and isinstance(images[0], (list, tuple))
|
||||||
|
and len(images[0]) > 0
|
||||||
and is_valid_image(images[0][0])
|
and is_valid_image(images[0][0])
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -226,6 +226,73 @@ class Idefics2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
def test_non_nested_images_with_batched_text(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
processor.image_processor.do_image_splitting = False
|
||||||
|
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str_1 = "In this image, we see"
|
||||||
|
text_str_2 = "bla, bla"
|
||||||
|
|
||||||
|
text = [
|
||||||
|
image_str + text_str_1,
|
||||||
|
text_str_2 + image_str + image_str,
|
||||||
|
]
|
||||||
|
images = [self.image1, self.image2, self.image3]
|
||||||
|
|
||||||
|
inputs = processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
self.assertEqual(inputs["pixel_values"].shape, (2, 2, 3, 767, 980))
|
||||||
|
self.assertEqual(inputs["pixel_attention_mask"].shape, (2, 2, 767, 980))
|
||||||
|
|
||||||
|
def test_process_interleaved_images_prompts_image_error(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.",
|
||||||
|
"In this other sentence we try some good things",
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[self.image1], []]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.<image>",
|
||||||
|
"In this other sentence we try some good things<image>",
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2, self.image3]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1, self.image2, self.image3]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.",
|
||||||
|
"In this other sentence we try some good things<image>",
|
||||||
|
]
|
||||||
|
images = [[self.image1], []]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1, self.image2]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
def test_apply_chat_template(self):
|
def test_apply_chat_template(self):
|
||||||
# Message contains content which a mix of lists with images and image urls and string
|
# Message contains content which a mix of lists with images and image urls and string
|
||||||
messages = [
|
messages = [
|
||||||
@@ -275,13 +342,3 @@ class Idefics2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
|
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
|
||||||
batch_size - 2
|
batch_size - 2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Override as PixtralProcessor needs nested images to work properly with batched inputs
|
|
||||||
@require_vision
|
|
||||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
|
||||||
"""This function prepares a list of PIL images for testing"""
|
|
||||||
if batch_size is None:
|
|
||||||
return super().prepare_image_inputs()
|
|
||||||
if batch_size < 1:
|
|
||||||
raise ValueError("batch_size must be greater than 0")
|
|
||||||
return [[super().prepare_image_inputs()]] * batch_size
|
|
||||||
|
|||||||
@@ -250,6 +250,74 @@ class Idefics3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
self.assertEqual(inputs["input_ids"], expected_input_ids)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
def test_non_nested_images_with_batched_text(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
processor.image_processor.do_image_splitting = False
|
||||||
|
|
||||||
|
image_str = "<image>"
|
||||||
|
text_str_1 = "In this image, we see"
|
||||||
|
text_str_2 = "In this image, we see"
|
||||||
|
|
||||||
|
text = [
|
||||||
|
image_str + text_str_1,
|
||||||
|
image_str + image_str + text_str_2,
|
||||||
|
]
|
||||||
|
images = [self.image1, self.image2, self.image3]
|
||||||
|
|
||||||
|
inputs = processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 2, 3, 364, 364))
|
||||||
|
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (2, 2, 364, 364))
|
||||||
|
|
||||||
|
# Copied from tests.models.idefics2.test_processor_idefics2.Idefics2ProcessorTest.test_process_interleaved_images_prompts_image_error
|
||||||
|
def test_process_interleaved_images_prompts_image_error(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.",
|
||||||
|
"In this other sentence we try some good things",
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[self.image1], []]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.<image>",
|
||||||
|
"In this other sentence we try some good things<image>",
|
||||||
|
]
|
||||||
|
images = [[self.image1], [self.image2, self.image3]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1, self.image2, self.image3]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
|
text = [
|
||||||
|
"This is a test sentence.",
|
||||||
|
"In this other sentence we try some good things<image>",
|
||||||
|
]
|
||||||
|
images = [[self.image1], []]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [[], [self.image2]]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1, self.image2]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
images = [self.image1]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
processor(text=text, images=images, padding=True)
|
||||||
|
|
||||||
def test_apply_chat_template(self):
|
def test_apply_chat_template(self):
|
||||||
# Message contains content which a mix of lists with images and image urls and string
|
# Message contains content which a mix of lists with images and image urls and string
|
||||||
messages = [
|
messages = [
|
||||||
@@ -299,16 +367,7 @@ class Idefics3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
batch_size - 2
|
batch_size - 2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Override as Idefics3Processor needs nested images to work properly with batched inputs
|
# Override tests as inputs_ids padded dimension is the second one but not the last one
|
||||||
@require_vision
|
|
||||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
|
||||||
"""This function prepares a list of PIL images for testing"""
|
|
||||||
if batch_size is None:
|
|
||||||
return super().prepare_image_inputs()
|
|
||||||
if batch_size < 1:
|
|
||||||
raise ValueError("batch_size must be greater than 0")
|
|
||||||
return [[super().prepare_image_inputs()]] * batch_size
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user