[chat template] return assistant mask in processors (#38545)
* messed up the git history, squash commits * raise error if slow and refine tests * index was off by one * fix the test
This commit is contained in:
committed by
GitHub
parent
328ca9cf1d
commit
bcc0091937
@@ -15,6 +15,7 @@
|
||||
Processing saving/loading class for common processors.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
@@ -1468,6 +1469,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
# It's a template string, render it directly
|
||||
chat_template = chat_template
|
||||
|
||||
is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast")
|
||||
|
||||
if kwargs.get("continue_final_message", False):
|
||||
if kwargs.get("add_generation_prompt", False):
|
||||
raise ValueError(
|
||||
@@ -1476,6 +1479,15 @@ class ProcessorMixin(PushToHubMixin):
|
||||
if kwargs.get("return_assistant_tokens_mask", False):
|
||||
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
|
||||
|
||||
if kwargs.get("return_assistant_tokens_mask", False):
|
||||
if not is_tokenizers_fast:
|
||||
raise ValueError(
|
||||
"`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. "
|
||||
"If the error persists, open an issue to support a Fast tokenizer for your model."
|
||||
)
|
||||
else:
|
||||
kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries
|
||||
|
||||
# Fill sets of kwargs that should be used by different parts of template
|
||||
processed_kwargs = {
|
||||
"mm_load_kwargs": {},
|
||||
@@ -1605,19 +1617,27 @@ class ProcessorMixin(PushToHubMixin):
|
||||
video_metadata=batch_video_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if return_dict:
|
||||
if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
|
||||
assistant_masks = []
|
||||
offset_mapping = out.pop("offset_mapping")
|
||||
input_ids = out["input_ids"]
|
||||
for i in range(len(input_ids)):
|
||||
current_mask = [0] * len(input_ids[i])
|
||||
offsets = offset_mapping[i]
|
||||
offset_starts = [start for start, end in offsets]
|
||||
for assistant_start_char, assistant_end_char in generation_indices[i]:
|
||||
start_token = out.char_to_token(i, assistant_start_char)
|
||||
end_token = out.char_to_token(i, assistant_end_char - 1)
|
||||
if start_token is None:
|
||||
start_pos = bisect.bisect_left(offset_starts, assistant_start_char)
|
||||
end_pos = bisect.bisect_left(offset_starts, assistant_end_char)
|
||||
|
||||
if not (
|
||||
start_pos >= 0
|
||||
and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1]
|
||||
):
|
||||
# start_token is out of bounds maybe due to truncation.
|
||||
break
|
||||
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
|
||||
continue
|
||||
for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
|
||||
current_mask[token_id] = 1
|
||||
assistant_masks.append(current_mask)
|
||||
out["assistant_masks"] = assistant_masks
|
||||
|
||||
@@ -137,3 +137,8 @@ class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
|
||||
)
|
||||
torch.testing.assert_close(input_ids, expected_ids)
|
||||
|
||||
@require_torch
|
||||
@unittest.skip("CSM doesn't need assistant masks as an audio generation model")
|
||||
def test_apply_chat_template_assistant_mask(self):
|
||||
pass
|
||||
|
||||
@@ -206,3 +206,7 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ShieldGemma requires images in input, and fails in text-only processing")
|
||||
def test_apply_chat_template_assistant_mask(self):
|
||||
pass
|
||||
|
||||
@@ -100,7 +100,11 @@ class ProcessorTesterMixin:
|
||||
assert attribute in self.processor_class.attributes
|
||||
component_class_name = getattr(self.processor_class, f"{attribute}_class")
|
||||
if isinstance(component_class_name, tuple):
|
||||
if attribute == "image_processor":
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
component_class_name = component_class_name[0]
|
||||
else:
|
||||
component_class_name = component_class_name[-1]
|
||||
|
||||
component_class = processor_class_from_name(component_class_name)
|
||||
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
|
||||
@@ -1149,3 +1153,77 @@ class ProcessorTesterMixin:
|
||||
)
|
||||
expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n"
|
||||
self.assertEqual(formatted_prompt, expected_prompt)
|
||||
|
||||
@require_torch
|
||||
def test_apply_chat_template_assistant_mask(self):
|
||||
processor = self.get_processor()
|
||||
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is the capital of France?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The capital of France is Paris."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What about Italy?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The capital of Italy is Rome."},
|
||||
],
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
dummy_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if (message['role'] != 'assistant') %}"
|
||||
"{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
|
||||
"{% elif (message['role'] == 'assistant')%}"
|
||||
"{{'<|special_start|>' + message['role'] + '\n'}}"
|
||||
"{% generation %}"
|
||||
"{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
|
||||
"{% endgeneration %}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
return_assistant_tokens_mask=True,
|
||||
chat_template=dummy_template,
|
||||
)
|
||||
self.assertTrue("assistant_masks" in inputs)
|
||||
self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"]))
|
||||
|
||||
mask = inputs["assistant_masks"].bool()
|
||||
assistant_ids = inputs["input_ids"][mask]
|
||||
|
||||
assistant_text = (
|
||||
"The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n"
|
||||
)
|
||||
|
||||
# Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids
|
||||
# if we can't get identical text outputs
|
||||
text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True)
|
||||
ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist()
|
||||
self.assertTrue(text_is_same or ids_is_same)
|
||||
|
||||
Reference in New Issue
Block a user