Chat template: update for processor (#35953)

* update

* we need batched nested input to always process correctly

* update a bit

* fix copies
This commit is contained in:
Raushan Turganbay
2025-02-10 09:52:19 +01:00
committed by GitHub
parent 5bd7694781
commit eebd2c972c
21 changed files with 966 additions and 111 deletions

View File

@@ -16,7 +16,7 @@ import shutil
import tempfile
import unittest
from transformers.testing_utils import require_av, require_torch, require_vision
from transformers.testing_utils import require_av, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@@ -32,7 +32,7 @@ if is_vision_available():
)
if is_torch_available:
import torch
pass
@require_vision
@@ -61,7 +61,11 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
def prepare_processor_dict(self):
return {"chat_template": "dummy_template", "num_image_tokens": 6, "vision_feature_select_strategy": "default"}
return {
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + ' '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>' }}{% endfor %}{# Render all video then #}{% for content in message['content'] | selectattr('type', 'equalto', 'video') %}{{ '<video>' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ '\n' + content['text'] }}{% endgeneration %}{% endfor %}{% endif %}{{'<|im_end|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"num_image_tokens": 6,
"vision_feature_select_strategy": "default"
} # fmt: skip
def test_processor_to_json_string(self):
processor = self.get_processor()
@@ -133,30 +137,3 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
@require_torch
@require_av
def test_chat_template_dict_torch(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))