[vlm] adjust max length for special tokens (#37342)
* update * apply suggestion * fix tests for main branch * remove unused logger * add special tokens in tests * nit * fix more tests * fix test * pg also
This commit is contained in:
committed by
GitHub
parent
c94c59fc47
commit
32eca7197a
@@ -40,6 +40,7 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
image_processor = CLIPImageProcessor(do_center_crop=False)
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
processor = LlavaProcessor(image_processor, tokenizer, **processor_kwargs)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
@@ -79,3 +80,29 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor = LlavaProcessor.from_pretrained(checkpoint)
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)
|
||||
|
||||
def test_special_mm_token_truncation(self):
|
||||
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
|
||||
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
|
||||
image_input = self.prepare_image_inputs(batch_size=2)
|
||||
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
truncation=None,
|
||||
padding=True,
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding=True,
|
||||
max_length=5,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user