Overhaul Conversation class and prompt templating (#25323)

* First commit while I figure this out

* make fixup

* Remove unused method

* Store prompt attrib

* Fix prompt argument for tests

* Make same changes in fast tokenizer

* Remove global prompts from fast tokenizer too

* stash commit

* stash commit

* Migrate PromptConfig to its True Final Location

* Replace Conversation entirely with the new class

* Import/dependency fixes

* Import/dependency fixes

* Change format for lots of default prompts

* More default prompt fixups

* Revert llama old methods so we can compare

* Fix some default configs

* Fix some default configs

* Fix misspelled kwarg

* Fixes for Blenderbot

* make fixup

* little rebase cleanup

* Add basic documentation

* Quick doc fix

* Truncate docstring for now

* Add handling for the case when messages is a single string

* Quick llama merges

* Update conversational pipeline and tests

* Add a couple of legacy properties for backward compatibility

* More legacy handling

* Add docstring for build_conversation_input_ids

* Restructure PromptConfig

* Let's start T E M P L A T I N G

* Refactor all default configs to use templates instead

* Revert changes to the special token properties since we don't need them anymore

* More class templates

* Make the sandbox even sandier

* Everything replaced with pure templating

* Remove docs for PromptConfig

* Add testing and optional requirement boilerplate

* Fix imports and make fixup

* Fix LLaMA tests and add Conversation docstring

* Finally get LLaMA working with the template system

* Finally get LLaMA working with the template system

* make fixup

* make fixup

* fmt-off for the long lists of test tokens

* Rename method to apply_chat_template for now

* Start on documentation

* Make chat_template a property that reads through to the default if it's not set

* Expand docs

* Expand chat templating doc some more

* trim/lstrip blocks by default and update doc

* Few doc tweaks

* rebase cleanup

* Clarify docstring

* rebase cleanup

* rebase cleanup

* make fixup

* Quick doc edit

* Reformat the standard template to match ChatML

* Re-add PEFT check

* Update docs/source/en/chat_templating.md

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add apply_chat_template to the tokenizer doc

* make fixup

* Add doc links

* Fix chat links

* Fix chat links

* Explain system messages in the doc

* Add chat template test

* Proper save-loading for chat template attribute

* Add test skips for layout models

* Remove _build_conversation_input_ids, add default_chat_template to code_llama

* Make sure all LLaMA models are using the latest template

* Remove default_system_prompt block in code_llama because it has no default prompt

* Update ConversationPipeline preprocess

* Add correct #Copied from links to the default_chat_templates

* Remove unneeded type checking line

* Add a dummy mark_processsed method

* Reorganize Conversation to have **deprecated_kwargs

* Update chat_templating.md

* Quick fix to LLAMA tests

* Small doc tweaks

* Add proper docstrings and "copied from" statements to all default chat templates

* Merge use_default_system_prompt support for code_llama too

* Improve clarity around self.chat_template

* Docstring fix

* Fix blenderbot default template

* More doctest fix

* Break out some tokenizer kwargs

* Update doc to explain default templates

* Quick tweaks to tokenizer args

* Cleanups for tokenizer args

* Add note about cacheing

* Quick tweak to the chat-templating doc

* Update the LLaMA template with error checking and correct system message embedding

* make fixup

* make fixup

* add requires_jinja

* Cleanup to expected output formatting

* Add cacheing

* Fix typo in llama default template

* Update LLaMA tests

* Update documentation

* Improved legacy handling in the Conversation class

* Update Jinja template with proper error handling

* Quick bugfix

* Proper exception raising

* Change cacheing behaviour so it doesn't try to pickle an entire Jinja env

* make fixup

* rebase cleanup

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Matt
2023-09-14 15:10:34 +01:00
committed by GitHub
parent 7c63e6fc8c
commit 866df66fe4
39 changed files with 1051 additions and 598 deletions

View File

@@ -50,6 +50,7 @@ from transformers.testing_utils import (
check_json_file_has_correct_format,
get_tests_dir,
is_pt_tf_cross_test,
require_jinja,
require_tf,
require_tokenizers,
require_torch,
@@ -1052,6 +1053,40 @@ class TokenizerTesterMixin:
if tokenizer.num_special_tokens_to_add(pair=True):
self.assertIn(None, output.sequence_ids())
@require_jinja
def test_chat_template(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
expected_output = "systemsystem messageuseruser messageassistantassistant message"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False
)
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: