Overhaul Conversation class and prompt templating (#25323)

* First commit while I figure this out

* make fixup

* Remove unused method

* Store prompt attrib

* Fix prompt argument for tests

* Make same changes in fast tokenizer

* Remove global prompts from fast tokenizer too

* stash commit

* stash commit

* Migrate PromptConfig to its True Final Location

* Replace Conversation entirely with the new class

* Import/dependency fixes

* Import/dependency fixes

* Change format for lots of default prompts

* More default prompt fixups

* Revert llama old methods so we can compare

* Fix some default configs

* Fix some default configs

* Fix misspelled kwarg

* Fixes for Blenderbot

* make fixup

* little rebase cleanup

* Add basic documentation

* Quick doc fix

* Truncate docstring for now

* Add handling for the case when messages is a single string

* Quick llama merges

* Update conversational pipeline and tests

* Add a couple of legacy properties for backward compatibility

* More legacy handling

* Add docstring for build_conversation_input_ids

* Restructure PromptConfig

* Let's start T E M P L A T I N G

* Refactor all default configs to use templates instead

* Revert changes to the special token properties since we don't need them anymore

* More class templates

* Make the sandbox even sandier

* Everything replaced with pure templating

* Remove docs for PromptConfig

* Add testing and optional requirement boilerplate

* Fix imports and make fixup

* Fix LLaMA tests and add Conversation docstring

* Finally get LLaMA working with the template system

* Finally get LLaMA working with the template system

* make fixup

* make fixup

* fmt-off for the long lists of test tokens

* Rename method to apply_chat_template for now

* Start on documentation

* Make chat_template a property that reads through to the default if it's not set

* Expand docs

* Expand chat templating doc some more

* trim/lstrip blocks by default and update doc

* Few doc tweaks

* rebase cleanup

* Clarify docstring

* rebase cleanup

* rebase cleanup

* make fixup

* Quick doc edit

* Reformat the standard template to match ChatML

* Re-add PEFT check

* Update docs/source/en/chat_templating.md

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add apply_chat_template to the tokenizer doc

* make fixup

* Add doc links

* Fix chat links

* Fix chat links

* Explain system messages in the doc

* Add chat template test

* Proper save-loading for chat template attribute

* Add test skips for layout models

* Remove _build_conversation_input_ids, add default_chat_template to code_llama

* Make sure all LLaMA models are using the latest template

* Remove default_system_prompt block in code_llama because it has no default prompt

* Update ConversationPipeline preprocess

* Add correct #Copied from links to the default_chat_templates

* Remove unneeded type checking line

* Add a dummy mark_processsed method

* Reorganize Conversation to have **deprecated_kwargs

* Update chat_templating.md

* Quick fix to LLAMA tests

* Small doc tweaks

* Add proper docstrings and "copied from" statements to all default chat templates

* Merge use_default_system_prompt support for code_llama too

* Improve clarity around self.chat_template

* Docstring fix

* Fix blenderbot default template

* More doctest fix

* Break out some tokenizer kwargs

* Update doc to explain default templates

* Quick tweaks to tokenizer args

* Cleanups for tokenizer args

* Add note about cacheing

* Quick tweak to the chat-templating doc

* Update the LLaMA template with error checking and correct system message embedding

* make fixup

* make fixup

* add requires_jinja

* Cleanup to expected output formatting

* Add cacheing

* Fix typo in llama default template

* Update LLaMA tests

* Update documentation

* Improved legacy handling in the Conversation class

* Update Jinja template with proper error handling

* Quick bugfix

* Proper exception raising

* Change cacheing behaviour so it doesn't try to pickle an entire Jinja env

* make fixup

* rebase cleanup

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Matt
2023-09-14 15:10:34 +01:00
committed by GitHub
parent 7c63e6fc8c
commit 866df66fe4
39 changed files with 1051 additions and 598 deletions

View File

@@ -17,6 +17,7 @@
import unittest
from transformers import BlenderbotTokenizer, BlenderbotTokenizerFast
from transformers.testing_utils import require_jinja
from transformers.utils import cached_property
@@ -50,3 +51,24 @@ class Blenderbot3BTokenizerTests(unittest.TestCase):
def test_3B_tokenization_same_as_parlai_rust_tokenizer(self):
assert self.rust_tokenizer_3b.add_prefix_space
assert self.rust_tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]]
@require_jinja
def test_tokenization_for_chat(self):
tok = self.tokenizer_3b
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tok.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 2],
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 228, 3490, 287, 2273, 304, 21, 2],
[3490, 287, 2273, 304, 21, 228, 228, 6950, 8, 2],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)

View File

@@ -18,7 +18,7 @@ import unittest
from datasets import load_dataset
from transformers import BloomTokenizerFast
from transformers.testing_utils import require_tokenizers
from transformers.testing_utils import require_jinja, require_tokenizers
from ...test_tokenization_common import TokenizerTesterMixin
@@ -134,6 +134,27 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = self.get_rust_tokenizer()
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2],
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2],
[229126, 427, 11890, 1152, 17, 2, 59414, 4, 2],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
def test_add_prefix_space_fast(self):
tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True)
tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False)

View File

@@ -20,7 +20,7 @@ import unittest
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
from transformers.testing_utils import require_jinja, require_tokenizers
from ...test_tokenization_common import TokenizerTesterMixin
@@ -275,6 +275,27 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname)
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20],
[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20, 20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20],
[20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20, 20, 3, 0, 0, 1, 20, 20]]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
@require_tokenizers
class OPTTokenizationTest(unittest.TestCase):

View File

@@ -16,7 +16,7 @@
import unittest
from transformers import GPTSw3Tokenizer
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.testing_utils import get_tests_dir, require_jinja, require_sentencepiece, require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -128,3 +128,27 @@ class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
model_name="AI-Sweden/gpt-sw3-126m",
sequences=sequences,
)
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = GPTSw3Tokenizer(SAMPLE_VOCAB)
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ],
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 938, 541, 419, ],
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ]
]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)

View File

@@ -22,7 +22,7 @@ from transformers.models.gptsan_japanese.tokenization_gptsan_japanese import (
VOCAB_FILES_NAMES,
GPTSanJapaneseTokenizer,
)
from transformers.testing_utils import require_tokenizers, slow
from transformers.testing_utils import require_jinja, require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -193,3 +193,27 @@ class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_padding_different_model_input_name(self):
# tokenizer has no padding token
pass
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999, 35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999],
[35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)

View File

@@ -2486,3 +2486,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip("Doesn't support another framework than PyTorch")
def test_np_encode_plus_sent_to_model(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

View File

@@ -2439,3 +2439,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# This should not fail
model(encoded_sequence)
model(batch_encoded_sequence)
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

View File

@@ -1958,3 +1958,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip("Doesn't use SentencePiece")
def test_sentencepiece_tokenize_and_decode(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

View File

@@ -32,6 +32,7 @@ from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers.testing_utils import (
get_tests_dir,
nested_simplify,
require_jinja,
require_sentencepiece,
require_tokenizers,
require_torch,
@@ -574,6 +575,32 @@ class LlamaIntegrationTest(unittest.TestCase):
# a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "user", "content": "Hello!"}],
]
# Matt: The third test case tests the default system message, but if this is ever changed in the
# class/repo code then that test will fail, and the case will need to be updated.
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962],
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2],
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962]
]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
@require_sentencepiece
@require_tokenizers

View File

@@ -2311,3 +2311,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"Dummy warning",
cm.records[0].message,
)
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

View File

@@ -1274,3 +1274,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip("Doesn't support another framework than PyTorch")
def test_np_encode_plus_sent_to_model(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass

View File

@@ -16,7 +16,7 @@ import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
from transformers.testing_utils import require_jinja, slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -473,3 +473,25 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])
@require_jinja
def test_tokenization_for_chat(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)

View File

@@ -78,17 +78,23 @@ class ConversationalPipelineTests(unittest.TestCase):
def run_pipeline_test(self, conversation_agent, _):
# Simple
outputs = conversation_agent(Conversation("Hi there!"))
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
)
# Single list
outputs = conversation_agent([Conversation("Hi there!")])
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
)
# Batch
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?")
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
self.assertEqual(len(conversation_1), 1)
self.assertEqual(len(conversation_2), 1)
outputs = conversation_agent([conversation_1, conversation_2])
self.assertEqual(outputs, [conversation_1, conversation_2])
@@ -96,32 +102,35 @@ class ConversationalPipelineTests(unittest.TestCase):
outputs,
[
Conversation(
past_user_inputs=["Going to the movies tonight - any suggestions?"],
generated_responses=[ANY(str)],
[
{"role": "user", "content": "Going to the movies tonight - any suggestions?"},
{"role": "assistant", "content": ANY(str)},
],
),
Conversation(
[
{"role": "user", "content": "What's the last book you have read?"},
{"role": "assistant", "content": ANY(str)},
]
),
Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]),
],
)
# One conversation with history
conversation_2.add_user_input("Why do you recommend it?")
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
outputs = conversation_agent(conversation_2)
self.assertEqual(outputs, conversation_2)
self.assertEqual(
outputs,
Conversation(
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
generated_responses=[ANY(str), ANY(str)],
[
{"role": "user", "content": "What's the last book you have read?"},
{"role": "assistant", "content": ANY(str)},
{"role": "user", "content": "Why do you recommend it?"},
{"role": "assistant", "content": ANY(str)},
]
),
)
with self.assertRaises(ValueError):
conversation_agent("Hi there!")
with self.assertRaises(ValueError):
conversation_agent(Conversation())
# Conversation have been consumed and are not valid anymore
# Inactive conversations passed to the pipeline raise a ValueError
with self.assertRaises(ValueError):
conversation_agent(conversation_2)
@require_torch
@slow

View File

@@ -50,6 +50,7 @@ from transformers.testing_utils import (
check_json_file_has_correct_format,
get_tests_dir,
is_pt_tf_cross_test,
require_jinja,
require_tf,
require_tokenizers,
require_torch,
@@ -1052,6 +1053,40 @@ class TokenizerTesterMixin:
if tokenizer.num_special_tokens_to_add(pair=True):
self.assertIn(None, output.sequence_ids())
@require_jinja
def test_chat_template(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
expected_output = "systemsystem messageuseruser messageassistantassistant message"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False
)
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: