Overhaul Conversation class and prompt templating (#25323)
* First commit while I figure this out * make fixup * Remove unused method * Store prompt attrib * Fix prompt argument for tests * Make same changes in fast tokenizer * Remove global prompts from fast tokenizer too * stash commit * stash commit * Migrate PromptConfig to its True Final Location * Replace Conversation entirely with the new class * Import/dependency fixes * Import/dependency fixes * Change format for lots of default prompts * More default prompt fixups * Revert llama old methods so we can compare * Fix some default configs * Fix some default configs * Fix misspelled kwarg * Fixes for Blenderbot * make fixup * little rebase cleanup * Add basic documentation * Quick doc fix * Truncate docstring for now * Add handling for the case when messages is a single string * Quick llama merges * Update conversational pipeline and tests * Add a couple of legacy properties for backward compatibility * More legacy handling * Add docstring for build_conversation_input_ids * Restructure PromptConfig * Let's start T E M P L A T I N G * Refactor all default configs to use templates instead * Revert changes to the special token properties since we don't need them anymore * More class templates * Make the sandbox even sandier * Everything replaced with pure templating * Remove docs for PromptConfig * Add testing and optional requirement boilerplate * Fix imports and make fixup * Fix LLaMA tests and add Conversation docstring * Finally get LLaMA working with the template system * Finally get LLaMA working with the template system * make fixup * make fixup * fmt-off for the long lists of test tokens * Rename method to apply_chat_template for now * Start on documentation * Make chat_template a property that reads through to the default if it's not set * Expand docs * Expand chat templating doc some more * trim/lstrip blocks by default and update doc * Few doc tweaks * rebase cleanup * Clarify docstring * rebase cleanup * rebase cleanup * make fixup * Quick doc edit * Reformat the standard template to match ChatML * Re-add PEFT check * Update docs/source/en/chat_templating.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add apply_chat_template to the tokenizer doc * make fixup * Add doc links * Fix chat links * Fix chat links * Explain system messages in the doc * Add chat template test * Proper save-loading for chat template attribute * Add test skips for layout models * Remove _build_conversation_input_ids, add default_chat_template to code_llama * Make sure all LLaMA models are using the latest template * Remove default_system_prompt block in code_llama because it has no default prompt * Update ConversationPipeline preprocess * Add correct #Copied from links to the default_chat_templates * Remove unneeded type checking line * Add a dummy mark_processsed method * Reorganize Conversation to have **deprecated_kwargs * Update chat_templating.md * Quick fix to LLAMA tests * Small doc tweaks * Add proper docstrings and "copied from" statements to all default chat templates * Merge use_default_system_prompt support for code_llama too * Improve clarity around self.chat_template * Docstring fix * Fix blenderbot default template * More doctest fix * Break out some tokenizer kwargs * Update doc to explain default templates * Quick tweaks to tokenizer args * Cleanups for tokenizer args * Add note about cacheing * Quick tweak to the chat-templating doc * Update the LLaMA template with error checking and correct system message embedding * make fixup * make fixup * add requires_jinja * Cleanup to expected output formatting * Add cacheing * Fix typo in llama default template * Update LLaMA tests * Update documentation * Improved legacy handling in the Conversation class * Update Jinja template with proper error handling * Quick bugfix * Proper exception raising * Change cacheing behaviour so it doesn't try to pickle an entire Jinja env * make fixup * rebase cleanup --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import BlenderbotTokenizer, BlenderbotTokenizerFast
|
||||
from transformers.testing_utils import require_jinja
|
||||
from transformers.utils import cached_property
|
||||
|
||||
|
||||
@@ -50,3 +51,24 @@ class Blenderbot3BTokenizerTests(unittest.TestCase):
|
||||
def test_3B_tokenization_same_as_parlai_rust_tokenizer(self):
|
||||
assert self.rust_tokenizer_3b.add_prefix_space
|
||||
assert self.rust_tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]]
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tok = self.tokenizer_3b
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tok.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
expected_tokens = [
|
||||
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 2],
|
||||
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 228, 3490, 287, 2273, 304, 21, 2],
|
||||
[3490, 287, 2273, 304, 21, 228, 228, 6950, 8, 2],
|
||||
]
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import BloomTokenizerFast
|
||||
from transformers.testing_utils import require_tokenizers
|
||||
from transformers.testing_utils import require_jinja, require_tokenizers
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -134,6 +134,27 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
|
||||
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
expected_tokens = [
|
||||
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2],
|
||||
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2],
|
||||
[229126, 427, 11890, 1152, 17, 2, 59414, 4, 2],
|
||||
]
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
def test_add_prefix_space_fast(self):
|
||||
tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True)
|
||||
tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False)
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_tokenizers
|
||||
from transformers.testing_utils import require_jinja, require_tokenizers
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -275,6 +275,27 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname)
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20],
|
||||
[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20, 20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20],
|
||||
[20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20, 20, 3, 0, 0, 1, 20, 20]]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class OPTTokenizationTest(unittest.TestCase):
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import GPTSw3Tokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
|
||||
from transformers.testing_utils import get_tests_dir, require_jinja, require_sentencepiece, require_tokenizers, slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -128,3 +128,27 @@ class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
model_name="AI-Sweden/gpt-sw3-126m",
|
||||
sequences=sequences,
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = GPTSw3Tokenizer(SAMPLE_VOCAB)
|
||||
# This is in English, but it's just here to make sure the chat control tokens are being added properly
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [
|
||||
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ],
|
||||
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 938, 541, 419, ],
|
||||
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ]
|
||||
]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers.models.gptsan_japanese.tokenization_gptsan_japanese import (
|
||||
VOCAB_FILES_NAMES,
|
||||
GPTSanJapaneseTokenizer,
|
||||
)
|
||||
from transformers.testing_utils import require_tokenizers, slow
|
||||
from transformers.testing_utils import require_jinja, require_tokenizers, slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -193,3 +193,27 @@ class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_padding_different_model_input_name(self):
|
||||
# tokenizer has no padding token
|
||||
pass
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
|
||||
# This is in English, but it's just here to make sure the chat control tokens are being added properly
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [
|
||||
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
|
||||
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999, 35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999],
|
||||
[35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
|
||||
]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -2486,3 +2486,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Doesn't support another framework than PyTorch")
|
||||
def test_np_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -2439,3 +2439,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# This should not fail
|
||||
model(encoded_sequence)
|
||||
model(batch_encoded_sequence)
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -1958,3 +1958,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Doesn't use SentencePiece")
|
||||
def test_sentencepiece_tokenize_and_decode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -32,6 +32,7 @@ from transformers.convert_slow_tokenizer import convert_slow_tokenizer
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
nested_simplify,
|
||||
require_jinja,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
@@ -574,6 +575,32 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
# a dummy prefix space is not added by the sp_model as it was de-activated
|
||||
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
# Matt: The third test case tests the default system message, but if this is ever changed in the
|
||||
# class/repo code then that test will fail, and the case will need to be updated.
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962],
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2],
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962]
|
||||
]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
||||
@@ -2311,3 +2311,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"Dummy warning",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -1274,3 +1274,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Doesn't support another framework than PyTorch")
|
||||
def test_np_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.testing_utils import require_jinja, slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -473,3 +473,25 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
|
||||
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||
self.assertEqual(output, [])
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
||||
# This is in English, but it's just here to make sure the chat control tokens are being added properly
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
expected_tokens = [
|
||||
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
|
||||
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
|
||||
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
|
||||
]
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -78,17 +78,23 @@ class ConversationalPipelineTests(unittest.TestCase):
|
||||
def run_pipeline_test(self, conversation_agent, _):
|
||||
# Simple
|
||||
outputs = conversation_agent(Conversation("Hi there!"))
|
||||
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||
)
|
||||
|
||||
# Single list
|
||||
outputs = conversation_agent([Conversation("Hi there!")])
|
||||
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||
)
|
||||
|
||||
# Batch
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_1), 1)
|
||||
self.assertEqual(len(conversation_2), 1)
|
||||
|
||||
outputs = conversation_agent([conversation_1, conversation_2])
|
||||
self.assertEqual(outputs, [conversation_1, conversation_2])
|
||||
@@ -96,32 +102,35 @@ class ConversationalPipelineTests(unittest.TestCase):
|
||||
outputs,
|
||||
[
|
||||
Conversation(
|
||||
past_user_inputs=["Going to the movies tonight - any suggestions?"],
|
||||
generated_responses=[ANY(str)],
|
||||
[
|
||||
{"role": "user", "content": "Going to the movies tonight - any suggestions?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
[
|
||||
{"role": "user", "content": "What's the last book you have read?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
]
|
||||
),
|
||||
Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]),
|
||||
],
|
||||
)
|
||||
|
||||
# One conversation with history
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
|
||||
outputs = conversation_agent(conversation_2)
|
||||
self.assertEqual(outputs, conversation_2)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation(
|
||||
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
|
||||
generated_responses=[ANY(str), ANY(str)],
|
||||
[
|
||||
{"role": "user", "content": "What's the last book you have read?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
{"role": "user", "content": "Why do you recommend it?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
]
|
||||
),
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
conversation_agent("Hi there!")
|
||||
with self.assertRaises(ValueError):
|
||||
conversation_agent(Conversation())
|
||||
# Conversation have been consumed and are not valid anymore
|
||||
# Inactive conversations passed to the pipeline raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
conversation_agent(conversation_2)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
@@ -50,6 +50,7 @@ from transformers.testing_utils import (
|
||||
check_json_file_has_correct_format,
|
||||
get_tests_dir,
|
||||
is_pt_tf_cross_test,
|
||||
require_jinja,
|
||||
require_tf,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
@@ -1052,6 +1053,40 @@ class TokenizerTesterMixin:
|
||||
if tokenizer.num_special_tokens_to_add(pair=True):
|
||||
self.assertIn(None, output.sequence_ids())
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template(self):
|
||||
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
expected_output = "systemsystem messageuseruser messageassistantassistant message"
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
|
||||
# Check that no error raised when tokenize=True
|
||||
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
|
||||
|
||||
tokenizer.chat_template = dummy_template
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
||||
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||
|
||||
def test_number_of_added_tokens(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
|
||||
Reference in New Issue
Block a user