[vlm] adjust max length for special tokens (#37342)

* update

* apply suggestion

* fix tests for main branch

* remove unused logger

* add special tokens in tests

* nit

* fix more tests

* fix test

* pg also
This commit is contained in:
Raushan Turganbay
2025-04-16 20:49:20 +02:00
committed by GitHub
parent c94c59fc47
commit 32eca7197a
39 changed files with 414 additions and 98 deletions

View File

@@ -271,3 +271,29 @@ And who is that?<|im_end|>
return_tensors="np",
)
self.assertListEqual(list(out_dict[self.images_input_name].shape), [1, 3, 980, 980])
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=3,
)

View File

@@ -40,10 +40,37 @@ class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenizer = LlamaTokenizer(vocab_file=SAMPLE_VOCAB)
tokenizer.pad_token_id = 0
tokenizer.sep_token_id = 1
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
processor = cls.processor_class(image_processor=image_processor, tokenizer=tokenizer, image_seq_length=2)
processor.save_pretrained(cls.tmpdirname)
cls.image_token = processor.image_token
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=20,
)
@staticmethod
def prepare_processor_dict():
return {"image_seq_length": 2} # fmt: skip

View File

@@ -124,3 +124,28 @@ class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
# base image + 4 crops
self.assertEqual(len(inputs[self.images_input_name]), 5)
self.assertEqual(len(inputs[self.text_input_name][0]), 67)
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=5,
)

View File

@@ -66,8 +66,8 @@ class Idefics2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
)
)
cls.bos_token = processor.tokenizer.bos_token
cls.image_token = processor.image_token.content
cls.fake_image_token = processor.fake_image_token.content
cls.image_token = processor.image_token
cls.fake_image_token = processor.fake_image_token
cls.bos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.bos_token)
cls.image_token_id = processor.tokenizer.convert_tokens_to_ids(cls.image_token)

View File

@@ -60,8 +60,8 @@ class Idefics3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
)
)
cls.bos_token = processor.tokenizer.bos_token
cls.image_token = processor.image_token.content
cls.fake_image_token = processor.fake_image_token.content
cls.image_token = processor.image_token
cls.fake_image_token = processor.fake_image_token
cls.global_img_token = processor.global_image_tag
cls.bos_token_id = processor.tokenizer.convert_tokens_to_ids(cls.bos_token)

View File

@@ -40,6 +40,7 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_processor = CLIPImageProcessor(do_center_crop=False)
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
processor_kwargs = cls.prepare_processor_dict()
processor = LlavaProcessor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(cls.tmpdirname)
@@ -79,3 +80,29 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor = LlavaProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=5,
)

View File

@@ -40,6 +40,7 @@ class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_processor = LlavaNextImageProcessor()
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
processor_kwargs = cls.prepare_processor_dict()
processor = LlavaNextProcessor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(cls.tmpdirname)

View File

@@ -41,6 +41,7 @@ class LlavaNextVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_processor = LlavaNextImageProcessor()
video_processor = LlavaNextVideoImageProcessor()
tokenizer = LlamaTokenizerFast.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>", "<video>"]})
processor_kwargs = cls.prepare_processor_dict()
processor = LlavaNextVideoProcessor(

View File

@@ -45,6 +45,7 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_processor = LlavaOnevisionImageProcessor()
video_processor = LlavaOnevisionVideoProcessor()
tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>", "<video>"]})
processor_kwargs = cls.prepare_processor_dict()
processor = LlavaOnevisionProcessor(

View File

@@ -290,3 +290,29 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 5)
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=3,
)

View File

@@ -360,3 +360,29 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
len(inputs[self.text_input_name][0]) == len(inputs[self.text_input_name][1])
and len(inputs[self.text_input_name][1]) < 76
)
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
image_input = [[image_input[0]], [image_input[1]]]
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=3,
)

View File

@@ -39,6 +39,7 @@ class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
image_processor.image_seq_length = 0 # TODO: raushan fix me in #37342
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(cls.tmpdirname)
cls.image_token = processor.image_token
@@ -59,7 +60,7 @@ class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112 + 14)
self.assertEqual(len(inputs["input_ids"][0]), 112)
def test_text_with_image_tokens(self):
image_processor = self.get_component("image_processor")

View File

@@ -397,3 +397,29 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(inputs[self.images_input_name].shape[0], 100)
inputs = processor(text=input_str, images=image_input, max_pixels=56 * 56 * 4, return_tensors="pt")
self.assertEqual(inputs[self.images_input_name].shape[0], 612)
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=20,
)

View File

@@ -435,7 +435,8 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
@@ -445,14 +446,14 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
text=input_str,
images=image_input,
return_tensors="pt",
padding="longest",
padding="max_length",
max_length=76,
truncation=True,
max_image_size={"longest_edge": 30},
max_image_size={"longest_edge": 300},
)
self.assertEqual(inputs["pixel_values"].shape[2], 3)
self.assertEqual(inputs["pixel_values"].shape[3], 30)
self.assertEqual(inputs["pixel_values"].shape[3], 300)
self.assertEqual(len(inputs["input_ids"][0]), 76)
@require_torch
@@ -529,3 +530,29 @@ class SmolVLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
with self.assertRaises(ValueError) as context:
processor(text=texts, images=None)
self.assertTrue("tokens in the text but no images/videos were passed" in str(context.exception))
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
image_input = self.prepare_image_inputs(batch_size=2)
image_input = [[image_input[0]], [image_input[1]]]
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=20,
)