[Chat] Add Chat from TRL 🐈 (#35714)
* tmp commit * add working chat * add docts * docs 2 * use auto dtype by default
This commit is contained in:
@@ -23,8 +23,8 @@ of text (as is the case with a standard language model), the model instead conti
|
|||||||
of one or more **messages**, each of which includes a **role**, like "user" or "assistant", as well as message text.
|
of one or more **messages**, each of which includes a **role**, like "user" or "assistant", as well as message text.
|
||||||
|
|
||||||
Much like tokenization, different models expect very different input formats for chat. This is the reason we added
|
Much like tokenization, different models expect very different input formats for chat. This is the reason we added
|
||||||
**chat templates** as a feature. Chat templates are part of the tokenizer for text-only LLMs or processor for multimodal LLMs. They specify how to convert conversations,
|
**chat templates** as a feature. Chat templates are part of the tokenizer for text-only LLMs or processor for multimodal LLMs. They specify how to convert conversations,
|
||||||
represented as lists of messages, into a single tokenizable string in the format that the model expects.
|
represented as lists of messages, into a single tokenizable string in the format that the model expects.
|
||||||
|
|
||||||
Let's make this concrete with a quick example using the `mistralai/Mistral-7B-Instruct-v0.1` model:
|
Let's make this concrete with a quick example using the `mistralai/Mistral-7B-Instruct-v0.1` model:
|
||||||
|
|
||||||
@@ -42,8 +42,8 @@ Let's make this concrete with a quick example using the `mistralai/Mistral-7B-In
|
|||||||
"<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
|
"<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
|
||||||
```
|
```
|
||||||
|
|
||||||
Notice how the tokenizer has added the control tokens [INST] and [/INST] to indicate the start and end of
|
Notice how the tokenizer has added the control tokens [INST] and [/INST] to indicate the start and end of
|
||||||
user messages (but not assistant messages!), and the entire chat is condensed into a single string.
|
user messages (but not assistant messages!), and the entire chat is condensed into a single string.
|
||||||
If we use `tokenize=True`, which is the default setting, that string will also be tokenized for us.
|
If we use `tokenize=True`, which is the default setting, that string will also be tokenized for us.
|
||||||
|
|
||||||
Now, try the same code, but swap in the `HuggingFaceH4/zephyr-7b-beta` model instead, and you should get:
|
Now, try the same code, but swap in the `HuggingFaceH4/zephyr-7b-beta` model instead, and you should get:
|
||||||
@@ -59,9 +59,16 @@ I'd like to show off how chat templating works!</s>
|
|||||||
|
|
||||||
Both Zephyr and Mistral-Instruct were fine-tuned from the same base model, `Mistral-7B-v0.1`. However, they were trained
|
Both Zephyr and Mistral-Instruct were fine-tuned from the same base model, `Mistral-7B-v0.1`. However, they were trained
|
||||||
with totally different chat formats. Without chat templates, you would have to write manual formatting code for each
|
with totally different chat formats. Without chat templates, you would have to write manual formatting code for each
|
||||||
model, and it's very easy to make minor errors that hurt performance! Chat templates handle the details of formatting
|
model, and it's very easy to make minor errors that hurt performance! Chat templates handle the details of formatting
|
||||||
for you, allowing you to write universal code that works for any model.
|
for you, allowing you to write universal code that works for any model.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Chat templates are a critical component of our [`transformers-cli chat` CLI](quicktour#chat-with-text-generation-models).
|
||||||
|
You can apply the learnings of this guide there as well.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
|
||||||
## How do I use chat templates?
|
## How do I use chat templates?
|
||||||
|
|
||||||
@@ -69,7 +76,7 @@ As you can see in the example above, chat templates are easy to use. Simply buil
|
|||||||
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] or [`~ProcessorMixin.apply_chat_template`] method
|
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] or [`~ProcessorMixin.apply_chat_template`] method
|
||||||
depending on what type of model you are using. Once you do that,
|
depending on what type of model you are using. Once you do that,
|
||||||
you'll get output that's ready to go! When using chat templates as input for model generation, it's also a good idea
|
you'll get output that's ready to go! When using chat templates as input for model generation, it's also a good idea
|
||||||
to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts).
|
to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts).
|
||||||
|
|
||||||
## Usage with text-only LLMs
|
## Usage with text-only LLMs
|
||||||
Here's an example of preparing input for `model.generate()`, using `Zephyr` again:
|
Here's an example of preparing input for `model.generate()`, using `Zephyr` again:
|
||||||
@@ -91,19 +98,19 @@ messages = [
|
|||||||
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||||
print(tokenizer.decode(tokenized_chat[0]))
|
print(tokenizer.decode(tokenized_chat[0]))
|
||||||
```
|
```
|
||||||
This will yield a string in the input format that Zephyr expects.
|
This will yield a string in the input format that Zephyr expects.
|
||||||
```text
|
```text
|
||||||
<|system|>
|
<|system|>
|
||||||
You are a friendly chatbot who always responds in the style of a pirate</s>
|
You are a friendly chatbot who always responds in the style of a pirate</s>
|
||||||
<|user|>
|
<|user|>
|
||||||
How many helicopters can a human eat in one sitting?</s>
|
How many helicopters can a human eat in one sitting?</s>
|
||||||
<|assistant|>
|
<|assistant|>
|
||||||
```
|
```
|
||||||
|
|
||||||
Now that our input is formatted correctly for Zephyr, we can use the model to generate a response to the user's question:
|
Now that our input is formatted correctly for Zephyr, we can use the model to generate a response to the user's question:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
outputs = model.generate(tokenized_chat, max_new_tokens=128)
|
outputs = model.generate(tokenized_chat, max_new_tokens=128)
|
||||||
print(tokenizer.decode(outputs[0]))
|
print(tokenizer.decode(outputs[0]))
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -111,9 +118,9 @@ This will yield:
|
|||||||
|
|
||||||
```text
|
```text
|
||||||
<|system|>
|
<|system|>
|
||||||
You are a friendly chatbot who always responds in the style of a pirate</s>
|
You are a friendly chatbot who always responds in the style of a pirate</s>
|
||||||
<|user|>
|
<|user|>
|
||||||
How many helicopters can a human eat in one sitting?</s>
|
How many helicopters can a human eat in one sitting?</s>
|
||||||
<|assistant|>
|
<|assistant|>
|
||||||
Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
|
Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
|
||||||
```
|
```
|
||||||
@@ -152,7 +159,7 @@ print(processor.batch_decode(processed_chat["input_ids"][:, :30]))
|
|||||||
This yields a string in LLaVAs expected input format with many `<image>` tokens at the end.
|
This yields a string in LLaVAs expected input format with many `<image>` tokens at the end.
|
||||||
The `<image>` tokens are placeholders and each one will be replaced by image embeddings when the mode is run in the forward call. The `processed_chat` can be further passed into [`~GenerationMixin.generate`] to generate text.
|
The `<image>` tokens are placeholders and each one will be replaced by image embeddings when the mode is run in the forward call. The `processed_chat` can be further passed into [`~GenerationMixin.generate`] to generate text.
|
||||||
```text
|
```text
|
||||||
'<|im_start|>system
|
'<|im_start|>system
|
||||||
You are a friendly chatbot who always responds in the style of a pirate<|im_end|><|im_start|>user <image><image><image><image><image><image><image><image>'
|
You are a friendly chatbot who always responds in the style of a pirate<|im_end|><|im_start|>user <image><image><image><image><image><image><image><image>'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -162,7 +169,7 @@ Arr, 'twas easy after all!
|
|||||||
|
|
||||||
Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. In the past,
|
Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. In the past,
|
||||||
we used to use a dedicated "ConversationalPipeline" class, but this has now been deprecated and its functionality
|
we used to use a dedicated "ConversationalPipeline" class, but this has now been deprecated and its functionality
|
||||||
has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using
|
has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using
|
||||||
a pipeline:
|
a pipeline:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -227,9 +234,9 @@ Can I ask a question?<|im_end|>
|
|||||||
```
|
```
|
||||||
|
|
||||||
Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model
|
Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model
|
||||||
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
|
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
|
||||||
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
|
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
|
||||||
special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're
|
special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're
|
||||||
supposed to be doing.
|
supposed to be doing.
|
||||||
|
|
||||||
Not all models require generation prompts. Some models, like LLaMA, don't have any
|
Not all models require generation prompts. Some models, like LLaMA, don't have any
|
||||||
@@ -241,7 +248,7 @@ effect that `add_generation_prompt` has will depend on the template being used.
|
|||||||
When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
|
When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
|
||||||
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
|
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
|
||||||
by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply
|
by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply
|
||||||
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.
|
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.
|
||||||
|
|
||||||
Here's an example:
|
Here's an example:
|
||||||
|
|
||||||
@@ -266,9 +273,9 @@ get an error if you try!
|
|||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
|
The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
|
||||||
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
|
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
|
||||||
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
|
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
|
||||||
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_final_message`
|
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_final_message`
|
||||||
argument when calling the pipeline.
|
argument when calling the pipeline.
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
@@ -277,8 +284,8 @@ argument when calling the pipeline.
|
|||||||
|
|
||||||
Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.
|
Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.
|
||||||
We recommend that you apply the chat template as a preprocessing step for your dataset. After this, you
|
We recommend that you apply the chat template as a preprocessing step for your dataset. After this, you
|
||||||
can simply continue like any other language model training task. When training, you should usually set
|
can simply continue like any other language model training task. When training, you should usually set
|
||||||
`add_generation_prompt=False`, because the added tokens to prompt an assistant response will not be helpful during
|
`add_generation_prompt=False`, because the added tokens to prompt an assistant response will not be helpful during
|
||||||
training. Let's see an example:
|
training. Let's see an example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -312,8 +319,8 @@ From here, just continue training like you would with a standard language modell
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
By default, some tokenizers add special tokens like `<bos>` and `<eos>` to text they tokenize. Chat templates should
|
By default, some tokenizers add special tokens like `<bos>` and `<eos>` to text they tokenize. Chat templates should
|
||||||
already include all the special tokens they need, and so additional special tokens will often be incorrect or
|
already include all the special tokens they need, and so additional special tokens will often be incorrect or
|
||||||
duplicated, which will hurt model performance.
|
duplicated, which will hurt model performance.
|
||||||
|
|
||||||
Therefore, if you format text with `apply_chat_template(tokenize=False)`, you should set the argument
|
Therefore, if you format text with `apply_chat_template(tokenize=False)`, you should set the argument
|
||||||
@@ -326,7 +333,7 @@ Therefore, if you format text with `apply_chat_template(tokenize=False)`, you sh
|
|||||||
The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
|
The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
|
||||||
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
|
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
|
||||||
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
|
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
|
||||||
strings, lists, dicts or whatever else you want.
|
strings, lists, dicts or whatever else you want.
|
||||||
|
|
||||||
That said, there are some common use-cases for these extra arguments,
|
That said, there are some common use-cases for these extra arguments,
|
||||||
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
|
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
|
||||||
@@ -349,7 +356,7 @@ def current_time():
|
|||||||
def multiply(a: float, b: float):
|
def multiply(a: float, b: float):
|
||||||
"""
|
"""
|
||||||
A function that multiplies two numbers
|
A function that multiplies two numbers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: The first number to multiply
|
a: The first number to multiply
|
||||||
b: The second number to multiply
|
b: The second number to multiply
|
||||||
@@ -369,8 +376,8 @@ correctly as tools. Specifically, you should follow these rules:
|
|||||||
|
|
||||||
- The function should have a descriptive name
|
- The function should have a descriptive name
|
||||||
- Every argument must have a type hint
|
- Every argument must have a type hint
|
||||||
- The function must have a docstring in the standard Google style (in other words, an initial function description
|
- The function must have a docstring in the standard Google style (in other words, an initial function description
|
||||||
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
|
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
|
||||||
- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not
|
- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not
|
||||||
`a (int): The first number to multiply`. Type hints should go in the function header instead.
|
`a (int): The first number to multiply`. Type hints should go in the function header instead.
|
||||||
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
|
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
|
||||||
@@ -412,7 +419,7 @@ Next, let's define a list of tools:
|
|||||||
def get_current_temperature(location: str, unit: str) -> float:
|
def get_current_temperature(location: str, unit: str) -> float:
|
||||||
"""
|
"""
|
||||||
Get the current temperature at a location.
|
Get the current temperature at a location.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
location: The location to get the temperature for, in the format "City, Country"
|
location: The location to get the temperature for, in the format "City, Country"
|
||||||
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
|
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
|
||||||
@@ -424,7 +431,7 @@ def get_current_temperature(location: str, unit: str) -> float:
|
|||||||
def get_current_wind_speed(location: str) -> float:
|
def get_current_wind_speed(location: str) -> float:
|
||||||
"""
|
"""
|
||||||
Get the current wind speed in km/h at a given location.
|
Get the current wind speed in km/h at a given location.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
location: The location to get the temperature for, in the format "City, Country"
|
location: The location to get the temperature for, in the format "City, Country"
|
||||||
Returns:
|
Returns:
|
||||||
@@ -469,8 +476,8 @@ the temperature in France should certainly be displayed in Celsius.
|
|||||||
|
|
||||||
The output format above is specific to the `Hermes-2-Pro` model we're using in this example. Other models may emit different
|
The output format above is specific to the `Hermes-2-Pro` model we're using in this example. Other models may emit different
|
||||||
tool call formats, and you may need to do some manual parsing at this step. For example, `Llama-3.1` models will emit
|
tool call formats, and you may need to do some manual parsing at this step. For example, `Llama-3.1` models will emit
|
||||||
slightly different JSON, with `parameters` instead of `arguments`. Regardless of the format the model outputs, you
|
slightly different JSON, with `parameters` instead of `arguments`. Regardless of the format the model outputs, you
|
||||||
should add the tool call to the conversation in the format below, with `tool_calls`, `function` and `arguments` keys.
|
should add the tool call to the conversation in the format below, with `tool_calls`, `function` and `arguments` keys.
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
@@ -489,7 +496,7 @@ a dict, but in the OpenAI API it's a JSON string. Passing a string may cause err
|
|||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Now that we've added the tool call to the conversation, we can call the function and append the result to the
|
Now that we've added the tool call to the conversation, we can call the function and append the result to the
|
||||||
conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
|
conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
|
||||||
that result directly.
|
that result directly.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -500,7 +507,7 @@ messages.append({"role": "tool", "name": "get_current_temperature", "content": "
|
|||||||
|
|
||||||
Some model architectures, notably Mistral/Mixtral, also require a `tool_call_id` here, which should be
|
Some model architectures, notably Mistral/Mixtral, also require a `tool_call_id` here, which should be
|
||||||
9 randomly-generated alphanumeric characters, and assigned to the `id` key of the tool call
|
9 randomly-generated alphanumeric characters, and assigned to the `id` key of the tool call
|
||||||
dictionary. The same key should also be assigned to the `tool_call_id` key of the tool response dictionary below, so
|
dictionary. The same key should also be assigned to the `tool_call_id` key of the tool response dictionary below, so
|
||||||
that tool calls can be matched to tool responses. So, for Mistral/Mixtral models, the code above would be:
|
that tool calls can be matched to tool responses. So, for Mistral/Mixtral models, the code above would be:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -532,13 +539,13 @@ And we get:
|
|||||||
The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|>
|
The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|>
|
||||||
```
|
```
|
||||||
|
|
||||||
Although this was a simple demo with dummy tools and a single call, the same technique works with
|
Although this was a simple demo with dummy tools and a single call, the same technique works with
|
||||||
multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational
|
multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational
|
||||||
agents with real-time information, computational tools like calculators, or access to large databases.
|
agents with real-time information, computational tools like calculators, or access to large databases.
|
||||||
|
|
||||||
### Understanding tool schemas
|
### Understanding tool schemas
|
||||||
|
|
||||||
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
|
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
|
||||||
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas
|
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas
|
||||||
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
|
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
|
||||||
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
|
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
|
||||||
@@ -547,7 +554,7 @@ to read their outputs, detect if they have requested to use a tool, pass their a
|
|||||||
return the response in the chat.
|
return the response in the chat.
|
||||||
|
|
||||||
Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
|
Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
|
||||||
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
|
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
|
||||||
you can handle the conversion manually. Here is an example of a manual schema conversion.
|
you can handle the conversion manually. Here is an example of a manual schema conversion.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -556,7 +563,7 @@ from transformers.utils import get_json_schema
|
|||||||
def multiply(a: float, b: float):
|
def multiply(a: float, b: float):
|
||||||
"""
|
"""
|
||||||
A function that multiplies two numbers
|
A function that multiplies two numbers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a: The first number to multiply
|
a: The first number to multiply
|
||||||
b: The second number to multiply
|
b: The second number to multiply
|
||||||
@@ -571,33 +578,33 @@ This will yield:
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "multiply",
|
"name": "multiply",
|
||||||
"description": "A function that multiplies two numbers",
|
"description": "A function that multiplies two numbers",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {
|
"a": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "The first number to multiply"
|
"description": "The first number to multiply"
|
||||||
},
|
},
|
||||||
"b": {
|
"b": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "The second number to multiply"
|
"description": "The second number to multiply"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["a", "b"]
|
"required": ["a", "b"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
|
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
|
||||||
all. JSON schemas can be passed directly to the `tools` argument of
|
all. JSON schemas can be passed directly to the `tools` argument of
|
||||||
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
|
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
|
||||||
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
|
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
|
||||||
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
|
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
|
||||||
to a minimum.
|
to a minimum.
|
||||||
|
|
||||||
Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
|
Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
|
||||||
@@ -605,7 +612,7 @@ Here is an example of defining schemas by hand, and passing them directly to `ap
|
|||||||
```python
|
```python
|
||||||
# A simple function that takes no arguments
|
# A simple function that takes no arguments
|
||||||
current_time = {
|
current_time = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "current_time",
|
"name": "current_time",
|
||||||
"description": "Get the current local time as a string.",
|
"description": "Get the current local time as a string.",
|
||||||
@@ -621,18 +628,18 @@ multiply = {
|
|||||||
'type': 'function',
|
'type': 'function',
|
||||||
'function': {
|
'function': {
|
||||||
'name': 'multiply',
|
'name': 'multiply',
|
||||||
'description': 'A function that multiplies two numbers',
|
'description': 'A function that multiplies two numbers',
|
||||||
'parameters': {
|
'parameters': {
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
'properties': {
|
'properties': {
|
||||||
'a': {
|
'a': {
|
||||||
'type': 'number',
|
'type': 'number',
|
||||||
'description': 'The first number to multiply'
|
'description': 'The first number to multiply'
|
||||||
},
|
},
|
||||||
'b': {
|
'b': {
|
||||||
'type': 'number', 'description': 'The second number to multiply'
|
'type': 'number', 'description': 'The second number to multiply'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'required': ['a', 'b']
|
'required': ['a', 'b']
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -647,7 +654,7 @@ model_input = tokenizer.apply_chat_template(
|
|||||||
## Advanced: Retrieval-augmented generation
|
## Advanced: Retrieval-augmented generation
|
||||||
|
|
||||||
"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
|
"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
|
||||||
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
|
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
|
||||||
recommendation for RAG models is that their template
|
recommendation for RAG models is that their template
|
||||||
should accept a `documents` argument. This should be a list of documents, where each "document"
|
should accept a `documents` argument. This should be a list of documents, where each "document"
|
||||||
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
|
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
|
||||||
@@ -672,7 +679,7 @@ conversation = [
|
|||||||
# Define documents for retrieval-based generation
|
# Define documents for retrieval-based generation
|
||||||
documents = [
|
documents = [
|
||||||
{
|
{
|
||||||
"title": "The Moon: Our Age-Old Foe",
|
"title": "The Moon: Our Age-Old Foe",
|
||||||
"text": "Man has always dreamed of destroying the moon. In this essay, I shall..."
|
"text": "Man has always dreamed of destroying the moon. In this essay, I shall..."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -690,7 +697,7 @@ input_ids = tokenizer.apply_chat_template(
|
|||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
return_tensors="pt").to(device)
|
return_tensors="pt").to(device)
|
||||||
|
|
||||||
# Generate a response
|
# Generate a response
|
||||||
gen_tokens = model.generate(
|
gen_tokens = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
max_new_tokens=100,
|
max_new_tokens=100,
|
||||||
@@ -750,8 +757,8 @@ Effectively, the template does three things:
|
|||||||
an assistant response.
|
an assistant response.
|
||||||
|
|
||||||
This is a pretty simple template but Jinja gives you a lot of flexibility to do more complex things! Let's see a Jinja
|
This is a pretty simple template but Jinja gives you a lot of flexibility to do more complex things! Let's see a Jinja
|
||||||
template that can format inputs similarly to the way LLaMA formats them (note that the real LLaMA template includes
|
template that can format inputs similarly to the way LLaMA formats them (note that the real LLaMA template includes
|
||||||
handling for default system messages and slightly different system message handling in general - don't use this one
|
handling for default system messages and slightly different system message handling in general - don't use this one
|
||||||
in your actual code!)
|
in your actual code!)
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -774,7 +781,7 @@ distinguishable to the model because of the tokens they're wrapped in.
|
|||||||
|
|
||||||
### How do I create a chat template?
|
### How do I create a chat template?
|
||||||
|
|
||||||
Simple, just write a jinja template and set `tokenizer.chat_template`. You may find it easier to start with an
|
Simple, just write a jinja template and set `tokenizer.chat_template`. You may find it easier to start with an
|
||||||
existing template from another model and simply edit it for your needs! For example, we could take the LLaMA template
|
existing template from another model and simply edit it for your needs! For example, we could take the LLaMA template
|
||||||
above and add "[ASST]" and "[/ASST]" to assistant messages:
|
above and add "[ASST]" and "[/ASST]" to assistant messages:
|
||||||
|
|
||||||
@@ -802,13 +809,13 @@ tokenizer.chat_template = template # Set the new template
|
|||||||
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
||||||
```
|
```
|
||||||
|
|
||||||
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so
|
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so
|
||||||
once you set the correct chat template, your model will automatically become compatible with [`TextGenerationPipeline`].
|
once you set the correct chat template, your model will automatically become compatible with [`TextGenerationPipeline`].
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
If you're fine-tuning a model for chat, in addition to setting a chat template, you should probably add any new chat
|
If you're fine-tuning a model for chat, in addition to setting a chat template, you should probably add any new chat
|
||||||
control tokens as special tokens in the tokenizer. Special tokens are never split,
|
control tokens as special tokens in the tokenizer. Special tokens are never split,
|
||||||
ensuring that your control tokens are always handled as single tokens rather than being tokenized in pieces. You
|
ensuring that your control tokens are always handled as single tokens rather than being tokenized in pieces. You
|
||||||
should also set the tokenizer's `eos_token` attribute to the token that marks the end of assistant generations in your
|
should also set the tokenizer's `eos_token` attribute to the token that marks the end of assistant generations in your
|
||||||
template. This will ensure that text generation tools can correctly figure out when to stop generating text.
|
template. This will ensure that text generation tools can correctly figure out when to stop generating text.
|
||||||
</Tip>
|
</Tip>
|
||||||
@@ -836,13 +843,13 @@ trying to put it all in a single template where possible!
|
|||||||
|
|
||||||
When setting the template for a model that's already been trained for chat, you should ensure that the template
|
When setting the template for a model that's already been trained for chat, you should ensure that the template
|
||||||
exactly matches the message formatting that the model saw during training, or else you will probably experience
|
exactly matches the message formatting that the model saw during training, or else you will probably experience
|
||||||
performance degradation. This is true even if you're training the model further - you will probably get the best
|
performance degradation. This is true even if you're training the model further - you will probably get the best
|
||||||
performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the
|
performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the
|
||||||
best performance for inference or fine-tuning when you precisely match the tokenization used during training.
|
best performance for inference or fine-tuning when you precisely match the tokenization used during training.
|
||||||
|
|
||||||
If you're training a model from scratch, or fine-tuning a base language model for chat, on the other hand,
|
If you're training a model from scratch, or fine-tuning a base language model for chat, on the other hand,
|
||||||
you have a lot of freedom to choose an appropriate template! LLMs are smart enough to learn to handle lots of different
|
you have a lot of freedom to choose an appropriate template! LLMs are smart enough to learn to handle lots of different
|
||||||
input formats. One popular choice is the `ChatML` format, and this is a good, flexible choice for many use-cases.
|
input formats. One popular choice is the `ChatML` format, and this is a good, flexible choice for many use-cases.
|
||||||
It looks like this:
|
It looks like this:
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -888,7 +895,7 @@ Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_templat
|
|||||||
model, which means it is also automatically supported in places like `TextGenerationPipeline`!
|
model, which means it is also automatically supported in places like `TextGenerationPipeline`!
|
||||||
|
|
||||||
By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of
|
By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of
|
||||||
open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
|
open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
|
||||||
it's time to put an end to them!
|
it's time to put an end to them!
|
||||||
|
|
||||||
## Advanced: Template writing tips
|
## Advanced: Template writing tips
|
||||||
@@ -896,17 +903,17 @@ it's time to put an end to them!
|
|||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
The easiest way to get started with writing Jinja templates is to take a look at some existing ones. You can use
|
The easiest way to get started with writing Jinja templates is to take a look at some existing ones. You can use
|
||||||
`print(tokenizer.chat_template)` for any chat model to see what template it's using. In general, models that support tool use have
|
`print(tokenizer.chat_template)` for any chat model to see what template it's using. In general, models that support tool use have
|
||||||
much more complex templates than other models - so when you're just getting started, they're probably a bad example
|
much more complex templates than other models - so when you're just getting started, they're probably a bad example
|
||||||
to learn from! You can also take a look at the
|
to learn from! You can also take a look at the
|
||||||
[Jinja documentation](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) for details
|
[Jinja documentation](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) for details
|
||||||
of general Jinja formatting and syntax.
|
of general Jinja formatting and syntax.
|
||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Jinja templates in `transformers` are identical to Jinja templates elsewhere. The main thing to know is that
|
Jinja templates in `transformers` are identical to Jinja templates elsewhere. The main thing to know is that
|
||||||
the conversation history will be accessible inside your template as a variable called `messages`.
|
the conversation history will be accessible inside your template as a variable called `messages`.
|
||||||
You will be able to access `messages` in your template just like you can in Python, which means you can loop over
|
You will be able to access `messages` in your template just like you can in Python, which means you can loop over
|
||||||
it with `{% for message in messages %}` or access individual messages with `{{ messages[0] }}`, for example.
|
it with `{% for message in messages %}` or access individual messages with `{{ messages[0] }}`, for example.
|
||||||
|
|
||||||
You can also use the following tips to write clean, efficient Jinja templates:
|
You can also use the following tips to write clean, efficient Jinja templates:
|
||||||
@@ -936,7 +943,7 @@ and indentation may end up being included in the output, which is probably not w
|
|||||||
|
|
||||||
### Special variables
|
### Special variables
|
||||||
|
|
||||||
Inside your template, you will have access several special variables. The most important of these is `messages`,
|
Inside your template, you will have access several special variables. The most important of these is `messages`,
|
||||||
which contains the chat history as a list of message dicts. However, there are several others. Not every
|
which contains the chat history as a list of message dicts. However, there are several others. Not every
|
||||||
variable will be used in every template. The most common other variables are:
|
variable will be used in every template. The most common other variables are:
|
||||||
|
|
||||||
@@ -970,7 +977,7 @@ There are multiple implementations of Jinja in various languages. They generally
|
|||||||
but a key difference is that when you're writing a template in Python you can use Python methods, such as
|
but a key difference is that when you're writing a template in Python you can use Python methods, such as
|
||||||
`.lower()` on strings or `.items()` on dicts. This will break if someone tries to use your template on a non-Python
|
`.lower()` on strings or `.items()` on dicts. This will break if someone tries to use your template on a non-Python
|
||||||
implementation of Jinja. Non-Python implementations are particularly common in deployment environments, where JS
|
implementation of Jinja. Non-Python implementations are particularly common in deployment environments, where JS
|
||||||
and Rust are very popular.
|
and Rust are very popular.
|
||||||
|
|
||||||
Don't panic, though! There are a few easy changes you can make to your templates to ensure they're compatible across
|
Don't panic, though! There are a few easy changes you can make to your templates to ensure they're compatible across
|
||||||
all implementations of Jinja:
|
all implementations of Jinja:
|
||||||
@@ -1002,21 +1009,21 @@ Here is an example of a template that formats messages ChatML-style, with genera
|
|||||||
```
|
```
|
||||||
|
|
||||||
The exact content of the assistant header will depend on your specific model, but it should always be **the string
|
The exact content of the assistant header will depend on your specific model, but it should always be **the string
|
||||||
that represents the start of an assistant message**, so that if the user applies your template with
|
that represents the start of an assistant message**, so that if the user applies your template with
|
||||||
`add_generation_prompt=True` and then generates text, the model will write an assistant response. Also note that some
|
`add_generation_prompt=True` and then generates text, the model will write an assistant response. Also note that some
|
||||||
models do not need a generation prompt, because assistant messages always begin immediately after user messages.
|
models do not need a generation prompt, because assistant messages always begin immediately after user messages.
|
||||||
This is particularly common for LLaMA and Mistral models, where assistant messages begin immediately after the `[/INST]`
|
This is particularly common for LLaMA and Mistral models, where assistant messages begin immediately after the `[/INST]`
|
||||||
token that ends user messages. In these cases, the template can ignore the `add_generation_prompt` flag.
|
token that ends user messages. In these cases, the template can ignore the `add_generation_prompt` flag.
|
||||||
|
|
||||||
Generation prompts are important! If your model requires a generation prompt but it is not set in the template, then
|
Generation prompts are important! If your model requires a generation prompt but it is not set in the template, then
|
||||||
model generations will likely be severely degraded, or the model may display unusual behaviour like continuing
|
model generations will likely be severely degraded, or the model may display unusual behaviour like continuing
|
||||||
the final user message!
|
the final user message!
|
||||||
|
|
||||||
### Writing and debugging larger templates
|
### Writing and debugging larger templates
|
||||||
|
|
||||||
When this feature was introduced, most templates were quite small, the Jinja equivalent of a "one-liner" script.
|
When this feature was introduced, most templates were quite small, the Jinja equivalent of a "one-liner" script.
|
||||||
However, with new models and features like tool-use and RAG, some templates can be 100 lines long or more. When
|
However, with new models and features like tool-use and RAG, some templates can be 100 lines long or more. When
|
||||||
writing templates like these, it's a good idea to write them in a separate file, using a text editor. You can easily
|
writing templates like these, it's a good idea to write them in a separate file, using a text editor. You can easily
|
||||||
extract a chat template to a file:
|
extract a chat template to a file:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -1035,7 +1042,7 @@ identify the source of issues.
|
|||||||
|
|
||||||
### Writing templates for tools
|
### Writing templates for tools
|
||||||
|
|
||||||
Although chat templates do not enforce a specific API for tools (or for anything, really), we recommend
|
Although chat templates do not enforce a specific API for tools (or for anything, really), we recommend
|
||||||
template authors try to stick to a standard API where possible. The whole point of chat templates is to allow code
|
template authors try to stick to a standard API where possible. The whole point of chat templates is to allow code
|
||||||
to be transferable across models, so deviating from the standard tools API means users will have to write
|
to be transferable across models, so deviating from the standard tools API means users will have to write
|
||||||
custom code to use tools with your model. Sometimes it's unavoidable, but often with clever templating you can
|
custom code to use tools with your model. Sometimes it's unavoidable, but often with clever templating you can
|
||||||
@@ -1045,30 +1052,30 @@ Below, we'll list the elements of the standard API, and give tips on writing tem
|
|||||||
|
|
||||||
#### Tool definitions
|
#### Tool definitions
|
||||||
|
|
||||||
Your template should expect that the variable `tools` will either be null (if no tools are passed), or is a list
|
Your template should expect that the variable `tools` will either be null (if no tools are passed), or is a list
|
||||||
of JSON schema dicts. Our chat template methods allow users to pass tools as either JSON schema or Python functions, but when
|
of JSON schema dicts. Our chat template methods allow users to pass tools as either JSON schema or Python functions, but when
|
||||||
functions are passed, we automatically generate JSON schema and pass that to your template. As a result, the
|
functions are passed, we automatically generate JSON schema and pass that to your template. As a result, the
|
||||||
`tools` variable that your template receives will always be a list of JSON schema. Here is
|
`tools` variable that your template receives will always be a list of JSON schema. Here is
|
||||||
a sample tool JSON schema:
|
a sample tool JSON schema:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "multiply",
|
"name": "multiply",
|
||||||
"description": "A function that multiplies two numbers",
|
"description": "A function that multiplies two numbers",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"a": {
|
"a": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "The first number to multiply"
|
"description": "The first number to multiply"
|
||||||
},
|
},
|
||||||
"b": {
|
"b": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"description": "The second number to multiply"
|
"description": "The second number to multiply"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["a", "b"]
|
"required": ["a", "b"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1092,13 +1099,13 @@ specific format - your model will probably need different formatting!
|
|||||||
|
|
||||||
The specific tokens and tool descriptions your template renders should of course be chosen to match the ones your model
|
The specific tokens and tool descriptions your template renders should of course be chosen to match the ones your model
|
||||||
was trained with. There is no requirement that your **model** understands JSON schema input, only that your template can translate
|
was trained with. There is no requirement that your **model** understands JSON schema input, only that your template can translate
|
||||||
JSON schema into your model's format. For example, [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-plus-08-2024)
|
JSON schema into your model's format. For example, [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-plus-08-2024)
|
||||||
was trained with tools defined using Python function headers, but the Command-R tool template accepts JSON schema,
|
was trained with tools defined using Python function headers, but the Command-R tool template accepts JSON schema,
|
||||||
converts types internally and renders the input tools as Python headers. You can do a lot with templates!
|
converts types internally and renders the input tools as Python headers. You can do a lot with templates!
|
||||||
|
|
||||||
#### Tool calls
|
#### Tool calls
|
||||||
|
|
||||||
Tool calls, if present, will be a list attached to a message with the "assistant" role. Note that `tool_calls` is
|
Tool calls, if present, will be a list attached to a message with the "assistant" role. Note that `tool_calls` is
|
||||||
always a list, even though most tool-calling models only support single tool calls at a time, which means
|
always a list, even though most tool-calling models only support single tool calls at a time, which means
|
||||||
the list will usually only have a single element. Here is a sample message dict containing a tool call:
|
the list will usually only have a single element. Here is a sample message dict containing a tool call:
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,13 @@ This guide describes:
|
|||||||
* common decoding strategies and their main parameters
|
* common decoding strategies and their main parameters
|
||||||
* saving and sharing custom generation configurations with your fine-tuned model on 🤗 Hub
|
* saving and sharing custom generation configurations with your fine-tuned model on 🤗 Hub
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
`generate()` is a critical component of our [`transformers-cli chat` CLI](quicktour#chat-with-text-generation-models).
|
||||||
|
You can apply the learnings of this guide there as well.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
## Default text generation configuration
|
## Default text generation configuration
|
||||||
|
|
||||||
A decoding strategy for a model is defined in its generation configuration. When using pre-trained models for inference
|
A decoding strategy for a model is defined in its generation configuration. When using pre-trained models for inference
|
||||||
|
|||||||
@@ -23,6 +23,12 @@ LLMs, or Large Language Models, are the key component behind text generation. In
|
|||||||
|
|
||||||
Autoregressive generation is the inference-time procedure of iteratively calling a model with its own generated outputs, given a few initial inputs. In 🤗 Transformers, this is handled by the [`~generation.GenerationMixin.generate`] method, which is available to all models with generative capabilities.
|
Autoregressive generation is the inference-time procedure of iteratively calling a model with its own generated outputs, given a few initial inputs. In 🤗 Transformers, this is handled by the [`~generation.GenerationMixin.generate`] method, which is available to all models with generative capabilities.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
If you want to jump straight to chatting with a model, [try our `transformers-cli chat` CLI](quicktour#chat-with-text-generation-models).
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
This tutorial will show you how to:
|
This tutorial will show you how to:
|
||||||
|
|
||||||
* Generate text with an LLM
|
* Generate text with an LLM
|
||||||
|
|||||||
@@ -553,6 +553,32 @@ All models are a standard [`tf.keras.Model`](https://www.tensorflow.org/api_docs
|
|||||||
>>> model.fit(tf_dataset) # doctest: +SKIP
|
>>> model.fit(tf_dataset) # doctest: +SKIP
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Chat with text generation models
|
||||||
|
|
||||||
|
If you're working with a model that generates text as an output, you can also engage in a multi-turn conversation with
|
||||||
|
it through the `transformers-cli chat` command. This is the fastest way to interact with a model, e.g. for a
|
||||||
|
qualitative assessment (aka vibe check).
|
||||||
|
|
||||||
|
This CLI is implemented on top of our `AutoClass` abstraction, leveraging our [text generation](llm_tutorial.md) and
|
||||||
|
[chat](chat_templating.md) tooling, and thus will be compatible with any 🤗 Transformers model. If you have the library
|
||||||
|
[installed](installation.md), you can launch the chat session on your terminal with
|
||||||
|
|
||||||
|
```
|
||||||
|
transformers-cli chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
For a full list of options to launch the chat, type
|
||||||
|
|
||||||
|
```
|
||||||
|
transformers-cli chat -h
|
||||||
|
```
|
||||||
|
|
||||||
|
After the chat is launched, you will enter an interactive session with the model. There are special commands for this
|
||||||
|
session as well, such as `clear` to reset the conversation. Type `help` at any moment to display all special chat
|
||||||
|
commands, and `exit` to terminate the session.
|
||||||
|
|
||||||
|
|
||||||
## What's next?
|
## What's next?
|
||||||
|
|
||||||
Now that you've completed the 🤗 Transformers quick tour, check out our guides and learn how to do more specific things like writing a custom model, fine-tuning a model for a task, and how to train a model with a script. If you're interested in learning more about 🤗 Transformers core concepts, grab a cup of coffee and take a look at our Conceptual Guides!
|
Now that you've completed the 🤗 Transformers quick tour, check out our guides and learn how to do more specific things like writing a custom model, fine-tuning a model for a task, and how to train a model with a script. If you're interested in learning more about 🤗 Transformers core concepts, grab a cup of coffee and take a look at our Conceptual Guides!
|
||||||
|
|||||||
539
src/transformers/commands/chat.py
Normal file
539
src/transformers/commands/chat.py
Normal file
@@ -0,0 +1,539 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser, Namespace
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
||||||
|
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
|
if platform.system() != "Windows":
|
||||||
|
import pwd
|
||||||
|
|
||||||
|
|
||||||
|
HELP_STRING = """\
|
||||||
|
|
||||||
|
**TRANSFORMERS CHAT INTERFACE**
|
||||||
|
|
||||||
|
The chat interface is a simple tool to try out a chat model.
|
||||||
|
|
||||||
|
Besides talking to the model there are several commands:
|
||||||
|
- **help**: show this help message
|
||||||
|
- **clear**: clears the current conversation and start a new one
|
||||||
|
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||||
|
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||||
|
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||||
|
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||||
|
- **exit**: closes the interface
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_GENERATION_KWARGS = [
|
||||||
|
"max_new_tokens",
|
||||||
|
"do_sample",
|
||||||
|
"num_beams",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"repetition_penalty",
|
||||||
|
]
|
||||||
|
|
||||||
|
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
|
||||||
|
|
||||||
|
DEFAULT_EXAMPLES = {
|
||||||
|
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
|
||||||
|
"code": {
|
||||||
|
"text": "Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]."
|
||||||
|
},
|
||||||
|
"helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
|
||||||
|
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
|
||||||
|
"birds": {"text": "Why aren't birds real?"},
|
||||||
|
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_username():
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
return os.getlogin()
|
||||||
|
else:
|
||||||
|
return pwd.getpwuid(os.getuid()).pw_name
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_filename(model_name):
|
||||||
|
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
return f"{model_name}/chat_{time_str}.json"
|
||||||
|
|
||||||
|
|
||||||
|
def save_chat(chat, args, filename):
|
||||||
|
output_dict = {}
|
||||||
|
output_dict["settings"] = vars(args)
|
||||||
|
output_dict["chat_history"] = chat
|
||||||
|
|
||||||
|
folder = args.save_folder
|
||||||
|
|
||||||
|
if filename is None:
|
||||||
|
filename = create_default_filename(args.model_name_or_path)
|
||||||
|
filename = os.path.join(folder, filename)
|
||||||
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
|
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(output_dict, f, indent=4)
|
||||||
|
return os.path.abspath(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_chat_history(system_prompt):
|
||||||
|
if system_prompt is None:
|
||||||
|
chat = []
|
||||||
|
else:
|
||||||
|
chat = [{"role": "system", "content": system_prompt}]
|
||||||
|
return chat
|
||||||
|
|
||||||
|
|
||||||
|
def parse_settings(user_input, current_args, interface):
|
||||||
|
settings = user_input[4:].strip().split(";")
|
||||||
|
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
||||||
|
settings = dict(settings)
|
||||||
|
error = False
|
||||||
|
|
||||||
|
for name in settings:
|
||||||
|
if hasattr(current_args, name):
|
||||||
|
try:
|
||||||
|
if isinstance(getattr(current_args, name), bool):
|
||||||
|
if settings[name] == "True":
|
||||||
|
settings[name] = True
|
||||||
|
elif settings[name] == "False":
|
||||||
|
settings[name] = False
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
settings[name] = type(getattr(current_args, name))(settings[name])
|
||||||
|
except ValueError:
|
||||||
|
interface.print_red(
|
||||||
|
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
interface.print_red(f"There is no '{name}' setting.")
|
||||||
|
|
||||||
|
if error:
|
||||||
|
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
|
||||||
|
return current_args, False
|
||||||
|
else:
|
||||||
|
for name in settings:
|
||||||
|
setattr(current_args, name, settings[name])
|
||||||
|
interface.print_green(f"Set {name} to {settings[name]}.")
|
||||||
|
|
||||||
|
time.sleep(1.5) # so the user has time to read the changes
|
||||||
|
return current_args, True
|
||||||
|
|
||||||
|
|
||||||
|
def get_quantization_config(model_args) -> Optional[BitsAndBytesConfig]:
|
||||||
|
if model_args.load_in_4bit:
|
||||||
|
quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
|
||||||
|
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||||
|
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||||
|
bnb_4bit_quant_storage=model_args.torch_dtype,
|
||||||
|
)
|
||||||
|
elif model_args.load_in_8bit:
|
||||||
|
quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_8bit=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
quantization_config = None
|
||||||
|
|
||||||
|
return quantization_config
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(args):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
revision=args.model_revision,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||||
|
quantization_config = get_quantization_config(args)
|
||||||
|
model_kwargs = {
|
||||||
|
"revision": args.model_revision,
|
||||||
|
"attn_implementation": args.attn_implementation,
|
||||||
|
"torch_dtype": torch_dtype,
|
||||||
|
"device_map": "auto",
|
||||||
|
"quantization_config": quantization_config,
|
||||||
|
}
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(model, "hf_device_map", None) is None:
|
||||||
|
model = model.to(args.device)
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
pad_token_id = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
all_eos_token_ids = []
|
||||||
|
|
||||||
|
if eos_tokens is not None:
|
||||||
|
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
|
||||||
|
|
||||||
|
if eos_token_ids is not None:
|
||||||
|
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
||||||
|
|
||||||
|
if len(all_eos_token_ids) == 0:
|
||||||
|
all_eos_token_ids.append(tokenizer.eos_token_id)
|
||||||
|
|
||||||
|
return pad_token_id, all_eos_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
class RichInterface:
|
||||||
|
def __init__(self, model_name=None, user_name=None):
|
||||||
|
self._console = Console()
|
||||||
|
if model_name is None:
|
||||||
|
self.model_name = "assistant"
|
||||||
|
else:
|
||||||
|
self.model_name = model_name
|
||||||
|
if user_name is None:
|
||||||
|
self.user_name = "user"
|
||||||
|
else:
|
||||||
|
self.user_name = user_name
|
||||||
|
|
||||||
|
def stream_output(self, output_stream):
|
||||||
|
"""Stream output from a role."""
|
||||||
|
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||||
|
# Create a Live context for updating the console output
|
||||||
|
text = ""
|
||||||
|
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||||
|
with Live(console=self._console, refresh_per_second=4) as live:
|
||||||
|
# Read lines from the stream
|
||||||
|
for i, outputs in enumerate(output_stream):
|
||||||
|
if not outputs or i == 0:
|
||||||
|
continue
|
||||||
|
text += outputs
|
||||||
|
# Render the accumulated text as Markdown
|
||||||
|
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||||
|
# in rich. The chatbots output treat "\n" as a new line for
|
||||||
|
# better compatibility with real-world text. However, rendering
|
||||||
|
# in markdown would break the format. It is because standard markdown
|
||||||
|
# treat a single "\n" in normal text as a space.
|
||||||
|
# Our workaround is adding two spaces at the end of each line.
|
||||||
|
# This is not a perfect solution, as it would
|
||||||
|
# introduce trailing spaces (only) in code block, but it works well
|
||||||
|
# especially for console output, because in general the console does not
|
||||||
|
# care about trailing spaces.
|
||||||
|
lines = []
|
||||||
|
for line in text.splitlines():
|
||||||
|
lines.append(line)
|
||||||
|
if line.startswith("```"):
|
||||||
|
# Code block marker - do not add trailing spaces, as it would
|
||||||
|
# break the syntax highlighting
|
||||||
|
lines.append("\n")
|
||||||
|
else:
|
||||||
|
lines.append(" \n")
|
||||||
|
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||||
|
# Update the Live console output
|
||||||
|
live.update(markdown)
|
||||||
|
self._console.print()
|
||||||
|
return text
|
||||||
|
|
||||||
|
def input(self):
|
||||||
|
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
|
||||||
|
self._console.print()
|
||||||
|
return input
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._console.clear()
|
||||||
|
|
||||||
|
def print_user_message(self, text):
|
||||||
|
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
|
||||||
|
self._console.print()
|
||||||
|
|
||||||
|
def print_green(self, text):
|
||||||
|
self._console.print(f"[bold green]{text}")
|
||||||
|
self._console.print()
|
||||||
|
|
||||||
|
def print_red(self, text):
|
||||||
|
self._console.print(f"[bold red]{text}")
|
||||||
|
self._console.print()
|
||||||
|
|
||||||
|
def print_help(self):
|
||||||
|
self._console.print(Markdown(HELP_STRING))
|
||||||
|
self._console.print()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatArguments:
|
||||||
|
r"""
|
||||||
|
Arguments for the chat script.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name_or_path (`str`):
|
||||||
|
Name of the pre-trained model.
|
||||||
|
user (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
Username to display in chat interface.
|
||||||
|
system_prompt (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
System prompt.
|
||||||
|
save_folder (`str`, *optional*, defaults to `"./chat_history/"`):
|
||||||
|
Folder to save chat history.
|
||||||
|
device (`str`, *optional*, defaults to `"cpu"`):
|
||||||
|
Device to use for inference.
|
||||||
|
examples_path (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
Path to a yaml file with examples.
|
||||||
|
max_new_tokens (`int`, *optional*, defaults to `256`):
|
||||||
|
Maximum number of tokens to generate.
|
||||||
|
do_sample (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to sample outputs during generation.
|
||||||
|
num_beams (`int`, *optional*, defaults to `1`):
|
||||||
|
Number of beams for beam search.
|
||||||
|
temperature (`float`, *optional*, defaults to `1.0`):
|
||||||
|
Temperature parameter for generation.
|
||||||
|
top_k (`int`, *optional*, defaults to `50`):
|
||||||
|
Value of k for top-k sampling.
|
||||||
|
top_p (`float`, *optional*, defaults to `1.0`):
|
||||||
|
Value of p for nucleus sampling.
|
||||||
|
repetition_penalty (`float`, *optional*, defaults to `1.0`):
|
||||||
|
Repetition penalty.
|
||||||
|
eos_tokens (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
EOS tokens to stop the generation. If multiple they should be comma separated.
|
||||||
|
eos_token_ids (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
EOS token IDs to stop the generation. If multiple they should be comma separated.
|
||||||
|
model_revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
Specific model version to use (can be a branch name, tag name or commit id).
|
||||||
|
torch_dtype (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype
|
||||||
|
will be automatically derived from the model's weights.
|
||||||
|
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to trust remote code when loading a model.
|
||||||
|
attn_implementation (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case
|
||||||
|
you must install this manually by running `pip install flash-attn --no-build-isolation`.
|
||||||
|
load_in_8bit (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use 8 bit precision for the base model - works only with LoRA.
|
||||||
|
load_in_4bit (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use 4 bit precision for the base model - works only with LoRA.
|
||||||
|
bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`):
|
||||||
|
Quantization type.
|
||||||
|
use_bnb_nested_quant (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use nested quantization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# General settings
|
||||||
|
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model."})
|
||||||
|
user: Optional[str] = field(default=None, metadata={"help": "Username to display in chat interface."})
|
||||||
|
system_prompt: Optional[str] = field(default=None, metadata={"help": "System prompt."})
|
||||||
|
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history."})
|
||||||
|
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
||||||
|
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
||||||
|
|
||||||
|
# Generation settings
|
||||||
|
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
||||||
|
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
||||||
|
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
||||||
|
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation."})
|
||||||
|
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling."})
|
||||||
|
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling."})
|
||||||
|
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
|
||||||
|
eos_tokens: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."},
|
||||||
|
)
|
||||||
|
eos_token_ids: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model loading
|
||||||
|
model_revision: str = field(
|
||||||
|
default="main",
|
||||||
|
metadata={"help": "Specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
|
)
|
||||||
|
torch_dtype: Optional[str] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"help": "Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, "
|
||||||
|
"the dtype will be automatically derived from the model's weights.",
|
||||||
|
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
trust_remote_code: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether to trust remote code when loading a model."}
|
||||||
|
)
|
||||||
|
attn_implementation: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in "
|
||||||
|
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
load_in_8bit: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use 8 bit precision for the base model - works only with LoRA."},
|
||||||
|
)
|
||||||
|
load_in_4bit: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use 4 bit precision for the base model - works only with LoRA."},
|
||||||
|
)
|
||||||
|
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||||
|
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||||
|
|
||||||
|
|
||||||
|
def chat_command_factory(args: Namespace):
|
||||||
|
"""
|
||||||
|
Factory function used to chat with a local model.
|
||||||
|
"""
|
||||||
|
return ChatCommand(args)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCommand(BaseTransformersCLICommand):
|
||||||
|
@staticmethod
|
||||||
|
def register_subcommand(parser: ArgumentParser):
|
||||||
|
"""
|
||||||
|
Register this command to argparse so it's available for the transformer-cli
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parser: Root parser to register command-specific arguments
|
||||||
|
"""
|
||||||
|
dataclass_types = (ChatArguments,)
|
||||||
|
chat_parser = parser.add_parser("chat", help=HELP_STRING, dataclass_types=dataclass_types)
|
||||||
|
chat_parser.set_defaults(func=chat_command_factory)
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
args = self.args
|
||||||
|
if args.examples_path is None:
|
||||||
|
examples = DEFAULT_EXAMPLES
|
||||||
|
else:
|
||||||
|
with open(args.examples_path) as f:
|
||||||
|
examples = yaml.safe_load(f)
|
||||||
|
|
||||||
|
current_args = copy.deepcopy(args)
|
||||||
|
|
||||||
|
if args.user is None:
|
||||||
|
user = get_username()
|
||||||
|
else:
|
||||||
|
user = args.user
|
||||||
|
|
||||||
|
model, tokenizer = load_model_and_tokenizer(args)
|
||||||
|
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||||
|
|
||||||
|
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||||
|
|
||||||
|
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||||
|
interface.clear()
|
||||||
|
chat = clear_chat_history(current_args.system_prompt)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = interface.input()
|
||||||
|
|
||||||
|
if user_input == "clear":
|
||||||
|
chat = clear_chat_history(current_args.system_prompt)
|
||||||
|
interface.clear()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input == "help":
|
||||||
|
interface.print_help()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input == "exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
if user_input == "reset":
|
||||||
|
interface.clear()
|
||||||
|
current_args = copy.deepcopy(args)
|
||||||
|
chat = clear_chat_history(current_args.system_prompt)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input.startswith("save") and len(user_input.split()) < 2:
|
||||||
|
split_input = user_input.split()
|
||||||
|
|
||||||
|
if len(split_input) == 2:
|
||||||
|
filename = split_input[1]
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
filename = save_chat(chat, current_args, filename)
|
||||||
|
interface.print_green(f"Chat saved in {filename}!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if re.match(SETTING_RE, user_input):
|
||||||
|
current_args, success = parse_settings(user_input, current_args, interface)
|
||||||
|
if success:
|
||||||
|
chat = []
|
||||||
|
interface.clear()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input.startswith("example") and len(user_input.split()) == 2:
|
||||||
|
example_name = user_input.split()[1]
|
||||||
|
if example_name in examples:
|
||||||
|
interface.clear()
|
||||||
|
chat = []
|
||||||
|
interface.print_user_message(examples[example_name]["text"])
|
||||||
|
user_input = examples[example_name]["text"]
|
||||||
|
else:
|
||||||
|
interface.print_red(
|
||||||
|
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
chat.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||||
|
model.device
|
||||||
|
)
|
||||||
|
attention_mask = torch.ones_like(inputs)
|
||||||
|
generation_kwargs = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"streamer": generation_streamer,
|
||||||
|
"max_new_tokens": current_args.max_new_tokens,
|
||||||
|
"do_sample": current_args.do_sample,
|
||||||
|
"num_beams": current_args.num_beams,
|
||||||
|
"temperature": current_args.temperature,
|
||||||
|
"top_k": current_args.top_k,
|
||||||
|
"top_p": current_args.top_p,
|
||||||
|
"repetition_penalty": current_args.repetition_penalty,
|
||||||
|
"pad_token_id": pad_token_id,
|
||||||
|
"eos_token_id": eos_token_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
model_output = interface.stream_output(generation_streamer)
|
||||||
|
thread.join()
|
||||||
|
chat.append({"role": "assistant", "content": model_output})
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
break
|
||||||
@@ -13,9 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
from .add_new_model_like import AddNewModelLikeCommand
|
from .add_new_model_like import AddNewModelLikeCommand
|
||||||
|
from .chat import ChatCommand
|
||||||
from .convert import ConvertCommand
|
from .convert import ConvertCommand
|
||||||
from .download import DownloadCommand
|
from .download import DownloadCommand
|
||||||
from .env import EnvironmentCommand
|
from .env import EnvironmentCommand
|
||||||
@@ -26,10 +27,11 @@ from .user import UserCommands
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
|
parser = HfArgumentParser(prog="Transformers CLI tool", usage="transformers-cli <command> [<args>]")
|
||||||
commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
|
commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
|
||||||
|
|
||||||
# Register commands
|
# Register commands
|
||||||
|
ChatCommand.register_subcommand(commands_parser)
|
||||||
ConvertCommand.register_subcommand(commands_parser)
|
ConvertCommand.register_subcommand(commands_parser)
|
||||||
DownloadCommand.register_subcommand(commands_parser)
|
DownloadCommand.register_subcommand(commands_parser)
|
||||||
EnvironmentCommand.register_subcommand(commands_parser)
|
EnvironmentCommand.register_subcommand(commands_parser)
|
||||||
|
|||||||
@@ -114,18 +114,23 @@ class HfArgumentParser(ArgumentParser):
|
|||||||
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
|
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
|
||||||
arguments to the parser after initialization and you'll get the output back after parsing as an additional
|
arguments to the parser after initialization and you'll get the output back after parsing as an additional
|
||||||
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
|
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataclass_types (`DataClassType` or `Iterable[DataClassType]`, *optional*):
|
||||||
|
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
|
||||||
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
|
Passed to `argparse.ArgumentParser()` in the regular way.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataclass_types: Iterable[DataClassType]
|
dataclass_types: Iterable[DataClassType]
|
||||||
|
|
||||||
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
|
def __init__(self, dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, **kwargs):
|
||||||
"""
|
# Make sure dataclass_types is an iterable
|
||||||
Args:
|
if dataclass_types is None:
|
||||||
dataclass_types:
|
dataclass_types = []
|
||||||
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
|
elif not isinstance(dataclass_types, Iterable):
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
dataclass_types = [dataclass_types]
|
||||||
Passed to `argparse.ArgumentParser()` in the regular way.
|
|
||||||
"""
|
|
||||||
# To make the default appear when using --help
|
# To make the default appear when using --help
|
||||||
if "formatter_class" not in kwargs:
|
if "formatter_class" not in kwargs:
|
||||||
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
|
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
|
||||||
|
|||||||
Reference in New Issue
Block a user