[chat template] return assistant mask in processors (#38545)
* messed up the git history, squash commits * raise error if slow and refine tests * index was off by one * fix the test
This commit is contained in:
committed by
GitHub
parent
328ca9cf1d
commit
bcc0091937
@@ -15,6 +15,7 @@
|
|||||||
Processing saving/loading class for common processors.
|
Processing saving/loading class for common processors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import bisect
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
@@ -1468,6 +1469,8 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
# It's a template string, render it directly
|
# It's a template string, render it directly
|
||||||
chat_template = chat_template
|
chat_template = chat_template
|
||||||
|
|
||||||
|
is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast")
|
||||||
|
|
||||||
if kwargs.get("continue_final_message", False):
|
if kwargs.get("continue_final_message", False):
|
||||||
if kwargs.get("add_generation_prompt", False):
|
if kwargs.get("add_generation_prompt", False):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1476,6 +1479,15 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
if kwargs.get("return_assistant_tokens_mask", False):
|
if kwargs.get("return_assistant_tokens_mask", False):
|
||||||
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
|
raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
|
||||||
|
|
||||||
|
if kwargs.get("return_assistant_tokens_mask", False):
|
||||||
|
if not is_tokenizers_fast:
|
||||||
|
raise ValueError(
|
||||||
|
"`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. "
|
||||||
|
"If the error persists, open an issue to support a Fast tokenizer for your model."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kwargs["return_offsets_mapping"] = True # force offset mapping so we can infer token boundaries
|
||||||
|
|
||||||
# Fill sets of kwargs that should be used by different parts of template
|
# Fill sets of kwargs that should be used by different parts of template
|
||||||
processed_kwargs = {
|
processed_kwargs = {
|
||||||
"mm_load_kwargs": {},
|
"mm_load_kwargs": {},
|
||||||
@@ -1605,19 +1617,27 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
video_metadata=batch_video_metadata,
|
video_metadata=batch_video_metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict:
|
if return_dict:
|
||||||
if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
|
if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
|
||||||
assistant_masks = []
|
assistant_masks = []
|
||||||
|
offset_mapping = out.pop("offset_mapping")
|
||||||
input_ids = out["input_ids"]
|
input_ids = out["input_ids"]
|
||||||
for i in range(len(input_ids)):
|
for i in range(len(input_ids)):
|
||||||
current_mask = [0] * len(input_ids[i])
|
current_mask = [0] * len(input_ids[i])
|
||||||
|
offsets = offset_mapping[i]
|
||||||
|
offset_starts = [start for start, end in offsets]
|
||||||
for assistant_start_char, assistant_end_char in generation_indices[i]:
|
for assistant_start_char, assistant_end_char in generation_indices[i]:
|
||||||
start_token = out.char_to_token(i, assistant_start_char)
|
start_pos = bisect.bisect_left(offset_starts, assistant_start_char)
|
||||||
end_token = out.char_to_token(i, assistant_end_char - 1)
|
end_pos = bisect.bisect_left(offset_starts, assistant_end_char)
|
||||||
if start_token is None:
|
|
||||||
|
if not (
|
||||||
|
start_pos >= 0
|
||||||
|
and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1]
|
||||||
|
):
|
||||||
# start_token is out of bounds maybe due to truncation.
|
# start_token is out of bounds maybe due to truncation.
|
||||||
break
|
continue
|
||||||
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
|
for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
|
||||||
current_mask[token_id] = 1
|
current_mask[token_id] = 1
|
||||||
assistant_masks.append(current_mask)
|
assistant_masks.append(current_mask)
|
||||||
out["assistant_masks"] = assistant_masks
|
out["assistant_masks"] = assistant_masks
|
||||||
|
|||||||
@@ -137,3 +137,8 @@ class CsmProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
|
[[128000, 58, 15, 60, 2028, 374, 264, 1296, 11914, 13, 128001, 128002, 128002, 128002, 128003]]
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(input_ids, expected_ids)
|
torch.testing.assert_close(input_ids, expected_ids)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@unittest.skip("CSM doesn't need assistant masks as an audio generation model")
|
||||||
|
def test_apply_chat_template_assistant_mask(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -206,3 +206,7 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
|
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
|
||||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("ShieldGemma requires images in input, and fails in text-only processing")
|
||||||
|
def test_apply_chat_template_assistant_mask(self):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -100,7 +100,11 @@ class ProcessorTesterMixin:
|
|||||||
assert attribute in self.processor_class.attributes
|
assert attribute in self.processor_class.attributes
|
||||||
component_class_name = getattr(self.processor_class, f"{attribute}_class")
|
component_class_name = getattr(self.processor_class, f"{attribute}_class")
|
||||||
if isinstance(component_class_name, tuple):
|
if isinstance(component_class_name, tuple):
|
||||||
component_class_name = component_class_name[0]
|
if attribute == "image_processor":
|
||||||
|
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||||
|
component_class_name = component_class_name[0]
|
||||||
|
else:
|
||||||
|
component_class_name = component_class_name[-1]
|
||||||
|
|
||||||
component_class = processor_class_from_name(component_class_name)
|
component_class = processor_class_from_name(component_class_name)
|
||||||
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
|
component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa
|
||||||
@@ -1149,3 +1153,77 @@ class ProcessorTesterMixin:
|
|||||||
)
|
)
|
||||||
expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n"
|
expected_prompt = "You are a helpful assistant.<|special_start|>user\nWhich of these animals is making the sound?<|special_end|>\nYou are a helpful assistant.<|special_start|>assistant\nIt is a cow.<|special_end|>\n"
|
||||||
self.assertEqual(formatted_prompt, expected_prompt)
|
self.assertEqual(formatted_prompt, expected_prompt)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_apply_chat_template_assistant_mask(self):
|
||||||
|
processor = self.get_processor()
|
||||||
|
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is the capital of France?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The capital of France is Paris."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What about Italy?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The capital of Italy is Rome."},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
dummy_template = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if (message['role'] != 'assistant') %}"
|
||||||
|
"{{'<|special_start|>' + message['role'] + '\n' + message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
|
||||||
|
"{% elif (message['role'] == 'assistant')%}"
|
||||||
|
"{{'<|special_start|>' + message['role'] + '\n'}}"
|
||||||
|
"{% generation %}"
|
||||||
|
"{{message['content'][0]['text'] + '<|special_end|>' + '\n'}}"
|
||||||
|
"{% endgeneration %}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_assistant_tokens_mask=True,
|
||||||
|
chat_template=dummy_template,
|
||||||
|
)
|
||||||
|
self.assertTrue("assistant_masks" in inputs)
|
||||||
|
self.assertEqual(len(inputs["assistant_masks"]), len(inputs["input_ids"]))
|
||||||
|
|
||||||
|
mask = inputs["assistant_masks"].bool()
|
||||||
|
assistant_ids = inputs["input_ids"][mask]
|
||||||
|
|
||||||
|
assistant_text = (
|
||||||
|
"The capital of France is Paris.<|special_end|>\nThe capital of Italy is Rome.<|special_end|>\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some tokenizers add extra spaces which aren't then removed when decoding, so we need to check token ids
|
||||||
|
# if we can't get identical text outputs
|
||||||
|
text_is_same = assistant_text == processor.decode(assistant_ids, clean_up_tokenization_spaces=True)
|
||||||
|
ids_is_same = processor.tokenizer.encode(assistant_text, add_special_tokens=False), assistant_ids.tolist()
|
||||||
|
self.assertTrue(text_is_same or ids_is_same)
|
||||||
|
|||||||
Reference in New Issue
Block a user