[chat template] return assistant mask in processors (#38545)

* messed up the git history, squash commits

* raise error if slow and refine tests

* index was off by one

* fix the test
This commit is contained in:
Raushan Turganbay
2025-07-18 14:23:20 +02:00
committed by GitHub
parent 328ca9cf1d
commit bcc0091937
4 changed files with 113 additions and 6 deletions

View File

@@ -137,3 +137,8 @@ class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
)
torch.testing.assert_close(input_ids, expected_ids)
@require_torch
@unittest.skip("CSM doesn't need assistant masks as an audio generation model")
def test_apply_chat_template_assistant_mask(self):
pass

View File

@@ -206,3 +206,7 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_kwargs_overrides_default_image_processor_kwargs(self):
pass
@unittest.skip("ShieldGemma requires images in input, and fails in text-only processing")
def test_apply_chat_template_assistant_mask(self):
pass

View File

@@ -100,7 +100,11 @@ class ProcessorTesterMixin:
assert attribute in self.processor_class.attributes
component_class_name = getattr(self.processor_class, f"{attribute}_class")
if isinstance(component_class_name, tuple):
component_class_name = component_class_name[0]
if attribute == "image_processor":
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
component_class_name = component_class_name[0]
else:
component_class_name = component_class_name[-1]
component_class = processor_class_from_name(component_class_name)
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
@@ -1149,3 +1153,77 @@ class ProcessorTesterMixin:
)
expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n"
self.assertEqual(formatted_prompt, expected_prompt)
@require_torch
def test_apply_chat_template_assistant_mask(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the capital of France?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The capital of France is Paris."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "What about Italy?"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The capital of Italy is Rome."},
],
},
]
]
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|special_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
"{% endgeneration %}"
"{% endif %}"
"{% endfor %}"
)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
return_assistant_tokens_mask=True,
chat_template=dummy_template,
)
self.assertTrue("assistant_masks" in inputs)
self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"]))
mask = inputs["assistant_masks"].bool()
assistant_ids = inputs["input_ids"][mask]
assistant_text = (
"The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n"
)
# Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids
# if we can't get identical text outputs
text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True)
ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist()
self.assertTrue(text_is_same or ids_is_same)