Allow apply_chat_template to pass kwargs to the template and support a dict of templates (#29658)
* Allow apply_chat_template to pass kwargs to the template * Fix priority for template_kwargs * Fix docstring * style fix * Add the option for the model to have a dict of templates * Error message cleanup * Add test for chat template dicts * Simplify the chat template dict test and apply it to all tokenizers in self.get_tokenizers() * Save chat template dicts as lists with fixed key names * Add test for serialization/reloading * Add require_jinja just to be safe, even though I don't think we use it
This commit is contained in:
@@ -1610,6 +1610,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
|
|
||||||
# Stores a Jinja template that formats chat histories into tokenizable strings
|
# Stores a Jinja template that formats chat histories into tokenizable strings
|
||||||
self.chat_template = kwargs.pop("chat_template", None)
|
self.chat_template = kwargs.pop("chat_template", None)
|
||||||
|
if isinstance(self.chat_template, (list, tuple)):
|
||||||
|
# Chat templates are stored as lists of dicts with fixed key names,
|
||||||
|
# we reconstruct that into a single dict while loading them.
|
||||||
|
self.chat_template = {template["name"]: template["template"] for template in self.chat_template}
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@@ -1697,7 +1701,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
**tokenizer_kwargs,
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Union[str, List[int]]:
|
) -> Union[str, List[int]]:
|
||||||
"""
|
"""
|
||||||
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
|
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
|
||||||
@@ -1732,7 +1737,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
return_dict (`bool`, *optional*, defaults to `False`):
|
return_dict (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
|
||||||
**tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
|
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
|
||||||
|
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
|
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
|
||||||
@@ -1743,8 +1749,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# Indicates it's a Conversation object
|
# Indicates it's a Conversation object
|
||||||
conversation = conversation.messages
|
conversation = conversation.messages
|
||||||
|
|
||||||
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template`
|
if tokenizer_kwargs is None:
|
||||||
if chat_template is None:
|
tokenizer_kwargs = {}
|
||||||
|
|
||||||
|
# First, handle the cases when the model has a dict of multiple templates
|
||||||
|
if isinstance(self.chat_template, dict) or (
|
||||||
|
self.chat_template is None and isinstance(self.default_chat_template, dict)
|
||||||
|
):
|
||||||
|
template_dict = self.chat_template or self.default_chat_template
|
||||||
|
if chat_template is not None and chat_template in template_dict:
|
||||||
|
# The user can pass the name of a template to the chat template argument instead of an entire template
|
||||||
|
chat_template = template_dict[chat_template]
|
||||||
|
elif chat_template is None and "default" in template_dict:
|
||||||
|
chat_template = template_dict["default"]
|
||||||
|
elif chat_template is None:
|
||||||
|
raise ValueError(
|
||||||
|
"This model has multiple chat templates with no default specified! Please either pass a chat "
|
||||||
|
"template or the name of the template you wish to use to the `chat_template` argument. Available "
|
||||||
|
f"template names are {sorted(template_dict.keys())}."
|
||||||
|
)
|
||||||
|
elif chat_template is None:
|
||||||
|
# These are the cases when the model has a single template
|
||||||
|
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
|
||||||
if self.chat_template is not None:
|
if self.chat_template is not None:
|
||||||
chat_template = self.chat_template
|
chat_template = self.chat_template
|
||||||
else:
|
else:
|
||||||
@@ -1753,8 +1779,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# Compilation function uses a cache to avoid recompiling the same template
|
# Compilation function uses a cache to avoid recompiling the same template
|
||||||
compiled_template = self._compile_jinja_template(chat_template)
|
compiled_template = self._compile_jinja_template(chat_template)
|
||||||
|
|
||||||
|
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
||||||
rendered = compiled_template.render(
|
rendered = compiled_template.render(
|
||||||
messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
|
messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if padding is True:
|
if padding is True:
|
||||||
@@ -2426,7 +2453,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
tokenizer_config.update(self.special_tokens_map)
|
tokenizer_config.update(self.special_tokens_map)
|
||||||
|
|
||||||
if self.chat_template is not None:
|
if self.chat_template is not None:
|
||||||
tokenizer_config["chat_template"] = self.chat_template
|
if isinstance(self.chat_template, dict):
|
||||||
|
# Chat template dicts are saved to the config as lists of dicts with fixed key names.
|
||||||
|
# They will be reconstructed as a single dict during loading.
|
||||||
|
tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
|
||||||
|
else:
|
||||||
|
tokenizer_config["chat_template"] = self.chat_template
|
||||||
|
|
||||||
if len(self.init_inputs) > 0:
|
if len(self.init_inputs) > 0:
|
||||||
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
|
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
|
||||||
|
|||||||
@@ -1118,6 +1118,52 @@ class TokenizerTesterMixin:
|
|||||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_chat_template_dict(self):
|
||||||
|
dummy_template_1 = "{{'a'}}"
|
||||||
|
dummy_template_2 = "{{'b'}}"
|
||||||
|
dummy_conversation = [
|
||||||
|
{"role": "user", "content": "user message"},
|
||||||
|
]
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||||
|
output1 = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template_1, tokenize=False
|
||||||
|
)
|
||||||
|
output1_via_dict = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template="template1", tokenize=False
|
||||||
|
)
|
||||||
|
self.assertEqual(output1, output1_via_dict)
|
||||||
|
output2 = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template=dummy_template_2, tokenize=False
|
||||||
|
)
|
||||||
|
output2_via_dict = tokenizer.apply_chat_template(
|
||||||
|
dummy_conversation, chat_template="template2", tokenize=False
|
||||||
|
)
|
||||||
|
self.assertEqual(output2, output2_via_dict)
|
||||||
|
|
||||||
|
@require_jinja
|
||||||
|
def test_chat_template_dict_saving(self):
|
||||||
|
dummy_template_1 = "{{'a'}}"
|
||||||
|
dummy_template_2 = "{{'b'}}"
|
||||||
|
tokenizers = self.get_tokenizers()
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
tokenizer.save_pretrained(tmp_dir_name)
|
||||||
|
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||||
|
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||||
|
self.assertEqual(
|
||||||
|
config_dict["chat_template"],
|
||||||
|
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
|
||||||
|
)
|
||||||
|
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||||
|
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||||
|
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||||
|
|
||||||
def test_number_of_added_tokens(self):
|
def test_number_of_added_tokens(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
|||||||
Reference in New Issue
Block a user