Chat template: return vectorized output in processors (#34275)
* update chat template * style * fix tests * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * typehints + docs * fix tests * remove unnecessary warnings * forgot code style :( * allow users to pass backend and num frames * Update docs/source/en/chat_templating.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/processing_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * typo fix * style * address comments * align with "pipeline" template * update docs * update docs * unpack for all kwargs? * wrong conflict resolution while rebasing * tmp * update docs * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5f087d1335
commit
e0646f3dce
@@ -17,8 +17,8 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
@@ -26,6 +26,9 @@ from ...test_processing_common import ProcessorTesterMixin
|
||||
if is_vision_available():
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
if is_torch_available:
|
||||
import torch
|
||||
|
||||
|
||||
@require_vision
|
||||
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
@@ -94,6 +97,55 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
def test_chat_template_dict(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# add image URL for return dict
|
||||
messages[0]["content"][0] = {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
out_dict_with_image = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True
|
||||
)
|
||||
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||
|
||||
@require_torch
|
||||
def test_chat_template_dict_torch(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
out_dict_tensors = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
|
||||
|
||||
def test_chat_template_with_continue_final_message(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
|
||||
|
||||
@@ -16,8 +16,8 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.testing_utils import require_av, require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
@@ -31,6 +31,9 @@ if is_vision_available():
|
||||
Qwen2TokenizerFast,
|
||||
)
|
||||
|
||||
if is_torch_available:
|
||||
import torch
|
||||
|
||||
|
||||
@require_vision
|
||||
class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
@@ -100,3 +103,60 @@ class LlavaOnevisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
@require_av
|
||||
def test_chat_template_dict(self):
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "video"},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = [[151644, 872, 220, 151647, 198, 3838, 374, 6839, 304, 419, 2766, 30, 151645, 151644, 77091, 198]] # fmt: skip
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# add image URL for return dict
|
||||
messages[0]["content"][0] = {
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
}
|
||||
out_dict_with_video = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True
|
||||
)
|
||||
self.assertListEqual(list(out_dict_with_video.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
|
||||
|
||||
@require_torch
|
||||
@require_av
|
||||
def test_chat_template_dict_torch(self):
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this video?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
out_dict_tensors = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values_videos"])
|
||||
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
|
||||
|
||||
@@ -110,32 +110,34 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
]
|
||||
input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_ids = [
|
||||
128000, # <|begin_of_text|>
|
||||
128006, # <|start_header_id|>
|
||||
9125, # "system"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
2028,
|
||||
374,
|
||||
264,
|
||||
1296,
|
||||
11914,
|
||||
13, # "This is a test sentence."
|
||||
128009, # <|eot_id|>
|
||||
128006, # <|start_header_id|>
|
||||
882, # "user"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
2028,
|
||||
374,
|
||||
264,
|
||||
2077,
|
||||
13, # "This is a response.",
|
||||
128009, # <|eot_id|>
|
||||
128006, # <|start_header_id|>
|
||||
78191, # "assistant"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
[
|
||||
128000, # <|begin_of_text|>
|
||||
128006, # <|start_header_id|>
|
||||
9125, # "system"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
2028,
|
||||
374,
|
||||
264,
|
||||
1296,
|
||||
11914,
|
||||
13, # "This is a test sentence."
|
||||
128009, # <|eot_id|>
|
||||
128006, # <|start_header_id|>
|
||||
882, # "user"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
2028,
|
||||
374,
|
||||
264,
|
||||
2077,
|
||||
13, # "This is a response.",
|
||||
128009, # <|eot_id|>
|
||||
128006, # <|start_header_id|>
|
||||
78191, # "assistant"
|
||||
128007, # <|end_of_header|>
|
||||
271, # "\n\n"
|
||||
]
|
||||
]
|
||||
|
||||
self.assertEqual(input_ids, expected_ids)
|
||||
@@ -146,9 +148,9 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image in two sentences"},
|
||||
{"type": "image"},
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": " Test sentence "},
|
||||
{"type": "image"},
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "ok\n"},
|
||||
],
|
||||
}
|
||||
@@ -164,10 +166,10 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
# fmt: off
|
||||
expected_ids = [
|
||||
expected_ids = [[
|
||||
128000, 128006, 882, 128007, 271, 75885, 420, 2217, 304, 1403, 23719, 128256,
|
||||
3475, 11914, 262, 128256, 564, 198, 128009, 128006, 78191, 128007, 271,
|
||||
]
|
||||
]]
|
||||
# fmt: on
|
||||
self.assertEqual(input_ids, expected_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user