Chat template: return vectorized output in processors (#34275)
* update chat template * style * fix tests * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * typehints + docs * fix tests * remove unnecessary warnings * forgot code style :( * allow users to pass backend and num frames * Update docs/source/en/chat_templating.md Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/image_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/processing_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * typo fix * style * address comments * align with "pipeline" template * update docs * update docs * unpack for all kwargs? * wrong conflict resolution while rebasing * tmp * update docs * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5f087d1335
commit
e0646f3dce
@@ -17,8 +17,8 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
@@ -26,6 +26,9 @@ from ...test_processing_common import ProcessorTesterMixin
|
||||
if is_vision_available():
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
if is_torch_available:
|
||||
import torch
|
||||
|
||||
|
||||
@require_vision
|
||||
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
@@ -94,6 +97,55 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
def test_chat_template_dict(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
|
||||
expected_output = [[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]] # fmt: skip
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||||
|
||||
# add image URL for return dict
|
||||
messages[0]["content"][0] = {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||
out_dict_with_image = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True
|
||||
)
|
||||
self.assertListEqual(list(out_dict_with_image.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||
|
||||
@require_torch
|
||||
def test_chat_template_dict_torch(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
out_dict_tensors = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
self.assertListEqual(list(out_dict_tensors.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
||||
self.assertTrue(isinstance(out_dict_tensors["input_ids"], torch.Tensor))
|
||||
|
||||
def test_chat_template_with_continue_final_message(self):
|
||||
processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
expected_prompt = "USER: <image>\nDescribe this image. ASSISTANT: There is a dog and"
|
||||
|
||||
Reference in New Issue
Block a user