[chat template] return assistant mask in processors (#38545)
* messed up the git history, squash commits * raise error if slow and refine tests * index was off by one * fix the test
This commit is contained in:
committed by
GitHub
parent
328ca9cf1d
commit
bcc0091937
@@ -137,3 +137,8 @@ class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
|
||||
)
|
||||
torch.testing.assert_close(input_ids, expected_ids)
|
||||
|
||||
@require_torch
|
||||
@unittest.skip("CSM doesn't need assistant masks as an audio generation model")
|
||||
def test_apply_chat_template_assistant_mask(self):
|
||||
pass
|
||||
|
||||
@@ -206,3 +206,7 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ShieldGemma requires images in input, and fails in text-only processing")
|
||||
def test_apply_chat_template_assistant_mask(self):
|
||||
pass
|
||||
|
||||
@@ -100,7 +100,11 @@ class ProcessorTesterMixin:
|
||||
assert attribute in self.processor_class.attributes
|
||||
component_class_name = getattr(self.processor_class, f"{attribute}_class")
|
||||
if isinstance(component_class_name, tuple):
|
||||
component_class_name = component_class_name[0]
|
||||
if attribute == "image_processor":
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
component_class_name = component_class_name[0]
|
||||
else:
|
||||
component_class_name = component_class_name[-1]
|
||||
|
||||
component_class = processor_class_from_name(component_class_name)
|
||||
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
|
||||
@@ -1149,3 +1153,77 @@ class ProcessorTesterMixin:
|
||||
)
|
||||
expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n"
|
||||
self.assertEqual(formatted_prompt, expected_prompt)
|
||||
|
||||
@require_torch
|
||||
def test_apply_chat_template_assistant_mask(self):
|
||||
processor = self.get_processor()
|
||||
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the capital of France?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The capital of France is Paris."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What about Italy?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The capital of Italy is Rome."},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
dummy_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if (message['role'] != 'assistant') %}"
|
||||
"{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
|
||||
"{% elif (message['role'] == 'assistant')%}"
|
||||
"{{'<|special_start|>' + message['role'] + '\n'}}"
|
||||
"{% generation %}"
|
||||
"{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
|
||||
"{% endgeneration %}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
return_assistant_tokens_mask=True,
|
||||
chat_template=dummy_template,
|
||||
)
|
||||
self.assertTrue("assistant_masks" in inputs)
|
||||
self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"]))
|
||||
|
||||
mask = inputs["assistant_masks"].bool()
|
||||
assistant_ids = inputs["input_ids"][mask]
|
||||
|
||||
assistant_text = (
|
||||
"The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n"
|
||||
)
|
||||
|
||||
# Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids
|
||||
# if we can't get identical text outputs
|
||||
text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True)
|
||||
ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist()
|
||||
self.assertTrue(text_is_same or ids_is_same)
|
||||
|
||||
Reference in New Issue
Block a user