Fix llava_onevision tests (#37280)

* Fix llava_onevision tests

* Trigger tests
This commit is contained in:
Matt
2025-04-04 15:03:38 +01:00
committed by GitHub
parent ad3d157188
commit 8ebc435267

View File

@@ -39,17 +39,19 @@ if is_torch_available:
class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase): class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = LlavaOnevisionProcessor processor_class = LlavaOnevisionProcessor
def setUp(self): @classmethod
self.tmpdirname = tempfile.mkdtemp() def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
cls.addClassCleanup(lambda tempdir=cls.tmpdirname: shutil.rmtree(tempdir))
image_processor = LlavaOnevisionImageProcessor() image_processor = LlavaOnevisionImageProcessor()
video_processor = LlavaOnevisionVideoProcessor() video_processor = LlavaOnevisionVideoProcessor()
tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen2-0.5B-Instruct") tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
processor_kwargs = self.prepare_processor_dict() processor_kwargs = cls.prepare_processor_dict()
processor = LlavaOnevisionProcessor( processor = LlavaOnevisionProcessor(
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs
) )
processor.save_pretrained(self.tmpdirname) processor.save_pretrained(cls.tmpdirname)
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
@@ -60,7 +62,8 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def get_video_processor(self, **kwargs): def get_video_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
def prepare_processor_dict(self): @staticmethod
def prepare_processor_dict():
return { return {
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ '\n' + content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ '\n' + content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"num_image_tokens": 6, "num_image_tokens": 6,
@@ -88,9 +91,6 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_dict = self.prepare_processor_dict() processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_chat_template(self): def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf") processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
expected_prompt = "<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" expected_prompt = "<|im_start|>user <image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"