[processor] clean up mulitmodal tests (#37362)
* clkea up mulitmodal processor tests * fixup * fix tests * fix one last test * forgot
This commit is contained in:
committed by
GitHub
parent
3c39c07939
commit
a563999a02
@@ -19,7 +19,7 @@ import unittest
|
||||
import requests
|
||||
|
||||
from transformers import PixtralProcessor
|
||||
from transformers.testing_utils import require_read_token, require_vision
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@@ -34,7 +34,6 @@ if is_vision_available():
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_read_token
|
||||
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
|
||||
|
||||
@@ -49,30 +48,37 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
cls.url_2 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
|
||||
cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw)
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
cls.addClassCleanup(lambda tempdir=cls.tmpdirname: shutil.rmtree(tempdir))
|
||||
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
processor = PixtralProcessor.from_pretrained(
|
||||
"hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor"
|
||||
"hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor", **processor_kwargs
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
cls.image_token = processor.image_token
|
||||
|
||||
def get_processor(self):
|
||||
return self.processor_class.from_pretrained(self.tmpdirname)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
return {
|
||||
"chat_template": "{%- set today = strftime_now(\"%Y-%m-%d\") %}\n{%- set default_system_message = \"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\\nYour knowledge base was last updated on 2023-10-01. The current date is \" + today + \".\\n\\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \\\"What are some good restaurants around me?\\\" => \\\"Where are you?\\\" or \\\"When is the next flight to Tokyo\\\" => \\\"Where do you travel from?\\\")\" %}\n\n{{- bos_token }}\n\n{%- if messages[0]['role'] == 'system' %}\n {%- if messages[0] is string %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n {%- else %} \n {%- set system_message = messages[0]['content'][0]['text'] %}\n {%- set loop_messages = messages[1:] %}\n {%- endif %}\n{%- else %}\n {%- set system_message = default_system_message %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}\n\n{%- for message in loop_messages %}\n {%- if message['role'] == 'user' %}\n {%- if message['content'] is string %}\n {{- '[INST]' + message['content'] + '[/INST]' }}\n {%- else %}\n {{- '[INST]' }}\n {%- for block in message['content'] %}\n {%- if block['type'] == 'text' %}\n {{- block['text'] }}\n {%- elif block['type'] == 'image' or block['type'] == 'image_url' %}\n {{- '[IMG]' }}\n {%- else %}\n {{- raise_exception('Only text and image blocks are supported in message content!') }}\n {%- endif %}\n {%- endfor %}\n {{- '[/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'system' %}\n {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}\n {%- elif message['role'] == 'assistant' %}\n {%- if message['content'] is string %}\n {{- message['content'] + eos_token }}\n {%- else %}\n {{- message['content'][0]['text'] + eos_token }}\n {%- endif %}\n {%- else %}\n {{- raise_exception('Only user, system and assistant roles are supported!') }}\n {%- endif %}\n{%- endfor %}",
|
||||
"patch_size": 128,
|
||||
} # fmt: skip
|
||||
|
||||
def test_image_token_filling(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
# Important to check with non square image
|
||||
image = torch.randint(0, 2, (3, 500, 316))
|
||||
expected_image_tokens = 198
|
||||
expected_image_tokens = 4
|
||||
image_token_index = 10
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -104,14 +110,14 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 36]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -121,36 +127,36 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 36]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to "USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the content of the image? ASSISTANT:"
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing inputs as a single list
|
||||
inputs_image = processor(text=prompt_string, images=[self.image_0], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 36]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test as nested single list
|
||||
inputs_image = processor(text=prompt_string, images=[[self.image_0]], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([1, 3, 24, 36]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 4701, 1307, 1278, 3937, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -168,14 +174,14 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 36]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -185,25 +191,25 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 36]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Test passing in as a nested list
|
||||
inputs_url = processor(text=prompt_string, images=[[self.image_0, self.image_1]], return_tensors="pt")
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([2, 3, 24, 36]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_url["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -226,14 +232,14 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 36, 36]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_image["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -243,14 +249,14 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(len(inputs_url["input_ids"]) == 2)
|
||||
self.assertIsInstance(inputs_url["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], torch.Tensor)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 36, 36]))
|
||||
|
||||
# fmt: off
|
||||
input_ids = inputs_url["input_ids"]
|
||||
self.assertEqual(
|
||||
input_ids[0].tolist(),
|
||||
# Equivalent to ["USER: [IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END][IMG][IMG][IMG_BREAK][IMG][IMG][IMG_END]\nWhat's the difference between these two images? ASSISTANT:"]
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -258,12 +264,12 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
inputs_image = processor(
|
||||
text=prompt_string, images=[self.image_0, self.image_1, self.image_2], return_tensors="pt", padding=True
|
||||
)
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 30, 30]))
|
||||
self.assertTrue(inputs_image["pixel_values"].shape == torch.Size([3, 3, 36, 36]))
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
inputs_image["input_ids"][0].tolist(),
|
||||
[1, 21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
[1, 21510, 1058, 1032, 10, 10, 10, 12, 10, 10, 10, 13, 10, 10, 10, 12, 10, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user