Allow apply_chat_template to pass kwargs to the template and support a dict of templates (#29658)

* Allow apply_chat_template to pass kwargs to the template

* Fix priority for template_kwargs

* Fix docstring

* style fix

* Add the option for the model to have a dict of templates

* Error message cleanup

* Add test for chat template dicts

* Simplify the chat template dict test and apply it to all tokenizers in self.get_tokenizers()

* Save chat template dicts as lists with fixed key names

* Add test for serialization/reloading

* Add require_jinja just to be safe, even though I don't think we use it
This commit is contained in:
Matt
2024-03-14 18:23:14 +00:00
committed by GitHub
parent 23db187d92
commit 48fbab7330
2 changed files with 84 additions and 6 deletions

View File

@@ -1118,6 +1118,52 @@ class TokenizerTesterMixin:
self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
dummy_template_2 = "{{'b'}}"
dummy_conversation = [
{"role": "user", "content": "user message"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
output1 = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template_1, tokenize=False
)
output1_via_dict = tokenizer.apply_chat_template(
dummy_conversation, chat_template="template1", tokenize=False
)
self.assertEqual(output1, output1_via_dict)
output2 = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template_2, tokenize=False
)
output2_via_dict = tokenizer.apply_chat_template(
dummy_conversation, chat_template="template2", tokenize=False
)
self.assertEqual(output2, output2_via_dict)
@require_jinja
def test_chat_template_dict_saving(self):
dummy_template_1 = "{{'a'}}"
dummy_template_2 = "{{'b'}}"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: