Chat template: return vectorized output in processors (#34275)

* update chat template

* style

* fix tests

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* typehints + docs

* fix tests

* remove unnecessary warnings

* forgot code style :(

* allow users to pass backend and num frames

* Update docs/source/en/chat_templating.md

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/image_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* Update src/transformers/processing_utils.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

* typo fix

* style

* address comments

* align with "pipeline" template

* update docs

* update docs

* unpack for all kwargs?

* wrong conflict resolution while rebasing

* tmp

* update docs

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/chat_templating.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-01-10 11:05:29 +01:00
committed by GitHub
parent 5f087d1335
commit e0646f3dce
12 changed files with 880 additions and 46 deletions

View File

@@ -17,8 +17,8 @@ import tempfile
import unittest
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@@ -26,6 +26,9 @@ from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import CLIPImageProcessor
if is_torch_available:
import torch
@require_vision
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
@@ -94,6 +97,55 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_chat_template_dict(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# add image URL for return dict
messages[0]["content"][0] = {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
out_dict_with_image = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True
)
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
@require_torch
def test_chat_template_dict_torch(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values"])
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
def test_chat_template_with_continue_final_message(self):
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"