Pixtral: vectorize patch embeddings and enable tests (#35122)
* initial POC * - batch mix feature * fix tests * fix tests * make style * do not skip and instead fix tests * update * return back the test * correct text with the correct ckpt
This commit is contained in:
committed by
GitHub
parent
8bc4c89ee9
commit
9725e5be2f
@@ -14,7 +14,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -28,7 +27,7 @@ from ...test_processing_common import ProcessorTesterMixin
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import AutoTokenizer, PixtralImageProcessor, PixtralProcessor
|
||||
from transformers import PixtralProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
@@ -46,20 +45,15 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
# FIXME - just load the processor directly from the checkpoint
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b")
|
||||
image_processor = PixtralImageProcessor()
|
||||
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
||||
processor = PixtralProcessor.from_pretrained("mistral-community/pixtral-12b")
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
@unittest.skip("No chat template was set for this model (yet)")
|
||||
def test_chat_template(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:"
|
||||
expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -73,13 +67,12 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
@unittest.skip("No chat template was set for this model (yet)")
|
||||
def test_image_token_filling(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
# Important to check with non square image
|
||||
image = torch.randint(0, 2, (3, 500, 316))
|
||||
expected_image_tokens = 1526
|
||||
image_token_index = 32000
|
||||
expected_image_tokens = 640
|
||||
image_token_index = 10
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -111,11 +104,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"][0]) == 1)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
@@ -131,11 +121,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"][0]) == 1)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
@@ -146,6 +133,28 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing inputs as a single list
|
||||
inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test as nested single list
|
||||
inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_with_multiple_images_single_list(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
|
||||
@@ -159,11 +168,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
@@ -179,11 +185,9 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
@@ -193,6 +197,17 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in as a nested list
|
||||
inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_url["input_ids"][0].tolist(),
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_with_multiple_images_multiple_lists(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = [
|
||||
@@ -211,11 +226,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 2)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
@@ -231,11 +243,8 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIn("input_ids", inputs_url)
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_url["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"]) == 2)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_url["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_url["pixel_values"][0][0], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
@@ -246,6 +255,19 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing as a single flat list
|
||||
inputs_image = processor(
|
||||
text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True
|
||||
)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 32, 32]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_processor_returns_full_length_batches(self):
|
||||
# to avoid https://github.com/huggingface/transformers/issues/34204
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
@@ -264,13 +286,3 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 5)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 5)
|
||||
|
||||
# Override as PixtralProcessor needs nested images to work properly with batched inputs
|
||||
@require_vision
|
||||
def prepare_image_inputs(self, batch_size: Optional[int] = None):
|
||||
"""This function prepares a list of PIL images for testing"""
|
||||
if batch_size is None:
|
||||
return super().prepare_image_inputs()
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
return [[super().prepare_image_inputs()]] * batch_size
|
||||
|
||||
Reference in New Issue
Block a user