Allow apply_chat_template to pass kwargs to the template and support a dict of templates (#29658)
* Allow apply_chat_template to pass kwargs to the template * Fix priority for template_kwargs * Fix docstring * style fix * Add the option for the model to have a dict of templates * Error message cleanup * Add test for chat template dicts * Simplify the chat template dict test and apply it to all tokenizers in self.get_tokenizers() * Save chat template dicts as lists with fixed key names * Add test for serialization/reloading * Add require_jinja just to be safe, even though I don't think we use it
This commit is contained in:
@@ -1118,6 +1118,52 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
dummy_template_2 = "{{'b'}}"
|
||||
dummy_conversation = [
|
||||
{"role": "user", "content": "user message"},
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
output1 = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template_1, tokenize=False
|
||||
)
|
||||
output1_via_dict = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template="template1", tokenize=False
|
||||
)
|
||||
self.assertEqual(output1, output1_via_dict)
|
||||
output2 = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template_2, tokenize=False
|
||||
)
|
||||
output2_via_dict = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template="template2", tokenize=False
|
||||
)
|
||||
self.assertEqual(output2, output2_via_dict)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict_saving(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
dummy_template_2 = "{{'b'}}"
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
|
||||
)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
|
||||
def test_number_of_added_tokens(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
|
||||
Reference in New Issue
Block a user