Overhaul Conversation class and prompt templating (#25323)
* First commit while I figure this out * make fixup * Remove unused method * Store prompt attrib * Fix prompt argument for tests * Make same changes in fast tokenizer * Remove global prompts from fast tokenizer too * stash commit * stash commit * Migrate PromptConfig to its True Final Location * Replace Conversation entirely with the new class * Import/dependency fixes * Import/dependency fixes * Change format for lots of default prompts * More default prompt fixups * Revert llama old methods so we can compare * Fix some default configs * Fix some default configs * Fix misspelled kwarg * Fixes for Blenderbot * make fixup * little rebase cleanup * Add basic documentation * Quick doc fix * Truncate docstring for now * Add handling for the case when messages is a single string * Quick llama merges * Update conversational pipeline and tests * Add a couple of legacy properties for backward compatibility * More legacy handling * Add docstring for build_conversation_input_ids * Restructure PromptConfig * Let's start T E M P L A T I N G * Refactor all default configs to use templates instead * Revert changes to the special token properties since we don't need them anymore * More class templates * Make the sandbox even sandier * Everything replaced with pure templating * Remove docs for PromptConfig * Add testing and optional requirement boilerplate * Fix imports and make fixup * Fix LLaMA tests and add Conversation docstring * Finally get LLaMA working with the template system * Finally get LLaMA working with the template system * make fixup * make fixup * fmt-off for the long lists of test tokens * Rename method to apply_chat_template for now * Start on documentation * Make chat_template a property that reads through to the default if it's not set * Expand docs * Expand chat templating doc some more * trim/lstrip blocks by default and update doc * Few doc tweaks * rebase cleanup * Clarify docstring * rebase cleanup * rebase cleanup * make fixup * Quick doc edit * Reformat the standard template to match ChatML * Re-add PEFT check * Update docs/source/en/chat_templating.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add apply_chat_template to the tokenizer doc * make fixup * Add doc links * Fix chat links * Fix chat links * Explain system messages in the doc * Add chat template test * Proper save-loading for chat template attribute * Add test skips for layout models * Remove _build_conversation_input_ids, add default_chat_template to code_llama * Make sure all LLaMA models are using the latest template * Remove default_system_prompt block in code_llama because it has no default prompt * Update ConversationPipeline preprocess * Add correct #Copied from links to the default_chat_templates * Remove unneeded type checking line * Add a dummy mark_processsed method * Reorganize Conversation to have **deprecated_kwargs * Update chat_templating.md * Quick fix to LLAMA tests * Small doc tweaks * Add proper docstrings and "copied from" statements to all default chat templates * Merge use_default_system_prompt support for code_llama too * Improve clarity around self.chat_template * Docstring fix * Fix blenderbot default template * More doctest fix * Break out some tokenizer kwargs * Update doc to explain default templates * Quick tweaks to tokenizer args * Cleanups for tokenizer args * Add note about cacheing * Quick tweak to the chat-templating doc * Update the LLaMA template with error checking and correct system message embedding * make fixup * make fixup * add requires_jinja * Cleanup to expected output formatting * Add cacheing * Fix typo in llama default template * Update LLaMA tests * Update documentation * Improved legacy handling in the Conversation class * Update Jinja template with proper error handling * Quick bugfix * Proper exception raising * Change cacheing behaviour so it doesn't try to pickle an entire Jinja env * make fixup * rebase cleanup --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -98,6 +98,8 @@
|
||||
title: Use model-specific APIs
|
||||
- local: custom_models
|
||||
title: Share a custom model
|
||||
- local: chat_templating
|
||||
title: Templates for chat models
|
||||
- local: sagemaker
|
||||
title: Run training on Amazon SageMaker
|
||||
- local: serialization
|
||||
|
||||
255
docs/source/en/chat_templating.md
Normal file
255
docs/source/en/chat_templating.md
Normal file
@@ -0,0 +1,255 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Templates for Chat Models
|
||||
|
||||
## Introduction
|
||||
|
||||
An increasingly common use case for LLMs is **chat**. In a chat context, rather than continuing a single string
|
||||
of text (as is the case with a standard language model), the model instead continues a conversation that consists
|
||||
of one or more **messages**, each of which includes a **role** as well as message text.
|
||||
|
||||
Most commonly, these roles are "user" for messages sent by the user, and "assistant" for messages sent by the model.
|
||||
Some models also support a "system" role. System messages are usually sent at the beginning of the conversation
|
||||
and include directives about how the model should behave in the subsequent chat.
|
||||
|
||||
All language models, including models fine-tuned for chat, operate on linear sequences of tokens and do not intrinsically
|
||||
have special handling for roles. This means that role information is usually injected by adding control tokens
|
||||
between messages, to indicate both the message boundary and the relevant roles.
|
||||
|
||||
Unfortunately, there isn't (yet!) a standard for which tokens to use, and so different models have been trained
|
||||
with wildly different formatting and control tokens for chat. This can be a real problem for users - if you use the
|
||||
wrong format, then the model will be confused by your input, and your performance will be a lot worse than it should be.
|
||||
This is the problem that **chat templates** aim to resolve.
|
||||
|
||||
Chat conversations are typically represented as a list of dictionaries, where each dictionary contains `role`
|
||||
and `content` keys, and represents a single chat message. Chat templates are strings containing a Jinja template that
|
||||
specifies how to format a conversation for a given model into a single tokenizable sequence. By storing this information
|
||||
with the tokenizer, we can ensure that models get input data in the format they expect.
|
||||
|
||||
Let's make this concrete with a quick example using the `BlenderBot` model. BlenderBot has an extremely simple default
|
||||
template, which mostly just adds whitespace between rounds of dialogue:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
|
||||
>>> chat = [
|
||||
... {"role": "user", "content": "Hello, how are you?"},
|
||||
... {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
... {"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
... ]
|
||||
|
||||
>>> tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
" Hello, how are you? I'm doing great. How can I help you today? I'd like to show off how chat templating works!</s>"
|
||||
```
|
||||
|
||||
Notice how the entire chat is condensed into a single string. If we use `tokenize=True`, which is the default setting,
|
||||
that string will also be tokenized for us. To see a more complex template in action, though, let's use the
|
||||
`meta-llama/Llama-2-7b-chat-hf` model. Note that this model has gated access, so you will have to
|
||||
[request access on the repo](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) if you want to run this code yourself:
|
||||
|
||||
```python
|
||||
>> from transformers import AutoTokenizer
|
||||
>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
|
||||
>> chat = [
|
||||
... {"role": "user", "content": "Hello, how are you?"},
|
||||
... {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
... {"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
... ]
|
||||
|
||||
>> tokenizer.use_default_system_prompt = False
|
||||
>> tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
"<s>[INST] Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]"
|
||||
```
|
||||
|
||||
Note that this time, the tokenizer has added the control tokens [INST] and [/INST] to indicate the start and end of
|
||||
user messages (but not assistant messages!)
|
||||
|
||||
## How do chat templates work?
|
||||
|
||||
The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
|
||||
default template for that model class is used instead. Let's take a look at the template for `BlenderBot`:
|
||||
|
||||
```python
|
||||
|
||||
>>> from transformers import AutoTokenizer
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
|
||||
|
||||
>>> tokenizer.default_chat_template
|
||||
"{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"
|
||||
```
|
||||
|
||||
That's kind of intimidating. Let's add some newlines and indentation to make it more readable. Note that
|
||||
we remove the first newline after each block as well as any preceding whitespace before a block by default, using the
|
||||
Jinja `trim_blocks` and `lstrip_blocks` flags. This means that you can write your templates with indentations and
|
||||
newlines and still have them function correctly!
|
||||
|
||||
```
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
{{ ' ' }}
|
||||
{% endif %}
|
||||
{{ message['content'] }}
|
||||
{% if not loop.last %}
|
||||
{{ ' ' }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{{ eos_token }}
|
||||
```
|
||||
|
||||
If you've never seen one of these before, this is a [Jinja template](https://jinja.palletsprojects.com/en/3.1.x/templates/).
|
||||
Jinja is a templating language that allows you to write simple code that generates text. In many ways, the code and
|
||||
syntax resembles Python. In pure Python, this template would look something like this:
|
||||
|
||||
```python
|
||||
for idx, message in enumerate(messages):
|
||||
if message['role'] == 'user':
|
||||
print(' ')
|
||||
print(message['content'])
|
||||
if not idx == len(messages) - 1: # Check for the last message in the conversation
|
||||
print(' ')
|
||||
print(eos_token)
|
||||
```
|
||||
|
||||
Effectively, the template does three things:
|
||||
1. For each message, if the message is a user message, add a blank space before it, otherwise print nothing.
|
||||
2. Add the message content
|
||||
3. If the message is not the last message, add two spaces after it. After the final message, print the EOS token.
|
||||
|
||||
This is a pretty simple template - it doesn't add any control tokens, and it doesn't support "system" messages, which
|
||||
are a common way to give the model directives about how it should behave in the subsequent conversation.
|
||||
But Jinja gives you a lot of flexibility to do those things! Let's see a Jinja template that can format inputs
|
||||
similarly to the way LLaMA formats them (note that the real LLaMA template includes handling for default system
|
||||
messages and slightly different system message handling in general - don't use this one in your actual code!)
|
||||
|
||||
```
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}
|
||||
{% elif message['role'] == 'system' %}
|
||||
{{ '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
{{ ' ' + message['content'] + ' ' + eos_token }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
Hopefully if you stare at this for a little bit you can see what this template is doing - it adds specific tokens based
|
||||
on the "role" of each message, which represents who sent it. User, assistant and system messages are clearly
|
||||
distinguishable to the model because of the tokens they're wrapped in.
|
||||
|
||||
## How do I create a chat template?
|
||||
|
||||
Simple, just write a jinja template and set `tokenizer.chat_template`. You may find it easier to start with an
|
||||
existing template from another model and simply edit it for your needs! For example, we could take the LLaMA template
|
||||
above and add "[ASST]" and "[/ASST]" to assistant messages:
|
||||
|
||||
```
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}
|
||||
{% elif message['role'] == 'system' %}
|
||||
{{ '<<SYS>>\\n' + message['content'].strip() + '\\n<</SYS>>\\n\\n' }}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
{{ '[ASST] ' + message['content'] + ' [/ASST]' + eos_token }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
Now, simply set the `tokenizer.chat_template` attribute. Next time you use [`~PreTrainedTokenizer.apply_chat_template`], it will
|
||||
use your new template! This attribute will be saved in the `tokenizer_config.json` file, so you can use
|
||||
[`~utils.PushToHubMixin.push_to_hub`] to upload your new template to the Hub and make sure everyone's using the right
|
||||
template for your model!
|
||||
|
||||
```python
|
||||
template = tokenizer.chat_template
|
||||
template = template.replace("SYS", "SYSTEM") # Change the system token
|
||||
tokenizer.chat_template = template # Set the new template
|
||||
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
||||
```
|
||||
|
||||
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`ConversationalPipeline`] class, so
|
||||
once you set the correct chat template, your model will automatically become compatible with [`ConversationalPipeline`].
|
||||
|
||||
## What are "default" templates?
|
||||
|
||||
Before the introduction of chat templates, chat handling was hardcoded at the model class level. For backwards
|
||||
compatibility, we have retained this class-specific handling as default templates, also set at the class level. If a
|
||||
model does not have a chat template set, but there is a default template for its model class, the `ConversationPipeline`
|
||||
class and methods like `apply_chat_template` will use the class template instead. You can find out what the default
|
||||
template for your tokenizer is by checking the `tokenizer.default_chat_template` attribute.
|
||||
|
||||
This is something we do purely for backward compatibility reasons, to avoid breaking any existing workflows. Even when
|
||||
the class template is appropriate for your model, we strongly recommend overriding the default template by
|
||||
setting the `chat_template` attribute explicitly to make it clear to users that your model has been correctly configured
|
||||
for chat, and to future-proof in case the default templates are ever altered or deprecated.
|
||||
|
||||
## What template should I use?
|
||||
|
||||
When setting the template for a model that's already been trained for chat, you should ensure that the template
|
||||
exactly matches the message formatting that the model saw during training, or else you will probably experience
|
||||
performance degradation. This is true even if you're training the model further - you will probably get the best
|
||||
performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the
|
||||
best performance for inference or fine-tuning when you precisely match the tokenization used during training.
|
||||
|
||||
If you're training a model from scratch, or fine-tuning a base language model for chat, on the other hand,
|
||||
you have a lot of freedom to choose an appropriate template! LLMs are smart enough to learn to handle lots of different
|
||||
input formats. Our default template for models that don't have a class-specific template follows the
|
||||
[ChatML format](https://github.com/openai/openai-python/blob/main/chatml.md), and this is a good, flexible choice for many use-cases. It looks like this:
|
||||
|
||||
```
|
||||
{% for message in messages %}
|
||||
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
If you like this one, here it is in one-liner form, ready to copy into your code:
|
||||
|
||||
```
|
||||
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||
```
|
||||
|
||||
This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which
|
||||
allows for flexibility in the roles you train with. The output looks like this:
|
||||
|
||||
```
|
||||
<|im_start|>system
|
||||
You are a helpful chatbot that will do its best not to say anything so stupid that people tweet about it.<|im_end|>
|
||||
<|im_start|>user
|
||||
How are you?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'm doing great!<|im_end|>
|
||||
```
|
||||
|
||||
The "user", "system" and "assistant" roles are the standard for chat, and we recommend using them when it makes sense,
|
||||
particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited
|
||||
to these roles - templating is extremely flexible, and any string can be a role.
|
||||
|
||||
## I want to use chat templates! How should I get started?
|
||||
|
||||
If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using
|
||||
[`~PreTrainedTokenizer.apply_chat_template`]. This applies even if you're not the model owner - if you're using a model
|
||||
with an empty chat template, or one that's still using the default class template, please open a [pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to
|
||||
the model repository so that this attribute can be set properly!
|
||||
|
||||
Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_template` will now work correctly for that
|
||||
model, which means it is also automatically supported in places like `ConversationPipeline`!
|
||||
|
||||
By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of
|
||||
open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
|
||||
it's time to put an end to them!
|
||||
@@ -58,6 +58,7 @@ to a given token).
|
||||
- batch_decode
|
||||
- decode
|
||||
- encode
|
||||
- apply_chat_template
|
||||
- push_to_hub
|
||||
- all
|
||||
|
||||
@@ -71,6 +72,7 @@ loaded very simply into 🤗 transformers. Take a look at the [Using tokenizers
|
||||
- batch_decode
|
||||
- decode
|
||||
- encode
|
||||
- apply_chat_template
|
||||
- push_to_hub
|
||||
- all
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import regex as re
|
||||
|
||||
@@ -25,9 +25,6 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -413,19 +410,16 @@ class BlenderbotTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
inputs = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
if is_user:
|
||||
# We need to space prefix as it's being done within blenderbot
|
||||
inputs.append(" " + text)
|
||||
else:
|
||||
# Generated responses should contain them already.
|
||||
inputs.append(text)
|
||||
|
||||
full_string = " ".join(inputs)
|
||||
input_ids = self.encode(full_string)
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.")
|
||||
return input_ids
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A very simple chat template that just adds whitespace between messages.
|
||||
"""
|
||||
return (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
|
||||
"{{ message['content'] }}"
|
||||
"{% if not loop.last %}{{ ' ' }}{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{{ eos_token }}"
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Fast Tokenization class for Blenderbot."""
|
||||
import json
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from tokenizers import pre_tokenizers, processors
|
||||
|
||||
@@ -24,9 +24,6 @@ from ...utils import logging
|
||||
from .tokenization_blenderbot import BlenderbotTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -297,19 +294,17 @@ class BlenderbotTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
return token_ids_0 + [self.eos_token_id]
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
inputs = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
if is_user:
|
||||
# We need to space prefix as it's being done within blenderbot
|
||||
inputs.append(" " + text)
|
||||
else:
|
||||
# Generated responses should contain them already.
|
||||
inputs.append(text)
|
||||
|
||||
full_string = " ".join(inputs)
|
||||
input_ids = self.encode(full_string)
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
logger.warning(f"Trimmed input from conversation as it was longer than {self.model_max_length} tokens.")
|
||||
return input_ids
|
||||
@property
|
||||
# Copied from transformers.models.blenderbot.tokenization_blenderbot.BlenderbotTokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A very simple chat template that just adds whitespace between messages.
|
||||
"""
|
||||
return (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}"
|
||||
"{{ message['content'] }}"
|
||||
"{% if not loop.last %}{{ ' ' }}{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{{ eos_token }}"
|
||||
)
|
||||
|
||||
@@ -16,17 +16,13 @@
|
||||
|
||||
|
||||
import pickle
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
|
||||
@@ -166,12 +162,10 @@ class BloomTokenizerFast(PreTrainedTokenizerFast):
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""This corresponds to DialoGPT variants of models."""
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
||||
"""
|
||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""Tokenization classes for Code LLaMA."""
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
@@ -26,9 +26,6 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging, requires_backends
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
||||
@@ -441,70 +438,57 @@ class CodeLlamaTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return output
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
r"""Builds the input ids for a conversation.
|
||||
This is the format used in the provided examples. System prompts should be manually added at the beginning of
|
||||
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
|
||||
```
|
||||
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
```
|
||||
|
||||
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
|
||||
```python
|
||||
>>> from transformers import Conversation
|
||||
|
||||
>>> Conversation(
|
||||
... "<<SYS>>\n Complete the functions without any documentation\n<</SYS>>\n\n `def remove_non_ascii(s: str) -> str:`"
|
||||
... ) # doctest: +IGNORE_RESULT
|
||||
```
|
||||
Args:
|
||||
conversation (`Conversation`):
|
||||
Conversation to build input ids for.
|
||||
Returns:
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
@property
|
||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
if self.use_default_system_prompt:
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if (
|
||||
not conversation.past_user_inputs[0].startswith(B_SYS)
|
||||
or E_SYS not in conversation.past_user_inputs[0]
|
||||
):
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif conversation.new_user_input:
|
||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||
else:
|
||||
raise ValueError("Last message must be from user")
|
||||
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
|
||||
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
|
||||
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
|
||||
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
|
||||
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
|
||||
to fine-tune a model with more flexible role ordering!
|
||||
|
||||
dialogue = list(conversation.iter_texts())
|
||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||
[not is_user for is_user, msg in dialogue[1::2]]
|
||||
):
|
||||
raise ValueError(
|
||||
"The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
|
||||
)
|
||||
The output should look something like:
|
||||
|
||||
dialog_tokens: List[int] = []
|
||||
dialog_tokens += sum(
|
||||
[
|
||||
[self.bos_token_id]
|
||||
+ self.encode(
|
||||
f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
|
||||
)
|
||||
+ [self.eos_token_id]
|
||||
for prompt, answer in zip(dialogue[::2], dialogue[1::2])
|
||||
],
|
||||
[],
|
||||
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
"""
|
||||
|
||||
template = (
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
"{% endif %}"
|
||||
"{% for message in loop_messages %}" # Loop over all non-system messages
|
||||
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
|
||||
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
|
||||
"{% endif %}"
|
||||
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
|
||||
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
|
||||
"{% else %}"
|
||||
"{% set content = message['content'] %}"
|
||||
"{% endif %}"
|
||||
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
|
||||
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
|
||||
"{% elif message['role'] == 'system' %}"
|
||||
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{ ' ' + content.strip() + ' ' + eos_token }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
dialog_tokens += [self.bos_token_id] + self.encode(
|
||||
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
|
||||
)
|
||||
return dialog_tokens
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from tokenizers import normalizers, processors
|
||||
|
||||
@@ -23,9 +23,6 @@ from ...utils import is_sentencepiece_available, logging
|
||||
from ...utils.versions import require_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
require_version("tokenizers>=0.13.3")
|
||||
|
||||
if is_sentencepiece_available():
|
||||
@@ -344,6 +341,58 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
|
||||
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
|
||||
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
|
||||
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
|
||||
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
|
||||
to fine-tune a model with more flexible role ordering!
|
||||
|
||||
The output should look something like:
|
||||
|
||||
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
"""
|
||||
|
||||
template = (
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
"{% endif %}"
|
||||
"{% for message in loop_messages %}" # Loop over all non-system messages
|
||||
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
|
||||
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
|
||||
"{% endif %}"
|
||||
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
|
||||
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
|
||||
"{% else %}"
|
||||
"{% set content = message['content'] %}"
|
||||
"{% endif %}"
|
||||
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
|
||||
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
|
||||
"{% elif message['role'] == 'system' %}"
|
||||
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{ ' ' + content.strip() + ' ' + eos_token }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
@@ -371,69 +420,3 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
if token_ids_1 is None:
|
||||
return self.bos_token_id + token_ids_0 + self.eos_token_id
|
||||
return self.bos_token_id + token_ids_0 + token_ids_1 + self.eos_token_id
|
||||
|
||||
# Copied from transformers.models.code_llama.tokenization_code_llama.CodeLlamaTokenizer._build_conversation_input_ids
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
r"""Builds the input ids for a conversation.
|
||||
This is the format used in the provided examples. System prompts should be manually added at the beginning of
|
||||
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
|
||||
```
|
||||
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
```
|
||||
|
||||
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
|
||||
```python
|
||||
>>> from transformers import Conversation
|
||||
|
||||
>>> Conversation(
|
||||
... "<<SYS>>\n Complete the functions without any documentation\n<</SYS>>\n\n `def remove_non_ascii(s: str) -> str:`"
|
||||
... ) # doctest: +IGNORE_RESULT
|
||||
```
|
||||
Args:
|
||||
conversation (`Conversation`):
|
||||
Conversation to build input ids for.
|
||||
Returns:
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
"""
|
||||
if self.use_default_system_prompt:
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if (
|
||||
not conversation.past_user_inputs[0].startswith(B_SYS)
|
||||
or E_SYS not in conversation.past_user_inputs[0]
|
||||
):
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif conversation.new_user_input:
|
||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||
else:
|
||||
raise ValueError("Last message must be from user")
|
||||
|
||||
dialogue = list(conversation.iter_texts())
|
||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||
[not is_user for is_user, msg in dialogue[1::2]]
|
||||
):
|
||||
raise ValueError(
|
||||
"The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
|
||||
)
|
||||
|
||||
dialog_tokens: List[int] = []
|
||||
dialog_tokens += sum(
|
||||
[
|
||||
[self.bos_token_id]
|
||||
+ self.encode(
|
||||
f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
|
||||
)
|
||||
+ [self.eos_token_id]
|
||||
for prompt, answer in zip(dialogue[::2], dialogue[1::2])
|
||||
],
|
||||
[],
|
||||
)
|
||||
dialog_tokens += [self.bos_token_id] + self.encode(
|
||||
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
|
||||
)
|
||||
return dialog_tokens
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import regex as re
|
||||
|
||||
@@ -24,9 +24,6 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
|
||||
@@ -433,12 +430,3 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
||||
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
|
||||
text = " " + text
|
||||
return (text, kwargs)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
""" Fast Tokenization class for model DeBERTa."""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
@@ -25,10 +25,6 @@ from ...utils import logging
|
||||
from .tokenization_deberta import DebertaTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
||||
@@ -288,14 +284,3 @@ class DebertaTokenizerFast(PreTrainedTokenizerFast):
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast._build_conversation_input_ids
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""This corresponds to DialoGPT variants of models."""
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import regex as re
|
||||
|
||||
@@ -26,9 +26,6 @@ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
@@ -354,10 +351,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
text = " " + text
|
||||
return (text, kwargs)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
||||
"""
|
||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
@@ -26,10 +26,6 @@ from ...utils import logging
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
||||
@@ -181,12 +177,10 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""This corresponds to DialoGPT variants of models."""
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
||||
"""
|
||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for GPTNeoX."""
|
||||
import json
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
@@ -22,10 +22,6 @@ from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
||||
@@ -133,12 +129,10 @@ class GPTNeoXTokenizerFast(PreTrainedTokenizerFast):
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""This corresponds to DialoGPT variants of models."""
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
||||
"""
|
||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
||||
|
||||
@@ -17,7 +17,7 @@ import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -25,10 +25,6 @@ from ...tokenization_utils_fast import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
|
||||
@@ -179,15 +175,14 @@ class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer):
|
||||
out_string = "".join(tokens).strip()
|
||||
return out_string
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""This corresponds to DialoGPT variants of models."""
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that just adds BOS/EOS tokens around messages while discarding role information.
|
||||
"""
|
||||
return (
|
||||
"{% for message in messages %}" "{{ bos_token + eos_token + message.content + eos_token }}" "{% endfor %}"
|
||||
)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
index = 0
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import re
|
||||
import unicodedata
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
@@ -16,10 +16,6 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
|
||||
|
||||
@@ -319,31 +315,18 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
return self.sp_model.decode(token_ids)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""Builds the input ids for a conversation.
|
||||
|
||||
This is the format used in the original GPT-SW3 paper [1] and which is also mentioned in the model card [2].
|
||||
The format is inspired by the ChatML format [3]. Concretely, the chat format is set up as follows:
|
||||
|
||||
```
|
||||
<eos><bos>User: Jag tycker träd är fina<bos>Bot: Kul att du tycker det!<bos>...
|
||||
```
|
||||
|
||||
Args:
|
||||
conversation (`Conversation`):
|
||||
Conversation to build input ids for.
|
||||
|
||||
Returns:
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
|
||||
References:
|
||||
- [1] https://doi.org/10.48550/arXiv.2305.12987
|
||||
- [2] https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m-instruct
|
||||
- [3] https://github.com/openai/openai-python/blob/main/chatml.md
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
all_responses = [f"User: {text}" if is_user else f"Bot: {text}" for is_user, text in conversation.iter_texts()]
|
||||
prompt = (
|
||||
f"{self.eos_token}{self.bos_token}" + f"{self.bos_token}".join(all_responses) + f"{self.bos_token}Bot:"
|
||||
This chat template formats messages like an instant messenger chat log, with "User:" and "Bot:" strings
|
||||
preceding messages. BOS tokens are added between all messages.
|
||||
"""
|
||||
return (
|
||||
"{{ eos_token }}{{ bos_token }}"
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'user' %}{{ 'User: ' + message['content']}}"
|
||||
"{% else %}{{ 'Bot: ' + message['content']}}{% endif %}"
|
||||
"{{ message['text'] }}{{ bos_token }}"
|
||||
"{% endfor %}"
|
||||
"Bot:"
|
||||
)
|
||||
return self.encode(text=prompt)
|
||||
|
||||
@@ -17,7 +17,7 @@ import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -33,10 +33,6 @@ from ...tokenization_utils_base import (
|
||||
from ...utils import PaddingStrategy, logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
|
||||
@@ -258,16 +254,18 @@ class GPTSanJapaneseTokenizer(PreTrainedTokenizer):
|
||||
text = "".join(words)
|
||||
return text
|
||||
|
||||
# Copied from tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizer._build_conversation_input_ids
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""This corresponds to DialoGPT variants of models."""
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that adds standard BOS, SEP and EOS tokens between messages while discarding role
|
||||
information.
|
||||
"""
|
||||
return (
|
||||
"{% for message in messages %}"
|
||||
"{% if not loop.first %}{{ bos_token}}{% endif %}"
|
||||
"{{ sep_token }}{{ message.content }} {{ eos_token }}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
# Copied from tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizer.save_vocabulary
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
|
||||
@@ -31,7 +31,6 @@ from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...pipelines.conversational import Conversation
|
||||
from ...tokenization_utils_base import TextInput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -374,67 +373,53 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return output
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
r"""Builds the input ids for a conversation.
|
||||
This is the format used in the provided examples. System prompts should be manually added at the beginning of
|
||||
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
|
||||
```
|
||||
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
```
|
||||
|
||||
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
|
||||
```python
|
||||
>>> from transformers import Conversation
|
||||
|
||||
>>> Conversation(
|
||||
... "<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 septs?"
|
||||
... ) # doctest: +IGNORE_RESULT
|
||||
```
|
||||
Args:
|
||||
conversation (`Conversation`):
|
||||
Conversation to build input ids for.
|
||||
Returns:
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
if self.use_default_system_prompt:
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if (
|
||||
not conversation.past_user_inputs[0].startswith(B_SYS)
|
||||
or E_SYS not in conversation.past_user_inputs[0]
|
||||
):
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif conversation.new_user_input:
|
||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||
else:
|
||||
raise ValueError("Last message must be from user")
|
||||
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
|
||||
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
|
||||
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
|
||||
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
|
||||
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
|
||||
to fine-tune a model with more flexible role ordering!
|
||||
|
||||
dialogue = list(conversation.iter_texts())
|
||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||
[not is_user for is_user, msg in dialogue[1::2]]
|
||||
):
|
||||
raise ValueError(
|
||||
"The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
|
||||
)
|
||||
The output should look something like:
|
||||
|
||||
dialog_tokens: List[int] = []
|
||||
dialog_tokens += sum(
|
||||
[
|
||||
[self.bos_token_id]
|
||||
+ self.encode(
|
||||
f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
|
||||
)
|
||||
+ [self.eos_token_id]
|
||||
for prompt, answer in zip(dialogue[::2], dialogue[1::2])
|
||||
],
|
||||
[],
|
||||
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
"""
|
||||
|
||||
template = (
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
"{% endif %}"
|
||||
"{% for message in loop_messages %}" # Loop over all non-system messages
|
||||
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
|
||||
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
|
||||
"{% endif %}"
|
||||
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
|
||||
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
|
||||
"{% else %}"
|
||||
"{% set content = message['content'] %}"
|
||||
"{% endif %}"
|
||||
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
|
||||
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
|
||||
"{% elif message['role'] == 'system' %}"
|
||||
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{ ' ' + content.strip() + ' ' + eos_token }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
dialog_tokens += [self.bos_token_id] + self.encode(
|
||||
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
|
||||
)
|
||||
return dialog_tokens
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from tokenizers import processors
|
||||
|
||||
@@ -23,9 +23,6 @@ from ...utils import is_sentencepiece_available, logging
|
||||
from ...utils.versions import require_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
require_version("tokenizers>=0.13.3")
|
||||
|
||||
if is_sentencepiece_available():
|
||||
@@ -192,67 +189,54 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation"):
|
||||
"""Builds the input ids for a conversation.
|
||||
This is the format used in the provided examples. System prompts should be manually added at the beginning of
|
||||
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
|
||||
```
|
||||
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
```
|
||||
|
||||
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
|
||||
```python
|
||||
>>> from transformers import Conversation
|
||||
|
||||
>>> Conversation(
|
||||
... "<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 septs?"
|
||||
... )
|
||||
```
|
||||
Args:
|
||||
conversation (`Conversation`):
|
||||
Conversation to build input ids for.
|
||||
Returns:
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
@property
|
||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
if self.use_default_system_prompt:
|
||||
if len(conversation.past_user_inputs) > 0:
|
||||
if (
|
||||
not conversation.past_user_inputs[0].startswith(B_SYS)
|
||||
or E_SYS not in conversation.past_user_inputs[0]
|
||||
):
|
||||
conversation.past_user_inputs[0] = (
|
||||
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.past_user_inputs[0]
|
||||
)
|
||||
elif conversation.new_user_input:
|
||||
if not conversation.new_user_input.startswith(B_SYS) or E_SYS not in conversation.new_user_input:
|
||||
conversation.new_user_input = B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + conversation.new_user_input
|
||||
else:
|
||||
raise ValueError("Last message must be from user")
|
||||
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
|
||||
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
|
||||
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
|
||||
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
|
||||
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
|
||||
to fine-tune a model with more flexible role ordering!
|
||||
|
||||
dialogue = list(conversation.iter_texts())
|
||||
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
|
||||
[not is_user for is_user, msg in dialogue[1::2]]
|
||||
):
|
||||
raise ValueError(
|
||||
"The model only supports 'user' and 'assistant' roles, starting with user and alternating (u/a/u/a/u...)"
|
||||
)
|
||||
The output should look something like:
|
||||
|
||||
dialog_tokens = []
|
||||
dialog_tokens += sum(
|
||||
[
|
||||
[self.bos_token_id]
|
||||
+ self.encode(
|
||||
f"{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ", add_special_tokens=False
|
||||
)
|
||||
+ [self.eos_token_id]
|
||||
for prompt, answer in zip(dialogue[::2], dialogue[1::2])
|
||||
],
|
||||
[],
|
||||
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
|
||||
<bos>[INST] Prompt [/INST]
|
||||
"""
|
||||
|
||||
template = (
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
"{% endif %}"
|
||||
"{% for message in loop_messages %}" # Loop over all non-system messages
|
||||
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
|
||||
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
|
||||
"{% endif %}"
|
||||
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
|
||||
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
|
||||
"{% else %}"
|
||||
"{% set content = message['content'] %}"
|
||||
"{% endif %}"
|
||||
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
|
||||
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
|
||||
"{% elif message['role'] == 'system' %}"
|
||||
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{ ' ' + content.strip() + ' ' + eos_token }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
dialog_tokens += [self.bos_token_id] + self.encode(
|
||||
f"{B_INST} {(dialogue[-1][1]).strip()} {E_INST}", add_special_tokens=False
|
||||
)
|
||||
return dialog_tokens
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
@@ -26,10 +26,6 @@ from ...utils import logging
|
||||
from .english_normalizer import EnglishTextNormalizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...pipelines.conversational import Conversation
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"tokenizer_file": "tokenizer.json",
|
||||
@@ -751,14 +747,13 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
text = " " + text
|
||||
return (text, kwargs)
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
||||
"""
|
||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
||||
|
||||
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
||||
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from tokenizers import pre_tokenizers, processors
|
||||
@@ -28,10 +28,6 @@ from .english_normalizer import EnglishTextNormalizer
|
||||
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
@@ -520,14 +516,13 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._build_conversation_input_ids
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
|
||||
if len(input_ids) > self.model_max_length:
|
||||
input_ids = input_ids[-self.model_max_length :]
|
||||
return input_ids
|
||||
@property
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
||||
"""
|
||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
||||
|
||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_decoder_prompt_ids
|
||||
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
@@ -19,137 +19,153 @@ class Conversation:
|
||||
"""
|
||||
Utility class containing a conversation and its history. This class is meant to be used as an input to the
|
||||
[`ConversationalPipeline`]. The conversation contains several utility functions to manage the addition of new user
|
||||
inputs and generated model responses. A conversation needs to contain an unprocessed user input before being passed
|
||||
to the [`ConversationalPipeline`]. This user input is either created when the class is instantiated, or by calling
|
||||
`conversational_pipeline.append_response("input")` after a conversation turn.
|
||||
inputs and generated model responses.
|
||||
|
||||
Arguments:
|
||||
text (`str`, *optional*):
|
||||
The initial user input to start the conversation. If not provided, a user input needs to be provided
|
||||
manually using the [`~Conversation.add_user_input`] method before the conversation can begin.
|
||||
messages (Union[str, List[Dict[str, str]]], *optional*):
|
||||
The initial messages to start the conversation, either a string, or a list of dicts containing "role" and
|
||||
"content" keys. If a string is passed, it is interpreted as a single message with the "user" role.
|
||||
conversation_id (`uuid.UUID`, *optional*):
|
||||
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
|
||||
conversation.
|
||||
past_user_inputs (`List[str]`, *optional*):
|
||||
Eventual past history of the conversation of the user. You don't need to pass it manually if you use the
|
||||
pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and
|
||||
`generated_responses` with equal length lists of strings
|
||||
generated_responses (`List[str]`, *optional*):
|
||||
Eventual past history of the conversation of the model. You don't need to pass it manually if you use the
|
||||
pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and
|
||||
`generated_responses` with equal length lists of strings
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
conversation = Conversation("Going to the movies tonight - any suggestions?")
|
||||
|
||||
# Steps usually performed by the model when generating a response:
|
||||
# 1. Mark the user input as processed (moved to the history)
|
||||
conversation.mark_processed()
|
||||
# 2. Append a mode response
|
||||
conversation.append_response("The Big lebowski.")
|
||||
|
||||
conversation.add_user_input("Is it good?")
|
||||
conversation.add_message({"role": "assistant", "content": "The Big lebowski."})
|
||||
conversation.add_message({"role": "user", "content": "Is it good?"})
|
||||
```"""
|
||||
|
||||
def __init__(
|
||||
self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None
|
||||
self, messages: Union[str, List[Dict[str, str]]] = None, conversation_id: uuid.UUID = None, **deprecated_kwargs
|
||||
):
|
||||
if not conversation_id:
|
||||
conversation_id = uuid.uuid4()
|
||||
if past_user_inputs is None:
|
||||
past_user_inputs = []
|
||||
if generated_responses is None:
|
||||
generated_responses = []
|
||||
|
||||
self.uuid: uuid.UUID = conversation_id
|
||||
self.past_user_inputs: List[str] = past_user_inputs
|
||||
self.generated_responses: List[str] = generated_responses
|
||||
self.new_user_input: Optional[str] = text
|
||||
if messages is None:
|
||||
text = deprecated_kwargs.pop("text", None)
|
||||
if text is not None:
|
||||
messages = [{"role": "user", "content": text}]
|
||||
else:
|
||||
messages = []
|
||||
elif isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# This block deals with the legacy args - new code should just totally
|
||||
# avoid past_user_inputs and generated_responses
|
||||
generated_responses = deprecated_kwargs.pop("generated_responses", None)
|
||||
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
|
||||
if generated_responses is not None and past_user_inputs is None:
|
||||
raise ValueError("generated_responses cannot be passed without past_user_inputs!")
|
||||
if past_user_inputs is not None:
|
||||
legacy_messages = []
|
||||
if generated_responses is None:
|
||||
generated_responses = []
|
||||
# We structure it this way instead of using zip() because the lengths may differ by 1
|
||||
for i in range(max([len(past_user_inputs), len(generated_responses)])):
|
||||
if i < len(past_user_inputs):
|
||||
legacy_messages.append({"role": "user", "content": past_user_inputs[i]})
|
||||
if i < len(generated_responses):
|
||||
legacy_messages.append({"role": "assistant", "content": generated_responses[i]})
|
||||
messages = legacy_messages + messages
|
||||
|
||||
self.uuid = conversation_id
|
||||
self.messages = messages
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Conversation):
|
||||
return False
|
||||
if self.uuid == other.uuid:
|
||||
return True
|
||||
return (
|
||||
self.new_user_input == other.new_user_input
|
||||
and self.past_user_inputs == other.past_user_inputs
|
||||
and self.generated_responses == other.generated_responses
|
||||
)
|
||||
return self.uuid == other.uuid or self.messages == other.messages
|
||||
|
||||
def add_message(self, message: Dict[str, str]):
|
||||
if not set(message.keys()) == {"role", "content"}:
|
||||
raise ValueError("Message should contain only 'role' and 'content' keys!")
|
||||
if message["role"] not in ("user", "assistant", "system"):
|
||||
raise ValueError("Only 'user', 'assistant' and 'system' roles are supported for now!")
|
||||
self.messages.append(message)
|
||||
|
||||
def add_user_input(self, text: str, overwrite: bool = False):
|
||||
"""
|
||||
Add a user input to the conversation for the next round. This populates the internal `new_user_input` field.
|
||||
|
||||
Args:
|
||||
text (`str`): The user input for the next conversation round.
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not existing and unprocessed user input should be overwritten when this function is called.
|
||||
Add a user input to the conversation for the next round. This is a legacy method that assumes that inputs must
|
||||
alternate user/assistant/user/assistant, and so will not add multiple user messages in succession. We recommend
|
||||
just using `add_message` with role "user" instead.
|
||||
"""
|
||||
if self.new_user_input:
|
||||
if len(self) > 0 and self[-1]["role"] == "user":
|
||||
if overwrite:
|
||||
logger.warning(
|
||||
f'User input added while unprocessed input was existing: "{self.new_user_input}" was overwritten '
|
||||
f'User input added while unprocessed input was existing: "{self[-1]["content"]}" was overwritten '
|
||||
f'with: "{text}".'
|
||||
)
|
||||
self.new_user_input = text
|
||||
self[-1]["content"] = text
|
||||
else:
|
||||
logger.warning(
|
||||
f'User input added while unprocessed input was existing: "{self.new_user_input}" new input '
|
||||
f'User input added while unprocessed input was existing: "{self[-1]["content"]}" new input '
|
||||
f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input'
|
||||
)
|
||||
else:
|
||||
self.new_user_input = text
|
||||
|
||||
def mark_processed(self):
|
||||
"""
|
||||
Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties
|
||||
the `new_user_input` field.
|
||||
"""
|
||||
if self.new_user_input:
|
||||
self.past_user_inputs.append(self.new_user_input)
|
||||
self.new_user_input = None
|
||||
self.messages.append({"role": "user", "content": text})
|
||||
|
||||
def append_response(self, response: str):
|
||||
"""
|
||||
Append a response to the list of generated responses.
|
||||
|
||||
Args:
|
||||
response (`str`): The model generated response.
|
||||
This is a legacy method. We recommend just using `add_message` with an appropriate role instead.
|
||||
"""
|
||||
self.generated_responses.append(response)
|
||||
self.messages.append({"role": "assistant", "content": response})
|
||||
|
||||
def iter_texts(self):
|
||||
def mark_processed(self):
|
||||
"""
|
||||
Iterates over all blobs of the conversation.
|
||||
This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
|
||||
processed and unprocessed user input.
|
||||
"""
|
||||
pass
|
||||
|
||||
Returns: Iterator of (is_user, text_chunk) in chronological order of the conversation. `is_user` is a `bool`,
|
||||
`text_chunks` is a `str`.
|
||||
"""
|
||||
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses):
|
||||
yield True, user_input
|
||||
yield False, generated_response
|
||||
if self.new_user_input:
|
||||
yield True, self.new_user_input
|
||||
def __iter__(self):
|
||||
for message in self.messages:
|
||||
yield message
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.messages[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.messages[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Generates a string representation of the conversation.
|
||||
|
||||
Return:
|
||||
Returns:
|
||||
`str`:
|
||||
|
||||
Example: Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user >> Going to the movies tonight - any
|
||||
suggestions? bot >> The Big Lebowski
|
||||
Example:
|
||||
Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user: Going to the movies tonight - any suggestions?
|
||||
bot: The Big Lebowski
|
||||
"""
|
||||
output = f"Conversation id: {self.uuid} \n"
|
||||
for is_user, text in self.iter_texts():
|
||||
name = "user" if is_user else "bot"
|
||||
output += f"{name} >> {text} \n"
|
||||
output = f"Conversation id: {self.uuid}\n"
|
||||
for message in self.messages:
|
||||
output += f"{message['role']}: {message['content']}\n"
|
||||
return output
|
||||
|
||||
def iter_texts(self):
|
||||
# This is a legacy method for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
for message in self.messages:
|
||||
yield message["role"] == "user", message["content"]
|
||||
|
||||
@property
|
||||
def past_user_inputs(self):
|
||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
return [message["content"] for message in self.messages if message["role"] == "user"]
|
||||
|
||||
@property
|
||||
def generated_responses(self):
|
||||
# This is a legacy property for backwards compatibility. It is recommended to just directly access
|
||||
# conversation.messages instead.
|
||||
return [message["content"] for message in self.messages if message["role"] == "assistant"]
|
||||
|
||||
|
||||
@add_end_docstrings(
|
||||
PIPELINE_INIT_ARGS,
|
||||
@@ -246,18 +262,7 @@ class ConversationalPipeline(Pipeline):
|
||||
return outputs
|
||||
|
||||
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
|
||||
if not isinstance(conversation, Conversation):
|
||||
raise ValueError("ConversationalPipeline, expects Conversation as inputs")
|
||||
if conversation.new_user_input is None:
|
||||
raise ValueError(
|
||||
f"Conversation with UUID {type(conversation.uuid)} does not contain new user input to process. "
|
||||
"Add user inputs with the conversation's `add_user_input` method"
|
||||
)
|
||||
if hasattr(self.tokenizer, "_build_conversation_input_ids"):
|
||||
input_ids = self.tokenizer._build_conversation_input_ids(conversation)
|
||||
else:
|
||||
# If the tokenizer cannot handle conversations, we default to only the old version
|
||||
input_ids = self._legacy_parse_and_tokenize(conversation)
|
||||
input_ids = self.tokenizer.apply_chat_template(conversation)
|
||||
|
||||
if self.framework == "pt":
|
||||
input_ids = torch.LongTensor([input_ids])
|
||||
@@ -292,19 +297,5 @@ class ConversationalPipeline(Pipeline):
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
conversation = model_outputs["conversation"]
|
||||
conversation.mark_processed()
|
||||
conversation.append_response(answer)
|
||||
conversation.add_message({"role": "assistant", "content": answer})
|
||||
return conversation
|
||||
|
||||
def _legacy_parse_and_tokenize(self, conversation: Conversation) -> Dict:
|
||||
eos_token_id = self.tokenizer.eos_token_id
|
||||
input_ids = []
|
||||
for is_user, text in conversation.iter_texts():
|
||||
if eos_token_id is not None:
|
||||
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id])
|
||||
else:
|
||||
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False))
|
||||
|
||||
if len(input_ids) > self.tokenizer.model_max_length:
|
||||
input_ids = input_ids[-self.tokenizer.model_max_length :]
|
||||
return input_ids
|
||||
|
||||
@@ -64,6 +64,7 @@ from .utils import (
|
||||
is_ftfy_available,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
is_jinja_available,
|
||||
is_jumanpp_available,
|
||||
is_keras_nlp_available,
|
||||
is_librosa_available,
|
||||
@@ -336,6 +337,13 @@ def require_jieba(test_case):
|
||||
return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)
|
||||
|
||||
|
||||
def require_jinja(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
|
||||
|
||||
|
||||
def require_tf2onnx(test_case):
|
||||
return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from collections import OrderedDict, UserDict
|
||||
from collections.abc import Mapping, Sized
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -69,6 +70,7 @@ if TYPE_CHECKING:
|
||||
import tensorflow as tf
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp # noqa: F401
|
||||
from .pipelines.conversational import Conversation
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
@@ -1426,6 +1428,7 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
|
||||
- **length** -- The length of the inputs (when `return_length=True`)
|
||||
"""
|
||||
|
||||
|
||||
INIT_TOKENIZER_DOCSTRING = r"""
|
||||
Class attributes (overridden by derived classes)
|
||||
|
||||
@@ -1461,6 +1464,9 @@ INIT_TOKENIZER_DOCSTRING = r"""
|
||||
truncation_side (`str`, *optional*):
|
||||
The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
|
||||
Default value is picked from the class attribute of the same name.
|
||||
chat_template (`str`, *optional*):
|
||||
A Jinja template string that will be used to format lists of chat messages. See
|
||||
https://huggingface.co/docs/transformers/chat_templating for a full description.
|
||||
model_input_names (`List[string]`, *optional*):
|
||||
The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
|
||||
`"attention_mask"`). Default value is picked from the class attribute of the same name.
|
||||
@@ -1558,6 +1564,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
{}
|
||||
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
|
||||
self._in_target_context_manager = False
|
||||
|
||||
# Stores a Jinja template that formats chat histories into tokenizable strings
|
||||
self.chat_template = kwargs.pop("chat_template", None)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@@ -1627,6 +1637,109 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], "Conversation"],
|
||||
chat_template: Optional[str] = None,
|
||||
tokenize: bool = True,
|
||||
padding: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**tokenizer_kwargs,
|
||||
) -> Union[str, List[int]]:
|
||||
"""
|
||||
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
|
||||
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
|
||||
determine the format and control tokens to use when converting. When chat_template is None, it will fall back
|
||||
to the default_chat_template specified at the class level.
|
||||
|
||||
Args:
|
||||
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
||||
this is not passed, the model's default chat template will be used instead.
|
||||
tokenize (`bool`, defaults to `True`):
|
||||
Whether to tokenize the output. If `False`, the output will be a string.
|
||||
padding (`bool`, defaults to `False`):
|
||||
Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
|
||||
truncation (`bool`, defaults to `False`):
|
||||
Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
|
||||
not specified, the tokenizer's `max_length` attribute will be used as a default.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
|
||||
values are:
|
||||
- `'tf'`: Return TensorFlow `tf.Tensor` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
**tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
|
||||
output is ready to pass to the model, either directly or via methods like `generate()`.
|
||||
"""
|
||||
|
||||
if hasattr(conversation, "messages"):
|
||||
# Indicates it's a Conversation object
|
||||
conversation = conversation.messages
|
||||
|
||||
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template`
|
||||
if chat_template is None:
|
||||
if self.chat_template is not None:
|
||||
chat_template = self.chat_template
|
||||
else:
|
||||
chat_template = self.default_chat_template
|
||||
|
||||
# Compilation function uses a cache to avoid recompiling the same template
|
||||
compiled_template = self._compile_jinja_template(chat_template)
|
||||
|
||||
rendered = compiled_template.render(messages=conversation, **self.special_tokens_map)
|
||||
|
||||
if padding is True:
|
||||
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
|
||||
if tokenize:
|
||||
return self.encode(
|
||||
rendered,
|
||||
add_special_tokens=False,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
else:
|
||||
return rendered
|
||||
|
||||
@lru_cache
|
||||
def _compile_jinja_template(self, chat_template):
|
||||
try:
|
||||
from jinja2.exceptions import TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
except ImportError:
|
||||
raise ImportError("apply_chat_template requires jinja2 to be installed.")
|
||||
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||
jinja_env.globals["raise_exception"] = raise_exception
|
||||
return jinja_env.from_string(chat_template)
|
||||
|
||||
@property
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
This template formats inputs in the standard ChatML format. See
|
||||
https://github.com/openai/openai-python/blob/main/chatml.md
|
||||
"""
|
||||
return (
|
||||
"{% for message in messages %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -2187,6 +2300,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if hasattr(self, k):
|
||||
tokenizer_config[k] = getattr(self, k)
|
||||
|
||||
if self.chat_template is not None:
|
||||
tokenizer_config["chat_template"] = self.chat_template
|
||||
|
||||
if len(self.init_inputs) > 0:
|
||||
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
|
||||
for file_id in self.vocab_files_names.keys():
|
||||
|
||||
@@ -119,6 +119,7 @@ from .import_utils import (
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_jieba_available,
|
||||
is_jinja_available,
|
||||
is_jumanpp_available,
|
||||
is_kenlm_available,
|
||||
is_keras_nlp_available,
|
||||
|
||||
@@ -91,6 +91,7 @@ except importlib.metadata.PackageNotFoundError:
|
||||
_ftfy_available = _is_package_available("ftfy")
|
||||
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
|
||||
_jieba_available = _is_package_available("jieba")
|
||||
_jinja_available = _is_package_available("jinja2")
|
||||
_kenlm_available = _is_package_available("kenlm")
|
||||
_keras_nlp_available = _is_package_available("keras_nlp")
|
||||
_librosa_available = _is_package_available("librosa")
|
||||
@@ -793,6 +794,10 @@ def is_jieba_available():
|
||||
return _jieba_available
|
||||
|
||||
|
||||
def is_jinja_available():
|
||||
return _jinja_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
DATASETS_IMPORT_ERROR = """
|
||||
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
|
||||
@@ -1081,6 +1086,11 @@ PEFT_IMPORT_ERROR = """
|
||||
peft`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
JINJA_IMPORT_ERROR = """
|
||||
{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
|
||||
jinja2`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
@@ -1118,6 +1128,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
|
||||
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
|
||||
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
||||
("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import BlenderbotTokenizer, BlenderbotTokenizerFast
|
||||
from transformers.testing_utils import require_jinja
|
||||
from transformers.utils import cached_property
|
||||
|
||||
|
||||
@@ -50,3 +51,24 @@ class Blenderbot3BTokenizerTests(unittest.TestCase):
|
||||
def test_3B_tokenization_same_as_parlai_rust_tokenizer(self):
|
||||
assert self.rust_tokenizer_3b.add_prefix_space
|
||||
assert self.rust_tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]]
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tok = self.tokenizer_3b
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tok.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
expected_tokens = [
|
||||
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 2],
|
||||
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 228, 3490, 287, 2273, 304, 21, 2],
|
||||
[3490, 287, 2273, 304, 21, 228, 228, 6950, 8, 2],
|
||||
]
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import BloomTokenizerFast
|
||||
from transformers.testing_utils import require_tokenizers
|
||||
from transformers.testing_utils import require_jinja, require_tokenizers
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -134,6 +134,27 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
|
||||
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
expected_tokens = [
|
||||
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2],
|
||||
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2],
|
||||
[229126, 427, 11890, 1152, 17, 2, 59414, 4, 2],
|
||||
]
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
def test_add_prefix_space_fast(self):
|
||||
tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True)
|
||||
tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False)
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_tokenizers
|
||||
from transformers.testing_utils import require_jinja, require_tokenizers
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -275,6 +275,27 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname)
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20],
|
||||
[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20, 20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20],
|
||||
[20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20, 20, 3, 0, 0, 1, 20, 20]]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class OPTTokenizationTest(unittest.TestCase):
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import GPTSw3Tokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
|
||||
from transformers.testing_utils import get_tests_dir, require_jinja, require_sentencepiece, require_tokenizers, slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -128,3 +128,27 @@ class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
model_name="AI-Sweden/gpt-sw3-126m",
|
||||
sequences=sequences,
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = GPTSw3Tokenizer(SAMPLE_VOCAB)
|
||||
# This is in English, but it's just here to make sure the chat control tokens are being added properly
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [
|
||||
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ],
|
||||
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 938, 541, 419, ],
|
||||
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ]
|
||||
]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers.models.gptsan_japanese.tokenization_gptsan_japanese import (
|
||||
VOCAB_FILES_NAMES,
|
||||
GPTSanJapaneseTokenizer,
|
||||
)
|
||||
from transformers.testing_utils import require_tokenizers, slow
|
||||
from transformers.testing_utils import require_jinja, require_tokenizers, slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -193,3 +193,27 @@ class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_padding_different_model_input_name(self):
|
||||
# tokenizer has no padding token
|
||||
pass
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
|
||||
# This is in English, but it's just here to make sure the chat control tokens are being added properly
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [
|
||||
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
|
||||
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999, 35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999],
|
||||
[35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
|
||||
]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -2486,3 +2486,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Doesn't support another framework than PyTorch")
|
||||
def test_np_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -2439,3 +2439,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# This should not fail
|
||||
model(encoded_sequence)
|
||||
model(batch_encoded_sequence)
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -1958,3 +1958,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Doesn't use SentencePiece")
|
||||
def test_sentencepiece_tokenize_and_decode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -32,6 +32,7 @@ from transformers.convert_slow_tokenizer import convert_slow_tokenizer
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
nested_simplify,
|
||||
require_jinja,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
@@ -574,6 +575,32 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
# a dummy prefix space is not added by the sp_model as it was de-activated
|
||||
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
# Matt: The third test case tests the default system message, but if this is ever changed in the
|
||||
# class/repo code then that test will fail, and the case will need to be updated.
|
||||
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
# fmt: off
|
||||
expected_tokens = [
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962],
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2],
|
||||
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962]
|
||||
]
|
||||
# fmt: on
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
||||
@@ -2311,3 +2311,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"Dummy warning",
|
||||
cm.records[0].message,
|
||||
)
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -1274,3 +1274,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
@unittest.skip("Doesn't support another framework than PyTorch")
|
||||
def test_np_encode_plus_sent_to_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chat is not supported")
|
||||
def test_chat_template(self):
|
||||
pass
|
||||
|
||||
@@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.testing_utils import require_jinja, slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -473,3 +473,25 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
|
||||
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||
self.assertEqual(output, [])
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
||||
# This is in English, but it's just here to make sure the chat control tokens are being added properly
|
||||
test_chats = [
|
||||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful chatbot."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "Nice to meet you."},
|
||||
],
|
||||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
|
||||
]
|
||||
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
|
||||
expected_tokens = [
|
||||
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
|
||||
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
|
||||
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
|
||||
]
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@@ -78,17 +78,23 @@ class ConversationalPipelineTests(unittest.TestCase):
|
||||
def run_pipeline_test(self, conversation_agent, _):
|
||||
# Simple
|
||||
outputs = conversation_agent(Conversation("Hi there!"))
|
||||
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||
)
|
||||
|
||||
# Single list
|
||||
outputs = conversation_agent([Conversation("Hi there!")])
|
||||
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
|
||||
)
|
||||
|
||||
# Batch
|
||||
conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
|
||||
conversation_2 = Conversation("What's the last book you have read?")
|
||||
self.assertEqual(len(conversation_1.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_2.past_user_inputs), 0)
|
||||
self.assertEqual(len(conversation_1), 1)
|
||||
self.assertEqual(len(conversation_2), 1)
|
||||
|
||||
outputs = conversation_agent([conversation_1, conversation_2])
|
||||
self.assertEqual(outputs, [conversation_1, conversation_2])
|
||||
@@ -96,32 +102,35 @@ class ConversationalPipelineTests(unittest.TestCase):
|
||||
outputs,
|
||||
[
|
||||
Conversation(
|
||||
past_user_inputs=["Going to the movies tonight - any suggestions?"],
|
||||
generated_responses=[ANY(str)],
|
||||
[
|
||||
{"role": "user", "content": "Going to the movies tonight - any suggestions?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
[
|
||||
{"role": "user", "content": "What's the last book you have read?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
]
|
||||
),
|
||||
Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]),
|
||||
],
|
||||
)
|
||||
|
||||
# One conversation with history
|
||||
conversation_2.add_user_input("Why do you recommend it?")
|
||||
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
|
||||
outputs = conversation_agent(conversation_2)
|
||||
self.assertEqual(outputs, conversation_2)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
Conversation(
|
||||
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
|
||||
generated_responses=[ANY(str), ANY(str)],
|
||||
[
|
||||
{"role": "user", "content": "What's the last book you have read?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
{"role": "user", "content": "Why do you recommend it?"},
|
||||
{"role": "assistant", "content": ANY(str)},
|
||||
]
|
||||
),
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
conversation_agent("Hi there!")
|
||||
with self.assertRaises(ValueError):
|
||||
conversation_agent(Conversation())
|
||||
# Conversation have been consumed and are not valid anymore
|
||||
# Inactive conversations passed to the pipeline raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
conversation_agent(conversation_2)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
@@ -50,6 +50,7 @@ from transformers.testing_utils import (
|
||||
check_json_file_has_correct_format,
|
||||
get_tests_dir,
|
||||
is_pt_tf_cross_test,
|
||||
require_jinja,
|
||||
require_tf,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
@@ -1052,6 +1053,40 @@ class TokenizerTesterMixin:
|
||||
if tokenizer.num_special_tokens_to_add(pair=True):
|
||||
self.assertIn(None, output.sequence_ids())
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template(self):
|
||||
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
expected_output = "systemsystem messageuseruser messageassistantassistant message"
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
|
||||
# Check that no error raised when tokenize=True
|
||||
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
|
||||
|
||||
tokenizer.chat_template = dummy_template
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
||||
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
|
||||
|
||||
def test_number_of_added_tokens(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
|
||||
Reference in New Issue
Block a user