Chat template: update for processor (#35953)

* update

* we need batched nested input to always process correctly

* update a bit

* fix copies
This commit is contained in:
Raushan Turganbay
2025-02-10 09:52:19 +01:00
committed by GitHub
parent 5bd7694781
commit eebd2c972c
21 changed files with 966 additions and 111 deletions

View File

@@ -27,10 +27,11 @@ from transformers.models.auto.processing_auto import processor_class_from_name
from transformers.processing_utils import Unpack
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_av,
require_torch,
require_vision,
)
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available
global_rng = random.Random()
@@ -38,6 +39,9 @@ global_rng = random.Random()
if is_vision_available():
from PIL import Image
if is_torch_available():
import torch
def prepare_image_inputs():
"""This function prepares a list of PIL images"""
@@ -131,8 +135,10 @@ class ProcessorTesterMixin:
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
for key, value in self.prepare_processor_dict().items():
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)
# Chat template is saved as a separate file
if key not in "chat_template":
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)
def test_processor_from_and_save_pretrained(self):
processor_first = self.get_processor()
@@ -532,6 +538,10 @@ class ProcessorTesterMixin:
def test_chat_template_save_loading(self):
processor = self.get_processor()
signature = inspect.signature(processor.__call__)
if "chat_template" not in {*signature.parameters.keys()}:
self.skipTest("Processor doesn't accept chat templates at input")
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
processor.chat_template = "test template"
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -553,3 +563,298 @@ class ProcessorTesterMixin:
# When we save as single files, tokenizers and processors share a chat template, which means
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
def test_chat_template_single(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1)
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
expected_output = processor.tokenizer(
formatted_prompt, return_tensors=None, add_special_tokens=add_special_tokens
).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1)
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 1)
def test_chat_template_batched(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
batched_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What do you see?"},
],
},
],
]
formatted_prompt = processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 2)
formatted_prompt_tokenized = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, padding=True
)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
expected_output = processor.tokenizer(
formatted_prompt,
return_tensors=None,
padding=True,
add_special_tokens=add_special_tokens,
).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# Now test the ability to return dict
batched_messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
batched_messages[1][0]["content"].append(
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
)
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 2)
self.assertEqual(len(out_dict["attention_mask"]), 2)
self.assertEqual(len(out_dict[self.images_input_name]), 2)
def test_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
max_length=50,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
truncation=True,
max_length=5,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
# Now test the ability to return dict
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
do_rescale=True,
rescale_factor=-1,
return_tensors="np",
)
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
@require_torch
def test_chat_template_dict_torch(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.images_input_name in out_dict_tensors)
for k in out_dict_tensors:
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)
@require_av
def test_chat_template_video(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
signature = inspect.signature(processor.__call__)
if "videos" not in {*signature.parameters.keys()} or (
signature.parameters.get("videos") is not None
and signature.parameters["videos"].annotation == inspect._empty
):
self.skipTest("Processor doesn't accept videos at input")
messages = [
[
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": "What is shown in this video?"},
],
},
]
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1)
formatted_prompt_tokenized = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
add_special_tokens = True
if processor.tokenizer.bos_token is not None and formatted_prompt[0].startswith(processor.tokenizer.bos_token):
add_special_tokens = False
expected_output = processor.tokenizer(
formatted_prompt,
return_tensors=None,
add_special_tokens=add_special_tokens,
).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# Add video URL for return dict and load with `num_frames` arg
messages[0][0]["content"][0] = {
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
}
num_frames = 3
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
num_frames=num_frames,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), num_frames)
# Load with `video_fps` arg
video_fps = 1
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), video_fps * 10)
# Load with `video_fps` and `num_frames` args, should raise an error
with self.assertRaises(ValueError):
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
video_fps=video_fps,
num_frames=num_frames,
)
# Load without any arg should load the whole video
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 300)
# Load video as a list of frames (i.e. images). NOTE: each frame should have same size
# because we assume they come from one video
messages[0][0]["content"][0] = {
"type": "video",
"url": [
"https://www.ilankelman.org/stopsigns/australia.jpg",
"https://www.ilankelman.org/stopsigns/australia.jpg",
],
}
out_dict_with_video = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
)
self.assertTrue(self.videos_input_name in out_dict_with_video)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 2)