[MllamaProcessor] Update errors and API with multiple image (#33715)

* update error

* update and add a test

* update

* update
This commit is contained in:
Arthur
2024-09-26 16:33:25 +02:00
committed by GitHub
parent 0a21381ba3
commit 46841d3eb2
2 changed files with 134 additions and 16 deletions

View File

@@ -12,11 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Processor class for Mllama.
"""
from statistics import mean """Processor class for Mllama."""
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
@@ -296,25 +294,27 @@ class MllamaProcessor(ProcessorMixin):
encoding = self.tokenizer(text, **text_kwargs) encoding = self.tokenizer(text, **text_kwargs)
data.update(encoding) data.update(encoding)
n_images_in_images = [0]
if images is not None: if images is not None:
images = make_list_of_images(images) images = make_list_of_images(images)
n_images_in_images = [len(sample) for sample in images] n_images_in_images = [len(sample) for sample in images]
if text is not None: if text is not None:
if ( if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
not all(batch_img_per_prompt == n_images_in_images for batch_img_per_prompt in n_images_in_text) batch_img == 0 for batch_img in n_images_in_text
and len(text) > 1 ):
): raise ValueError(
"If a batch of text is provided, there should be either no images or at least one image per sample"
)
if sum(n_images_in_images) != sum(n_images_in_text):
if images is None:
raise ValueError("No image were provided, but there are image tokens in the prompt")
else:
raise ValueError( raise ValueError(
f"The number of images in each batch {n_images_in_text} should be the same {n_images_in_images} should be the same. Yes, the model does not \ f"The number of image token ({sum(n_images_in_images)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
support having a different number of images per batch."
)
if int(mean(n_images_in_text)) != int(mean(n_images_in_images)):
raise ValueError(
f"The number of images in the text ({n_images_in_text}) should be the same as in the number of provided images ({n_images_in_images}) \
should be the same."
) )
if images is not None:
image_features = self.image_processor(images, **images_kwargs) image_features = self.image_processor(images, **images_kwargs)
num_tiles = image_features.pop("num_tiles") num_tiles = image_features.pop("num_tiles")
data.update(image_features) data.update(image_features)

View File

@@ -15,6 +15,8 @@
import unittest import unittest
import numpy as np
from transformers import MllamaProcessor from transformers import MllamaProcessor
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_vision_available
@@ -177,3 +179,119 @@ class MllamaProcessorTest(unittest.TestCase):
rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False) rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False) rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False)
self.assertEqual(rendered_list, rendered_str) self.assertEqual(rendered_list, rendered_str)
def test_process_interleaved_images_prompts_image_splitting(self):
# Test that a single image is processed correctly
inputs = self.processor(images=self.image2, size={"width": 224, "height": 224})
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 224, 224))
# Test that text is processed correctly
text = "<|begin_of_text|>This is a test sentence.<|end_of_text|>"
inputs = self.processor(text=text)
expected_ids = [128000, 2028, 374, 264, 1296, 11914, 13, 128001]
self.assertEqual(inputs["input_ids"][0], expected_ids)
self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids))
self.assertEqual(inputs.get("cross_attention_mask"), None)
# Test a single sample with image and text
image_str = "<|image|>"
text_str = "This is a test sentence."
text = image_str + text_str
inputs = self.processor(
text=text,
images=self.image1,
size={"width": 128, "height": 128},
)
expected_ids = [self.image_token_id, self.bos_token_id] + [2028, 374, 264, 1296, 11914, 13]
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 128, 128))
self.assertEqual(inputs["input_ids"][0], expected_ids)
self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids))
cross_attention_mask = inputs["cross_attention_mask"]
self.assertEqual(cross_attention_mask.shape, (1, 8, 1, 4))
self.assertTrue(
np.all(cross_attention_mask == 1), f"Cross attention mask is not all ones: {cross_attention_mask}"
)
# Test batch
text = [
"<|image|>This is a test sentence.",
"This is a test sentence.<|image|><|image|>This is a test sentence.",
]
# fmt: off
expected_ids = [
[self.image_token_id, self.bos_token_id, 2028, 374, 264, 1296, 11914, 13],
[self.bos_token_id, 2028, 374, 264, 1296, 11914, 13, self.image_token_id, self.image_token_id, 2028, 374, 264, 1296, 11914, 13],
]
# fmt: onn
images = [[self.image1], [self.image1, self.image2]]
inputs = self.processor(text=text, images=images, padding=True, size={"width": 256, "height": 256})
self.assertEqual(inputs["pixel_values"].shape, (2, 2, 4, 3, 256, 256))
for input_ids_i, attention_mask_i, expected_ids_i in zip(inputs["input_ids"], inputs["attention_mask"], expected_ids):
pad_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 0]
input_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 1]
self.assertEqual(input_ids, expected_ids_i)
self.assertEqual(pad_ids, [self.pad_token_id] * len(pad_ids))
cross_attention_mask = inputs["cross_attention_mask"]
self.assertEqual(cross_attention_mask.shape, (2, 15, 2, 4))
# Check that only first tile of first sample is attended to all text tokens
first_sample_mask = cross_attention_mask[0].copy()
first_image_first_tile_attention = first_sample_mask[:, :1, :1] # text tokens, images, tiles
self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}")
# zero out first tile of first image
first_image_first_tile_attention[:, :1, :1] = 0
self.assertTrue(np.all(first_image_first_tile_attention == 0), f"Cross attention mask is not all zeros: {first_image_first_tile_attention}")
# second sample
second_sample_mask = cross_attention_mask[1].copy()
first_image_first_tile_attention = second_sample_mask[7:, :1, :1] # text tokens, images, tiles
self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}")
second_image_two_tiles_attention = second_sample_mask[8:, 1:2, :2] # text tokens, images, tiles
self.assertTrue(np.all(second_image_two_tiles_attention == 1), f"Cross attention mask is not all ones: {second_image_two_tiles_attention}")
# zero out both images masks
second_sample_mask[7:, :1, :1] = 0
second_sample_mask[8:, 1:2, :2] = 0
self.assertTrue(np.all(second_sample_mask == 0), f"Cross attention mask is not all zeros: {second_sample_mask}")
def test_process_interleaved_images_prompts_image_error(self):
text = [
"This is a test sentence.",
"In this other sentence we try some good things",
]
inputs = self.processor(text=text, images=None, padding=True)
self.assertIsNotNone(inputs["input_ids"])
text = [
"This is a test sentence.<|image|>",
"In this other sentence we try some good things",
]
with self.assertRaises(ValueError):
self.processor(text=text, images=None, padding=True)
images = [[self.image1], []]
with self.assertRaises(ValueError):
self.processor(text=text, images=images, padding=True)
text = [
"This is a test sentence.<|image|>",
"In this other sentence we try some good things<|image|>",
]
with self.assertRaises(ValueError):
self.processor(text=text, images=None, padding=True)
text = [
"This is a test sentence.<|image|>",
"In this other sentence we try some good things<|image|>",
]
images = [[self.image1], [self.image2]]
inputs = self.processor(text=text, images=images, padding=True)
images = [[self.image1, self.image2], []]
with self.assertRaises(ValueError):
self.processor(text=text, images=None, padding=True)