[chat template] return assistant mask in processors (#38545)

* messed up the git history, squash commits

* raise error if slow and refine tests

* index was off by one

* fix the test
This commit is contained in:
Raushan Turganbay
2025-07-18 14:23:20 +02:00
committed by GitHub
parent 328ca9cf1d
commit bcc0091937
4 changed files with 113 additions and 6 deletions

View File

@@ -15,6 +15,7 @@
Processing saving/loading class for common processors.
"""
import bisect
import copy
import inspect
import json
@@ -1468,6 +1469,8 @@ class ProcessorMixin(PushToHubMixin):
# It's a template string, render it directly
chat_template = chat_template
is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast")
if kwargs.get("continue_final_message", False):
if kwargs.get("add_generation_prompt", False):
raise ValueError(
@@ -1476,6 +1479,15 @@ class ProcessorMixin(PushToHubMixin):
if kwargs.get("return_assistant_tokens_mask", False):
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
if kwargs.get("return_assistant_tokens_mask", False):
if not is_tokenizers_fast:
raise ValueError(
"`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. "
"If the error persists, open an issue to support a Fast tokenizer for your model."
)
else:
kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries
# Fill sets of kwargs that should be used by different parts of template
processed_kwargs = {
"mm_load_kwargs": {},
@@ -1605,19 +1617,27 @@ class ProcessorMixin(PushToHubMixin):
video_metadata=batch_video_metadata,
**kwargs,
)
if return_dict:
if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
assistant_masks = []
offset_mapping = out.pop("offset_mapping")
input_ids = out["input_ids"]
for i in range(len(input_ids)):
current_mask = [0] * len(input_ids[i])
offsets = offset_mapping[i]
offset_starts = [start for start, end in offsets]
for assistant_start_char, assistant_end_char in generation_indices[i]:
start_token = out.char_to_token(i, assistant_start_char)
end_token = out.char_to_token(i, assistant_end_char - 1)
if start_token is None:
start_pos = bisect.bisect_left(offset_starts, assistant_start_char)
end_pos = bisect.bisect_left(offset_starts, assistant_end_char)
if not (
start_pos >= 0
and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1]
):
# start_token is out of bounds maybe due to truncation.
break
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
continue
for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks

View File

@@ -137,3 +137,8 @@ class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
)
torch.testing.assert_close(input_ids, expected_ids)
@require_torch
@unittest.skip("CSM doesn't need assistant masks as an audio generation model")
def test_apply_chat_template_assistant_mask(self):
pass

View File

@@ -206,3 +206,7 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_kwargs_overrides_default_image_processor_kwargs(self):
pass
@unittest.skip("ShieldGemma requires images in input, and fails in text-only processing")
def test_apply_chat_template_assistant_mask(self):
pass

View File

@@ -100,7 +100,11 @@ class ProcessorTesterMixin:
assert attribute in self.processor_class.attributes
component_class_name = getattr(self.processor_class, f"{attribute}_class")
if isinstance(component_class_name, tuple):
component_class_name = component_class_name[0]
if attribute == "image_processor":
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
component_class_name = component_class_name[0]
else:
component_class_name = component_class_name[-1]
component_class = processor_class_from_name(component_class_name)
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
@@ -1149,3 +1153,77 @@ class ProcessorTesterMixin:
)
expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n"
self.assertEqual(formatted_prompt, expected_prompt)
@require_torch
def test_apply_chat_template_assistant_mask(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the capital of France?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The capital of France is Paris."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "What about Italy?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The capital of Italy is Rome."},
],
},
]
]
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|special_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
"{% endgeneration %}"
"{% endif %}"
"{% endfor %}"
)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
return_assistant_tokens_mask=True,
chat_template=dummy_template,
)
self.assertTrue("assistant_masks" in inputs)
self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"]))
mask = inputs["assistant_masks"].bool()
assistant_ids = inputs["input_ids"][mask]
assistant_text = (
"The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n"
)
# Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids
# if we can't get identical text outputs
text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True)
ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist()
self.assertTrue(text_is_same or ids_is_same)