[Chat] Add Chat from TRL 🐈 (#35714)
* tmp commit * add working chat * add docts * docs 2 * use auto dtype by default
This commit is contained in:
@@ -23,8 +23,8 @@ of text (as is the case with a standard language model), the model instead conti
|
||||
of one or more **messages**, each of which includes a **role**, like "user" or "assistant", as well as message text.
|
||||
|
||||
Much like tokenization, different models expect very different input formats for chat. This is the reason we added
|
||||
**chat templates** as a feature. Chat templates are part of the tokenizer for text-only LLMs or processor for multimodal LLMs. They specify how to convert conversations,
|
||||
represented as lists of messages, into a single tokenizable string in the format that the model expects.
|
||||
**chat templates** as a feature. Chat templates are part of the tokenizer for text-only LLMs or processor for multimodal LLMs. They specify how to convert conversations,
|
||||
represented as lists of messages, into a single tokenizable string in the format that the model expects.
|
||||
|
||||
Let's make this concrete with a quick example using the `mistralai/Mistral-7B-Instruct-v0.1` model:
|
||||
|
||||
@@ -42,8 +42,8 @@ Let's make this concrete with a quick example using the `mistralai/Mistral-7B-In
|
||||
"<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]"
|
||||
```
|
||||
|
||||
Notice how the tokenizer has added the control tokens [INST] and [/INST] to indicate the start and end of
|
||||
user messages (but not assistant messages!), and the entire chat is condensed into a single string.
|
||||
Notice how the tokenizer has added the control tokens [INST] and [/INST] to indicate the start and end of
|
||||
user messages (but not assistant messages!), and the entire chat is condensed into a single string.
|
||||
If we use `tokenize=True`, which is the default setting, that string will also be tokenized for us.
|
||||
|
||||
Now, try the same code, but swap in the `HuggingFaceH4/zephyr-7b-beta` model instead, and you should get:
|
||||
@@ -59,9 +59,16 @@ I'd like to show off how chat templating works!</s>
|
||||
|
||||
Both Zephyr and Mistral-Instruct were fine-tuned from the same base model, `Mistral-7B-v0.1`. However, they were trained
|
||||
with totally different chat formats. Without chat templates, you would have to write manual formatting code for each
|
||||
model, and it's very easy to make minor errors that hurt performance! Chat templates handle the details of formatting
|
||||
model, and it's very easy to make minor errors that hurt performance! Chat templates handle the details of formatting
|
||||
for you, allowing you to write universal code that works for any model.
|
||||
|
||||
<Tip>
|
||||
|
||||
Chat templates are a critical component of our [`transformers-cli chat` CLI](quicktour#chat-with-text-generation-models).
|
||||
You can apply the learnings of this guide there as well.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## How do I use chat templates?
|
||||
|
||||
@@ -69,7 +76,7 @@ As you can see in the example above, chat templates are easy to use. Simply buil
|
||||
and `content` keys, and then pass it to the [`~PreTrainedTokenizer.apply_chat_template`] or [`~ProcessorMixin.apply_chat_template`] method
|
||||
depending on what type of model you are using. Once you do that,
|
||||
you'll get output that's ready to go! When using chat templates as input for model generation, it's also a good idea
|
||||
to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts).
|
||||
to use `add_generation_prompt=True` to add a [generation prompt](#what-are-generation-prompts).
|
||||
|
||||
## Usage with text-only LLMs
|
||||
Here's an example of preparing input for `model.generate()`, using `Zephyr` again:
|
||||
@@ -91,19 +98,19 @@ messages = [
|
||||
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
||||
print(tokenizer.decode(tokenized_chat[0]))
|
||||
```
|
||||
This will yield a string in the input format that Zephyr expects.
|
||||
This will yield a string in the input format that Zephyr expects.
|
||||
```text
|
||||
<|system|>
|
||||
You are a friendly chatbot who always responds in the style of a pirate</s>
|
||||
You are a friendly chatbot who always responds in the style of a pirate</s>
|
||||
<|user|>
|
||||
How many helicopters can a human eat in one sitting?</s>
|
||||
How many helicopters can a human eat in one sitting?</s>
|
||||
<|assistant|>
|
||||
```
|
||||
|
||||
Now that our input is formatted correctly for Zephyr, we can use the model to generate a response to the user's question:
|
||||
|
||||
```python
|
||||
outputs = model.generate(tokenized_chat, max_new_tokens=128)
|
||||
outputs = model.generate(tokenized_chat, max_new_tokens=128)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
@@ -111,9 +118,9 @@ This will yield:
|
||||
|
||||
```text
|
||||
<|system|>
|
||||
You are a friendly chatbot who always responds in the style of a pirate</s>
|
||||
You are a friendly chatbot who always responds in the style of a pirate</s>
|
||||
<|user|>
|
||||
How many helicopters can a human eat in one sitting?</s>
|
||||
How many helicopters can a human eat in one sitting?</s>
|
||||
<|assistant|>
|
||||
Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
|
||||
```
|
||||
@@ -152,7 +159,7 @@ print(processor.batch_decode(processed_chat["input_ids"][:, :30]))
|
||||
This yields a string in LLaVAs expected input format with many `<image>` tokens at the end.
|
||||
The `<image>` tokens are placeholders and each one will be replaced by image embeddings when the mode is run in the forward call. The `processed_chat` can be further passed into [`~GenerationMixin.generate`] to generate text.
|
||||
```text
|
||||
'<|im_start|>system
|
||||
'<|im_start|>system
|
||||
You are a friendly chatbot who always responds in the style of a pirate<|im_end|><|im_start|>user <image><image><image><image><image><image><image><image>'
|
||||
```
|
||||
|
||||
@@ -162,7 +169,7 @@ Arr, 'twas easy after all!
|
||||
|
||||
Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. In the past,
|
||||
we used to use a dedicated "ConversationalPipeline" class, but this has now been deprecated and its functionality
|
||||
has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using
|
||||
has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using
|
||||
a pipeline:
|
||||
|
||||
```python
|
||||
@@ -227,9 +234,9 @@ Can I ask a question?<|im_end|>
|
||||
```
|
||||
|
||||
Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model
|
||||
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
|
||||
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
|
||||
special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're
|
||||
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
|
||||
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
|
||||
special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're
|
||||
supposed to be doing.
|
||||
|
||||
Not all models require generation prompts. Some models, like LLaMA, don't have any
|
||||
@@ -241,7 +248,7 @@ effect that `add_generation_prompt` has will depend on the template being used.
|
||||
When passing a list of messages to `apply_chat_template` or `TextGenerationPipeline`, you can choose
|
||||
to format the chat so the model will continue the final message in the chat instead of starting a new one. This is done
|
||||
by removing any end-of-sequence tokens that indicate the end of the final message, so that the model will simply
|
||||
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.
|
||||
extend the final message when it begins to generate text. This is useful for "prefilling" the model's response.
|
||||
|
||||
Here's an example:
|
||||
|
||||
@@ -266,9 +273,9 @@ get an error if you try!
|
||||
<Tip>
|
||||
|
||||
The default behaviour of `TextGenerationPipeline` is to set `add_generation_prompt=True` so that it starts a new
|
||||
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
|
||||
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
|
||||
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_final_message`
|
||||
message. However, if the final message in the input chat has the "assistant" role, it will assume that this message is
|
||||
a prefill and switch to `continue_final_message=True` instead, because most models do not support multiple
|
||||
consecutive assistant messages. You can override this behaviour by explicitly passing the `continue_final_message`
|
||||
argument when calling the pipeline.
|
||||
|
||||
</Tip>
|
||||
@@ -277,8 +284,8 @@ argument when calling the pipeline.
|
||||
|
||||
Yes! This is a good way to ensure that the chat template matches the tokens the model sees during training.
|
||||
We recommend that you apply the chat template as a preprocessing step for your dataset. After this, you
|
||||
can simply continue like any other language model training task. When training, you should usually set
|
||||
`add_generation_prompt=False`, because the added tokens to prompt an assistant response will not be helpful during
|
||||
can simply continue like any other language model training task. When training, you should usually set
|
||||
`add_generation_prompt=False`, because the added tokens to prompt an assistant response will not be helpful during
|
||||
training. Let's see an example:
|
||||
|
||||
```python
|
||||
@@ -312,8 +319,8 @@ From here, just continue training like you would with a standard language modell
|
||||
|
||||
<Tip>
|
||||
|
||||
By default, some tokenizers add special tokens like `<bos>` and `<eos>` to text they tokenize. Chat templates should
|
||||
already include all the special tokens they need, and so additional special tokens will often be incorrect or
|
||||
By default, some tokenizers add special tokens like `<bos>` and `<eos>` to text they tokenize. Chat templates should
|
||||
already include all the special tokens they need, and so additional special tokens will often be incorrect or
|
||||
duplicated, which will hurt model performance.
|
||||
|
||||
Therefore, if you format text with `apply_chat_template(tokenize=False)`, you should set the argument
|
||||
@@ -326,7 +333,7 @@ Therefore, if you format text with `apply_chat_template(tokenize=False)`, you sh
|
||||
The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
|
||||
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
|
||||
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
|
||||
strings, lists, dicts or whatever else you want.
|
||||
strings, lists, dicts or whatever else you want.
|
||||
|
||||
That said, there are some common use-cases for these extra arguments,
|
||||
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
|
||||
@@ -349,7 +356,7 @@ def current_time():
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
@@ -369,8 +376,8 @@ correctly as tools. Specifically, you should follow these rules:
|
||||
|
||||
- The function should have a descriptive name
|
||||
- Every argument must have a type hint
|
||||
- The function must have a docstring in the standard Google style (in other words, an initial function description
|
||||
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
|
||||
- The function must have a docstring in the standard Google style (in other words, an initial function description
|
||||
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
|
||||
- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not
|
||||
`a (int): The first number to multiply`. Type hints should go in the function header instead.
|
||||
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
|
||||
@@ -412,7 +419,7 @@ Next, let's define a list of tools:
|
||||
def get_current_temperature(location: str, unit: str) -> float:
|
||||
"""
|
||||
Get the current temperature at a location.
|
||||
|
||||
|
||||
Args:
|
||||
location: The location to get the temperature for, in the format "City, Country"
|
||||
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
|
||||
@@ -424,7 +431,7 @@ def get_current_temperature(location: str, unit: str) -> float:
|
||||
def get_current_wind_speed(location: str) -> float:
|
||||
"""
|
||||
Get the current wind speed in km/h at a given location.
|
||||
|
||||
|
||||
Args:
|
||||
location: The location to get the temperature for, in the format "City, Country"
|
||||
Returns:
|
||||
@@ -469,8 +476,8 @@ the temperature in France should certainly be displayed in Celsius.
|
||||
|
||||
The output format above is specific to the `Hermes-2-Pro` model we're using in this example. Other models may emit different
|
||||
tool call formats, and you may need to do some manual parsing at this step. For example, `Llama-3.1` models will emit
|
||||
slightly different JSON, with `parameters` instead of `arguments`. Regardless of the format the model outputs, you
|
||||
should add the tool call to the conversation in the format below, with `tool_calls`, `function` and `arguments` keys.
|
||||
slightly different JSON, with `parameters` instead of `arguments`. Regardless of the format the model outputs, you
|
||||
should add the tool call to the conversation in the format below, with `tool_calls`, `function` and `arguments` keys.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -489,7 +496,7 @@ a dict, but in the OpenAI API it's a JSON string. Passing a string may cause err
|
||||
</Tip>
|
||||
|
||||
Now that we've added the tool call to the conversation, we can call the function and append the result to the
|
||||
conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
|
||||
conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
|
||||
that result directly.
|
||||
|
||||
```python
|
||||
@@ -500,7 +507,7 @@ messages.append({"role": "tool", "name": "get_current_temperature", "content": "
|
||||
|
||||
Some model architectures, notably Mistral/Mixtral, also require a `tool_call_id` here, which should be
|
||||
9 randomly-generated alphanumeric characters, and assigned to the `id` key of the tool call
|
||||
dictionary. The same key should also be assigned to the `tool_call_id` key of the tool response dictionary below, so
|
||||
dictionary. The same key should also be assigned to the `tool_call_id` key of the tool response dictionary below, so
|
||||
that tool calls can be matched to tool responses. So, for Mistral/Mixtral models, the code above would be:
|
||||
|
||||
```python
|
||||
@@ -532,13 +539,13 @@ And we get:
|
||||
The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|>
|
||||
```
|
||||
|
||||
Although this was a simple demo with dummy tools and a single call, the same technique works with
|
||||
Although this was a simple demo with dummy tools and a single call, the same technique works with
|
||||
multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational
|
||||
agents with real-time information, computational tools like calculators, or access to large databases.
|
||||
|
||||
### Understanding tool schemas
|
||||
|
||||
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
|
||||
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
|
||||
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas
|
||||
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
|
||||
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
|
||||
@@ -547,7 +554,7 @@ to read their outputs, detect if they have requested to use a tool, pass their a
|
||||
return the response in the chat.
|
||||
|
||||
Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
|
||||
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
|
||||
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
|
||||
you can handle the conversion manually. Here is an example of a manual schema conversion.
|
||||
|
||||
```python
|
||||
@@ -556,7 +563,7 @@ from transformers.utils import get_json_schema
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
@@ -571,33 +578,33 @@ This will yield:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "function",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "number",
|
||||
"type": "number",
|
||||
"description": "The first number to multiply"
|
||||
},
|
||||
},
|
||||
"b": {
|
||||
"type": "number",
|
||||
"description": "The second number to multiply"
|
||||
}
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
|
||||
all. JSON schemas can be passed directly to the `tools` argument of
|
||||
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
|
||||
all. JSON schemas can be passed directly to the `tools` argument of
|
||||
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
|
||||
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
|
||||
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
|
||||
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
|
||||
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
|
||||
to a minimum.
|
||||
|
||||
Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
|
||||
@@ -605,7 +612,7 @@ Here is an example of defining schemas by hand, and passing them directly to `ap
|
||||
```python
|
||||
# A simple function that takes no arguments
|
||||
current_time = {
|
||||
"type": "function",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "current_time",
|
||||
"description": "Get the current local time as a string.",
|
||||
@@ -621,18 +628,18 @@ multiply = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'multiply',
|
||||
'description': 'A function that multiplies two numbers',
|
||||
'description': 'A function that multiplies two numbers',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'a': {
|
||||
'type': 'number',
|
||||
'description': 'The first number to multiply'
|
||||
},
|
||||
},
|
||||
'b': {
|
||||
'type': 'number', 'description': 'The second number to multiply'
|
||||
}
|
||||
},
|
||||
},
|
||||
'required': ['a', 'b']
|
||||
}
|
||||
}
|
||||
@@ -647,7 +654,7 @@ model_input = tokenizer.apply_chat_template(
|
||||
## Advanced: Retrieval-augmented generation
|
||||
|
||||
"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
|
||||
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
|
||||
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
|
||||
recommendation for RAG models is that their template
|
||||
should accept a `documents` argument. This should be a list of documents, where each "document"
|
||||
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
|
||||
@@ -672,7 +679,7 @@ conversation = [
|
||||
# Define documents for retrieval-based generation
|
||||
documents = [
|
||||
{
|
||||
"title": "The Moon: Our Age-Old Foe",
|
||||
"title": "The Moon: Our Age-Old Foe",
|
||||
"text": "Man has always dreamed of destroying the moon. In this essay, I shall..."
|
||||
},
|
||||
{
|
||||
@@ -690,7 +697,7 @@ input_ids = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt").to(device)
|
||||
|
||||
# Generate a response
|
||||
# Generate a response
|
||||
gen_tokens = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=100,
|
||||
@@ -750,8 +757,8 @@ Effectively, the template does three things:
|
||||
an assistant response.
|
||||
|
||||
This is a pretty simple template but Jinja gives you a lot of flexibility to do more complex things! Let's see a Jinja
|
||||
template that can format inputs similarly to the way LLaMA formats them (note that the real LLaMA template includes
|
||||
handling for default system messages and slightly different system message handling in general - don't use this one
|
||||
template that can format inputs similarly to the way LLaMA formats them (note that the real LLaMA template includes
|
||||
handling for default system messages and slightly different system message handling in general - don't use this one
|
||||
in your actual code!)
|
||||
|
||||
```
|
||||
@@ -774,7 +781,7 @@ distinguishable to the model because of the tokens they're wrapped in.
|
||||
|
||||
### How do I create a chat template?
|
||||
|
||||
Simple, just write a jinja template and set `tokenizer.chat_template`. You may find it easier to start with an
|
||||
Simple, just write a jinja template and set `tokenizer.chat_template`. You may find it easier to start with an
|
||||
existing template from another model and simply edit it for your needs! For example, we could take the LLaMA template
|
||||
above and add "[ASST]" and "[/ASST]" to assistant messages:
|
||||
|
||||
@@ -802,13 +809,13 @@ tokenizer.chat_template = template # Set the new template
|
||||
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
|
||||
```
|
||||
|
||||
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so
|
||||
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so
|
||||
once you set the correct chat template, your model will automatically become compatible with [`TextGenerationPipeline`].
|
||||
|
||||
<Tip>
|
||||
If you're fine-tuning a model for chat, in addition to setting a chat template, you should probably add any new chat
|
||||
control tokens as special tokens in the tokenizer. Special tokens are never split,
|
||||
ensuring that your control tokens are always handled as single tokens rather than being tokenized in pieces. You
|
||||
control tokens as special tokens in the tokenizer. Special tokens are never split,
|
||||
ensuring that your control tokens are always handled as single tokens rather than being tokenized in pieces. You
|
||||
should also set the tokenizer's `eos_token` attribute to the token that marks the end of assistant generations in your
|
||||
template. This will ensure that text generation tools can correctly figure out when to stop generating text.
|
||||
</Tip>
|
||||
@@ -836,13 +843,13 @@ trying to put it all in a single template where possible!
|
||||
|
||||
When setting the template for a model that's already been trained for chat, you should ensure that the template
|
||||
exactly matches the message formatting that the model saw during training, or else you will probably experience
|
||||
performance degradation. This is true even if you're training the model further - you will probably get the best
|
||||
performance degradation. This is true even if you're training the model further - you will probably get the best
|
||||
performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the
|
||||
best performance for inference or fine-tuning when you precisely match the tokenization used during training.
|
||||
|
||||
If you're training a model from scratch, or fine-tuning a base language model for chat, on the other hand,
|
||||
you have a lot of freedom to choose an appropriate template! LLMs are smart enough to learn to handle lots of different
|
||||
input formats. One popular choice is the `ChatML` format, and this is a good, flexible choice for many use-cases.
|
||||
input formats. One popular choice is the `ChatML` format, and this is a good, flexible choice for many use-cases.
|
||||
It looks like this:
|
||||
|
||||
```
|
||||
@@ -888,7 +895,7 @@ Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_templat
|
||||
model, which means it is also automatically supported in places like `TextGenerationPipeline`!
|
||||
|
||||
By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of
|
||||
open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
|
||||
open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
|
||||
it's time to put an end to them!
|
||||
|
||||
## Advanced: Template writing tips
|
||||
@@ -896,17 +903,17 @@ it's time to put an end to them!
|
||||
<Tip>
|
||||
|
||||
The easiest way to get started with writing Jinja templates is to take a look at some existing ones. You can use
|
||||
`print(tokenizer.chat_template)` for any chat model to see what template it's using. In general, models that support tool use have
|
||||
`print(tokenizer.chat_template)` for any chat model to see what template it's using. In general, models that support tool use have
|
||||
much more complex templates than other models - so when you're just getting started, they're probably a bad example
|
||||
to learn from! You can also take a look at the
|
||||
to learn from! You can also take a look at the
|
||||
[Jinja documentation](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) for details
|
||||
of general Jinja formatting and syntax.
|
||||
|
||||
</Tip>
|
||||
|
||||
Jinja templates in `transformers` are identical to Jinja templates elsewhere. The main thing to know is that
|
||||
the conversation history will be accessible inside your template as a variable called `messages`.
|
||||
You will be able to access `messages` in your template just like you can in Python, which means you can loop over
|
||||
Jinja templates in `transformers` are identical to Jinja templates elsewhere. The main thing to know is that
|
||||
the conversation history will be accessible inside your template as a variable called `messages`.
|
||||
You will be able to access `messages` in your template just like you can in Python, which means you can loop over
|
||||
it with `{% for message in messages %}` or access individual messages with `{{ messages[0] }}`, for example.
|
||||
|
||||
You can also use the following tips to write clean, efficient Jinja templates:
|
||||
@@ -936,7 +943,7 @@ and indentation may end up being included in the output, which is probably not w
|
||||
|
||||
### Special variables
|
||||
|
||||
Inside your template, you will have access several special variables. The most important of these is `messages`,
|
||||
Inside your template, you will have access several special variables. The most important of these is `messages`,
|
||||
which contains the chat history as a list of message dicts. However, there are several others. Not every
|
||||
variable will be used in every template. The most common other variables are:
|
||||
|
||||
@@ -970,7 +977,7 @@ There are multiple implementations of Jinja in various languages. They generally
|
||||
but a key difference is that when you're writing a template in Python you can use Python methods, such as
|
||||
`.lower()` on strings or `.items()` on dicts. This will break if someone tries to use your template on a non-Python
|
||||
implementation of Jinja. Non-Python implementations are particularly common in deployment environments, where JS
|
||||
and Rust are very popular.
|
||||
and Rust are very popular.
|
||||
|
||||
Don't panic, though! There are a few easy changes you can make to your templates to ensure they're compatible across
|
||||
all implementations of Jinja:
|
||||
@@ -1002,21 +1009,21 @@ Here is an example of a template that formats messages ChatML-style, with genera
|
||||
```
|
||||
|
||||
The exact content of the assistant header will depend on your specific model, but it should always be **the string
|
||||
that represents the start of an assistant message**, so that if the user applies your template with
|
||||
that represents the start of an assistant message**, so that if the user applies your template with
|
||||
`add_generation_prompt=True` and then generates text, the model will write an assistant response. Also note that some
|
||||
models do not need a generation prompt, because assistant messages always begin immediately after user messages.
|
||||
models do not need a generation prompt, because assistant messages always begin immediately after user messages.
|
||||
This is particularly common for LLaMA and Mistral models, where assistant messages begin immediately after the `[/INST]`
|
||||
token that ends user messages. In these cases, the template can ignore the `add_generation_prompt` flag.
|
||||
|
||||
Generation prompts are important! If your model requires a generation prompt but it is not set in the template, then
|
||||
model generations will likely be severely degraded, or the model may display unusual behaviour like continuing
|
||||
the final user message!
|
||||
model generations will likely be severely degraded, or the model may display unusual behaviour like continuing
|
||||
the final user message!
|
||||
|
||||
### Writing and debugging larger templates
|
||||
|
||||
When this feature was introduced, most templates were quite small, the Jinja equivalent of a "one-liner" script.
|
||||
When this feature was introduced, most templates were quite small, the Jinja equivalent of a "one-liner" script.
|
||||
However, with new models and features like tool-use and RAG, some templates can be 100 lines long or more. When
|
||||
writing templates like these, it's a good idea to write them in a separate file, using a text editor. You can easily
|
||||
writing templates like these, it's a good idea to write them in a separate file, using a text editor. You can easily
|
||||
extract a chat template to a file:
|
||||
|
||||
```python
|
||||
@@ -1035,7 +1042,7 @@ identify the source of issues.
|
||||
|
||||
### Writing templates for tools
|
||||
|
||||
Although chat templates do not enforce a specific API for tools (or for anything, really), we recommend
|
||||
Although chat templates do not enforce a specific API for tools (or for anything, really), we recommend
|
||||
template authors try to stick to a standard API where possible. The whole point of chat templates is to allow code
|
||||
to be transferable across models, so deviating from the standard tools API means users will have to write
|
||||
custom code to use tools with your model. Sometimes it's unavoidable, but often with clever templating you can
|
||||
@@ -1045,30 +1052,30 @@ Below, we'll list the elements of the standard API, and give tips on writing tem
|
||||
|
||||
#### Tool definitions
|
||||
|
||||
Your template should expect that the variable `tools` will either be null (if no tools are passed), or is a list
|
||||
Your template should expect that the variable `tools` will either be null (if no tools are passed), or is a list
|
||||
of JSON schema dicts. Our chat template methods allow users to pass tools as either JSON schema or Python functions, but when
|
||||
functions are passed, we automatically generate JSON schema and pass that to your template. As a result, the
|
||||
functions are passed, we automatically generate JSON schema and pass that to your template. As a result, the
|
||||
`tools` variable that your template receives will always be a list of JSON schema. Here is
|
||||
a sample tool JSON schema:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "function",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "number",
|
||||
"type": "number",
|
||||
"description": "The first number to multiply"
|
||||
},
|
||||
},
|
||||
"b": {
|
||||
"type": "number",
|
||||
"description": "The second number to multiply"
|
||||
}
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}
|
||||
@@ -1092,13 +1099,13 @@ specific format - your model will probably need different formatting!
|
||||
|
||||
The specific tokens and tool descriptions your template renders should of course be chosen to match the ones your model
|
||||
was trained with. There is no requirement that your **model** understands JSON schema input, only that your template can translate
|
||||
JSON schema into your model's format. For example, [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-plus-08-2024)
|
||||
was trained with tools defined using Python function headers, but the Command-R tool template accepts JSON schema,
|
||||
JSON schema into your model's format. For example, [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-plus-08-2024)
|
||||
was trained with tools defined using Python function headers, but the Command-R tool template accepts JSON schema,
|
||||
converts types internally and renders the input tools as Python headers. You can do a lot with templates!
|
||||
|
||||
#### Tool calls
|
||||
|
||||
Tool calls, if present, will be a list attached to a message with the "assistant" role. Note that `tool_calls` is
|
||||
Tool calls, if present, will be a list attached to a message with the "assistant" role. Note that `tool_calls` is
|
||||
always a list, even though most tool-calling models only support single tool calls at a time, which means
|
||||
the list will usually only have a single element. Here is a sample message dict containing a tool call:
|
||||
|
||||
|
||||
@@ -41,6 +41,13 @@ This guide describes:
|
||||
* common decoding strategies and their main parameters
|
||||
* saving and sharing custom generation configurations with your fine-tuned model on 🤗 Hub
|
||||
|
||||
<Tip>
|
||||
|
||||
`generate()` is a critical component of our [`transformers-cli chat` CLI](quicktour#chat-with-text-generation-models).
|
||||
You can apply the learnings of this guide there as well.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Default text generation configuration
|
||||
|
||||
A decoding strategy for a model is defined in its generation configuration. When using pre-trained models for inference
|
||||
|
||||
@@ -23,6 +23,12 @@ LLMs, or Large Language Models, are the key component behind text generation. In
|
||||
|
||||
Autoregressive generation is the inference-time procedure of iteratively calling a model with its own generated outputs, given a few initial inputs. In 🤗 Transformers, this is handled by the [`~generation.GenerationMixin.generate`] method, which is available to all models with generative capabilities.
|
||||
|
||||
<Tip>
|
||||
|
||||
If you want to jump straight to chatting with a model, [try our `transformers-cli chat` CLI](quicktour#chat-with-text-generation-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
This tutorial will show you how to:
|
||||
|
||||
* Generate text with an LLM
|
||||
|
||||
@@ -553,6 +553,32 @@ All models are a standard [`tf.keras.Model`](https://www.tensorflow.org/api_docs
|
||||
>>> model.fit(tf_dataset) # doctest: +SKIP
|
||||
```
|
||||
|
||||
|
||||
## Chat with text generation models
|
||||
|
||||
If you're working with a model that generates text as an output, you can also engage in a multi-turn conversation with
|
||||
it through the `transformers-cli chat` command. This is the fastest way to interact with a model, e.g. for a
|
||||
qualitative assessment (aka vibe check).
|
||||
|
||||
This CLI is implemented on top of our `AutoClass` abstraction, leveraging our [text generation](llm_tutorial.md) and
|
||||
[chat](chat_templating.md) tooling, and thus will be compatible with any 🤗 Transformers model. If you have the library
|
||||
[installed](installation.md), you can launch the chat session on your terminal with
|
||||
|
||||
```
|
||||
transformers-cli chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
||||
```
|
||||
|
||||
For a full list of options to launch the chat, type
|
||||
|
||||
```
|
||||
transformers-cli chat -h
|
||||
```
|
||||
|
||||
After the chat is launched, you will enter an interactive session with the model. There are special commands for this
|
||||
session as well, such as `clear` to reset the conversation. Type `help` at any moment to display all special chat
|
||||
commands, and `exit` to terminate the session.
|
||||
|
||||
|
||||
## What's next?
|
||||
|
||||
Now that you've completed the 🤗 Transformers quick tour, check out our guides and learn how to do more specific things like writing a custom model, fine-tuning a model for a task, and how to train a model with a script. If you're interested in learning more about 🤗 Transformers core concepts, grab a cup of coffee and take a look at our Conceptual Guides!
|
||||
|
||||
539
src/transformers/commands/chat.py
Normal file
539
src/transformers/commands/chat.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
||||
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import pwd
|
||||
|
||||
|
||||
HELP_STRING = """\
|
||||
|
||||
**TRANSFORMERS CHAT INTERFACE**
|
||||
|
||||
The chat interface is a simple tool to try out a chat model.
|
||||
|
||||
Besides talking to the model there are several commands:
|
||||
- **help**: show this help message
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
"""
|
||||
|
||||
SUPPORTED_GENERATION_KWARGS = [
|
||||
"max_new_tokens",
|
||||
"do_sample",
|
||||
"num_beams",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
]
|
||||
|
||||
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
|
||||
|
||||
DEFAULT_EXAMPLES = {
|
||||
"llama": {"text": "There is a Llama in my lawn, how can I get rid of it?"},
|
||||
"code": {
|
||||
"text": "Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]."
|
||||
},
|
||||
"helicopter": {"text": "How many helicopters can a human eat in one sitting?"},
|
||||
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
|
||||
"birds": {"text": "Why aren't birds real?"},
|
||||
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
||||
}
|
||||
|
||||
|
||||
def get_username():
|
||||
if platform.system() == "Windows":
|
||||
return os.getlogin()
|
||||
else:
|
||||
return pwd.getpwuid(os.getuid()).pw_name
|
||||
|
||||
|
||||
def create_default_filename(model_name):
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return f"{model_name}/chat_{time_str}.json"
|
||||
|
||||
|
||||
def save_chat(chat, args, filename):
|
||||
output_dict = {}
|
||||
output_dict["settings"] = vars(args)
|
||||
output_dict["chat_history"] = chat
|
||||
|
||||
folder = args.save_folder
|
||||
|
||||
if filename is None:
|
||||
filename = create_default_filename(args.model_name_or_path)
|
||||
filename = os.path.join(folder, filename)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_dict, f, indent=4)
|
||||
return os.path.abspath(filename)
|
||||
|
||||
|
||||
def clear_chat_history(system_prompt):
|
||||
if system_prompt is None:
|
||||
chat = []
|
||||
else:
|
||||
chat = [{"role": "system", "content": system_prompt}]
|
||||
return chat
|
||||
|
||||
|
||||
def parse_settings(user_input, current_args, interface):
|
||||
settings = user_input[4:].strip().split(";")
|
||||
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
||||
settings = dict(settings)
|
||||
error = False
|
||||
|
||||
for name in settings:
|
||||
if hasattr(current_args, name):
|
||||
try:
|
||||
if isinstance(getattr(current_args, name), bool):
|
||||
if settings[name] == "True":
|
||||
settings[name] = True
|
||||
elif settings[name] == "False":
|
||||
settings[name] = False
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
settings[name] = type(getattr(current_args, name))(settings[name])
|
||||
except ValueError:
|
||||
interface.print_red(
|
||||
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
|
||||
)
|
||||
else:
|
||||
interface.print_red(f"There is no '{name}' setting.")
|
||||
|
||||
if error:
|
||||
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
|
||||
return current_args, False
|
||||
else:
|
||||
for name in settings:
|
||||
setattr(current_args, name, settings[name])
|
||||
interface.print_green(f"Set {name} to {settings[name]}.")
|
||||
|
||||
time.sleep(1.5) # so the user has time to read the changes
|
||||
return current_args, True
|
||||
|
||||
|
||||
def get_quantization_config(model_args) -> Optional[BitsAndBytesConfig]:
|
||||
if model_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=model_args.torch_dtype,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
return quantization_config
|
||||
|
||||
|
||||
def load_model_and_tokenizer(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
revision=args.model_revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = get_quantization_config(args)
|
||||
model_kwargs = {
|
||||
"revision": args.model_revision,
|
||||
"attn_implementation": args.attn_implementation,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
|
||||
if tokenizer.pad_token_id is None:
|
||||
pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
all_eos_token_ids = []
|
||||
|
||||
if eos_tokens is not None:
|
||||
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
|
||||
|
||||
if eos_token_ids is not None:
|
||||
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
||||
|
||||
if len(all_eos_token_ids) == 0:
|
||||
all_eos_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
return pad_token_id, all_eos_token_ids
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name=None, user_name=None):
|
||||
self._console = Console()
|
||||
if model_name is None:
|
||||
self.model_name = "assistant"
|
||||
else:
|
||||
self.model_name = model_name
|
||||
if user_name is None:
|
||||
self.user_name = "user"
|
||||
else:
|
||||
self.user_name = user_name
|
||||
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output from a role."""
|
||||
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
# Create a Live context for updating the console output
|
||||
text = ""
|
||||
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||
with Live(console=self._console, refresh_per_second=4) as live:
|
||||
# Read lines from the stream
|
||||
for i, outputs in enumerate(output_stream):
|
||||
if not outputs or i == 0:
|
||||
continue
|
||||
text += outputs
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
# in rich. The chatbots output treat "\n" as a new line for
|
||||
# better compatibility with real-world text. However, rendering
|
||||
# in markdown would break the format. It is because standard markdown
|
||||
# treat a single "\n" in normal text as a space.
|
||||
# Our workaround is adding two spaces at the end of each line.
|
||||
# This is not a perfect solution, as it would
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
lines.append(line)
|
||||
if line.startswith("```"):
|
||||
# Code block marker - do not add trailing spaces, as it would
|
||||
# break the syntax highlighting
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
self._console.print()
|
||||
return text
|
||||
|
||||
def input(self):
|
||||
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
|
||||
self._console.print()
|
||||
return input
|
||||
|
||||
def clear(self):
|
||||
self._console.clear()
|
||||
|
||||
def print_user_message(self, text):
|
||||
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_green(self, text):
|
||||
self._console.print(f"[bold green]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_red(self, text):
|
||||
self._console.print(f"[bold red]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_help(self):
|
||||
self._console.print(Markdown(HELP_STRING))
|
||||
self._console.print()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatArguments:
|
||||
r"""
|
||||
Arguments for the chat script.
|
||||
|
||||
Args:
|
||||
model_name_or_path (`str`):
|
||||
Name of the pre-trained model.
|
||||
user (`str` or `None`, *optional*, defaults to `None`):
|
||||
Username to display in chat interface.
|
||||
system_prompt (`str` or `None`, *optional*, defaults to `None`):
|
||||
System prompt.
|
||||
save_folder (`str`, *optional*, defaults to `"./chat_history/"`):
|
||||
Folder to save chat history.
|
||||
device (`str`, *optional*, defaults to `"cpu"`):
|
||||
Device to use for inference.
|
||||
examples_path (`str` or `None`, *optional*, defaults to `None`):
|
||||
Path to a yaml file with examples.
|
||||
max_new_tokens (`int`, *optional*, defaults to `256`):
|
||||
Maximum number of tokens to generate.
|
||||
do_sample (`bool`, *optional*, defaults to `True`):
|
||||
Whether to sample outputs during generation.
|
||||
num_beams (`int`, *optional*, defaults to `1`):
|
||||
Number of beams for beam search.
|
||||
temperature (`float`, *optional*, defaults to `1.0`):
|
||||
Temperature parameter for generation.
|
||||
top_k (`int`, *optional*, defaults to `50`):
|
||||
Value of k for top-k sampling.
|
||||
top_p (`float`, *optional*, defaults to `1.0`):
|
||||
Value of p for nucleus sampling.
|
||||
repetition_penalty (`float`, *optional*, defaults to `1.0`):
|
||||
Repetition penalty.
|
||||
eos_tokens (`str` or `None`, *optional*, defaults to `None`):
|
||||
EOS tokens to stop the generation. If multiple they should be comma separated.
|
||||
eos_token_ids (`str` or `None`, *optional*, defaults to `None`):
|
||||
EOS token IDs to stop the generation. If multiple they should be comma separated.
|
||||
model_revision (`str`, *optional*, defaults to `"main"`):
|
||||
Specific model version to use (can be a branch name, tag name or commit id).
|
||||
torch_dtype (`str` or `None`, *optional*, defaults to `None`):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code when loading a model.
|
||||
attn_implementation (`str` or `None`, *optional*, defaults to `None`):
|
||||
Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case
|
||||
you must install this manually by running `pip install flash-attn --no-build-isolation`.
|
||||
load_in_8bit (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use 8 bit precision for the base model - works only with LoRA.
|
||||
load_in_4bit (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use 4 bit precision for the base model - works only with LoRA.
|
||||
bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`):
|
||||
Quantization type.
|
||||
use_bnb_nested_quant (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use nested quantization.
|
||||
"""
|
||||
|
||||
# General settings
|
||||
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model."})
|
||||
user: Optional[str] = field(default=None, metadata={"help": "Username to display in chat interface."})
|
||||
system_prompt: Optional[str] = field(default=None, metadata={"help": "System prompt."})
|
||||
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history."})
|
||||
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
||||
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
||||
|
||||
# Generation settings
|
||||
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
||||
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
||||
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
||||
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation."})
|
||||
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling."})
|
||||
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling."})
|
||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
|
||||
eos_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."},
|
||||
)
|
||||
eos_token_ids: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
||||
)
|
||||
|
||||
# Model loading
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "Specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, "
|
||||
"the dtype will be automatically derived from the model's weights.",
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False, metadata={"help": "Whether to trust remote code when loading a model."}
|
||||
)
|
||||
attn_implementation: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in "
|
||||
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
|
||||
},
|
||||
)
|
||||
load_in_8bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 8 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
load_in_4bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 4 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||
|
||||
|
||||
def chat_command_factory(args: Namespace):
|
||||
"""
|
||||
Factory function used to chat with a local model.
|
||||
"""
|
||||
return ChatCommand(args)
|
||||
|
||||
|
||||
class ChatCommand(BaseTransformersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
Register this command to argparse so it's available for the transformer-cli
|
||||
|
||||
Args:
|
||||
parser: Root parser to register command-specific arguments
|
||||
"""
|
||||
dataclass_types = (ChatArguments,)
|
||||
chat_parser = parser.add_parser("chat", help=HELP_STRING, dataclass_types=dataclass_types)
|
||||
chat_parser.set_defaults(func=chat_command_factory)
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def run(self):
|
||||
args = self.args
|
||||
if args.examples_path is None:
|
||||
examples = DEFAULT_EXAMPLES
|
||||
else:
|
||||
with open(args.examples_path) as f:
|
||||
examples = yaml.safe_load(f)
|
||||
|
||||
current_args = copy.deepcopy(args)
|
||||
|
||||
if args.user is None:
|
||||
user = get_username()
|
||||
else:
|
||||
user = args.user
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
|
||||
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
while True:
|
||||
try:
|
||||
user_input = interface.input()
|
||||
|
||||
if user_input == "clear":
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input == "help":
|
||||
interface.print_help()
|
||||
continue
|
||||
|
||||
if user_input == "exit":
|
||||
break
|
||||
|
||||
if user_input == "reset":
|
||||
interface.clear()
|
||||
current_args = copy.deepcopy(args)
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
continue
|
||||
|
||||
if user_input.startswith("save") and len(user_input.split()) < 2:
|
||||
split_input = user_input.split()
|
||||
|
||||
if len(split_input) == 2:
|
||||
filename = split_input[1]
|
||||
else:
|
||||
filename = None
|
||||
filename = save_chat(chat, current_args, filename)
|
||||
interface.print_green(f"Chat saved in {filename}!")
|
||||
continue
|
||||
|
||||
if re.match(SETTING_RE, user_input):
|
||||
current_args, success = parse_settings(user_input, current_args, interface)
|
||||
if success:
|
||||
chat = []
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input.startswith("example") and len(user_input.split()) == 2:
|
||||
example_name = user_input.split()[1]
|
||||
if example_name in examples:
|
||||
interface.clear()
|
||||
chat = []
|
||||
interface.print_user_message(examples[example_name]["text"])
|
||||
user_input = examples[example_name]["text"]
|
||||
else:
|
||||
interface.print_red(
|
||||
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
||||
)
|
||||
continue
|
||||
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
||||
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
)
|
||||
attention_mask = torch.ones_like(inputs)
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"streamer": generation_streamer,
|
||||
"max_new_tokens": current_args.max_new_tokens,
|
||||
"do_sample": current_args.do_sample,
|
||||
"num_beams": current_args.num_beams,
|
||||
"temperature": current_args.temperature,
|
||||
"top_k": current_args.top_k,
|
||||
"top_p": current_args.top_p,
|
||||
"repetition_penalty": current_args.repetition_penalty,
|
||||
"pad_token_id": pad_token_id,
|
||||
"eos_token_id": eos_token_ids,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
model_output = interface.stream_output(generation_streamer)
|
||||
thread.join()
|
||||
chat.append({"role": "assistant", "content": model_output})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
@@ -13,9 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from .add_new_model_like import AddNewModelLikeCommand
|
||||
from .chat import ChatCommand
|
||||
from .convert import ConvertCommand
|
||||
from .download import DownloadCommand
|
||||
from .env import EnvironmentCommand
|
||||
@@ -26,10 +27,11 @@ from .user import UserCommands
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
|
||||
parser = HfArgumentParser(prog="Transformers CLI tool", usage="transformers-cli <command> [<args>]")
|
||||
commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
|
||||
|
||||
# Register commands
|
||||
ChatCommand.register_subcommand(commands_parser)
|
||||
ConvertCommand.register_subcommand(commands_parser)
|
||||
DownloadCommand.register_subcommand(commands_parser)
|
||||
EnvironmentCommand.register_subcommand(commands_parser)
|
||||
|
||||
@@ -114,18 +114,23 @@ class HfArgumentParser(ArgumentParser):
|
||||
The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
|
||||
arguments to the parser after initialization and you'll get the output back after parsing as an additional
|
||||
namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
|
||||
|
||||
Args:
|
||||
dataclass_types (`DataClassType` or `Iterable[DataClassType]`, *optional*):
|
||||
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Passed to `argparse.ArgumentParser()` in the regular way.
|
||||
"""
|
||||
|
||||
dataclass_types: Iterable[DataClassType]
|
||||
|
||||
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
|
||||
"""
|
||||
Args:
|
||||
dataclass_types:
|
||||
Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Passed to `argparse.ArgumentParser()` in the regular way.
|
||||
"""
|
||||
def __init__(self, dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, **kwargs):
|
||||
# Make sure dataclass_types is an iterable
|
||||
if dataclass_types is None:
|
||||
dataclass_types = []
|
||||
elif not isinstance(dataclass_types, Iterable):
|
||||
dataclass_types = [dataclass_types]
|
||||
|
||||
# To make the default appear when using --help
|
||||
if "formatter_class" not in kwargs:
|
||||
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
|
||||
|
||||
Reference in New Issue
Block a user