[chat] generate parameterization powered by GenerationConfig and UX-related changes (#38047)
* accept arbitrary kwargs * move user commands to a separate fn * work with generation config files * rm cmmt * docs * base generate flag doc section * nits * nits * nits * no <br> * better basic args description
This commit is contained in:
@@ -120,7 +120,7 @@ To chat with a model, the usage pattern is the same. The only difference is you
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> You can also chat with a model directly from the command line.
|
> You can also chat with a model directly from the command line.
|
||||||
> ```shell
|
> ```shell
|
||||||
> transformers chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
```py
|
```py
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ This guide shows you how to quickly start chatting with Transformers from the co
|
|||||||
|
|
||||||
## transformers CLI
|
## transformers CLI
|
||||||
|
|
||||||
Chat with a model directly from the command line as shown below. It launches an interactive session with a model. Enter `clear` to reset the conversation, `exit` to terminate the session, and `help` to display all the command options.
|
After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||||
@@ -37,6 +37,12 @@ transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
|||||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers-chat-cli.png"/>
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers-chat-cli.png"/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
You can launch the CLI with arbitrary `generate` flags, with the format `arg_1=value_1 arg_2=value_2 ...`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
transformers chat Qwen/Qwen2.5-0.5B-Instruct do_sample=False max_new_tokens=10
|
||||||
|
```
|
||||||
|
|
||||||
For a full list of options, run the command below.
|
For a full list of options, run the command below.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -20,9 +20,13 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
Text generation is the most popular application for large language models (LLMs). A LLM is trained to generate the next word (token) given some initial text (prompt) along with its own generated outputs up to a predefined length or when it reaches an end-of-sequence (`EOS`) token.
|
Text generation is the most popular application for large language models (LLMs). A LLM is trained to generate the next word (token) given some initial text (prompt) along with its own generated outputs up to a predefined length or when it reaches an end-of-sequence (`EOS`) token.
|
||||||
|
|
||||||
In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities.
|
In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities. This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
|
||||||
|
|
||||||
This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
|
> [!TIP]
|
||||||
|
> You can also chat with a model directly from the command line. ([reference](./conversations.md#transformers-cli))
|
||||||
|
> ```shell
|
||||||
|
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
|
||||||
|
> ```
|
||||||
|
|
||||||
## Default generate
|
## Default generate
|
||||||
|
|
||||||
@@ -134,6 +138,20 @@ outputs = model.generate(**inputs, generation_config=generation_config)
|
|||||||
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Common Options
|
||||||
|
|
||||||
|
[`~GenerationMixin.generate`] is a powerful tool that can be heavily customized. This can be daunting for a new users. This section contains a list of popular generation options that you can define in most text generation tools in Transformers: [`~GenerationMixin.generate`], [`GenerationConfig`], `pipelines`, the `chat` CLI, ...
|
||||||
|
|
||||||
|
| Option name | Type | Simplified description |
|
||||||
|
|---|---|---|
|
||||||
|
| `max_new_tokens` | `int` | Controls the maximum generation length. Be sure to define it, as it usually defaults to a small value. |
|
||||||
|
| `do_sample` | `bool` | Defines whether generation will sample the next token (`True`), or is greedy instead (`False`). Most use cases should set this flag to `True`. Check [this guide](./generation_strategies.md) for more information. |
|
||||||
|
| `temperature` | `float` | How unpredictable the next selected token will be. High values (`>0.8`) are good for creative tasks, low values (e.g. `<0.4`) for tasks that require "thinking". Requires `do_sample=True`. |
|
||||||
|
| `num_beams` | `int` | When set to `>1`, activates the beam search algorithm. Beam search is good on input-grounded tasks. Check [this guide](./generation_strategies.md) for more information. |
|
||||||
|
| `repetition_penalty` | `float` | Set it to `>1.0` if you're seeing the model repeat itself often. Larger values apply a larger penalty. |
|
||||||
|
| `eos_token_id` | `List[int]` | The token(s) that will cause generation to stop. The default value is usually good, but you can specify a different token. |
|
||||||
|
|
||||||
|
|
||||||
## Pitfalls
|
## Pitfalls
|
||||||
|
|
||||||
The section below covers some common issues you may encounter during text generation and how to solve them.
|
The section below covers some common issues you may encounter during text generation and how to solve them.
|
||||||
@@ -286,4 +304,4 @@ Take a look below for some more specific and specialized text generation librari
|
|||||||
- [SynCode](https://github.com/uiuc-focal-lab/syncode): a library for context-free grammar guided generation (JSON, SQL, Python).
|
- [SynCode](https://github.com/uiuc-focal-lab/syncode): a library for context-free grammar guided generation (JSON, SQL, Python).
|
||||||
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a production-ready server for LLMs.
|
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a production-ready server for LLMs.
|
||||||
- [Text generation web UI](https://github.com/oobabooga/text-generation-webui): a Gradio web UI for text generation.
|
- [Text generation web UI](https://github.com/oobabooga/text-generation-webui): a Gradio web UI for text generation.
|
||||||
- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.
|
- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.
|
||||||
|
|||||||
@@ -13,12 +13,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@@ -42,7 +42,13 @@ if is_rich_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
GenerationConfig,
|
||||||
|
TextIteratorStreamer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||||
@@ -64,25 +70,16 @@ DEFAULT_EXAMPLES = {
|
|||||||
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORTED_GENERATION_KWARGS = [
|
|
||||||
"max_new_tokens",
|
|
||||||
"do_sample",
|
|
||||||
"num_beams",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
"repetition_penalty",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Printed at the start of a chat session
|
# Printed at the start of a chat session
|
||||||
HELP_STRING_MINIMAL = """
|
HELP_STRING_MINIMAL = """
|
||||||
|
|
||||||
**TRANSFORMERS CHAT INTERFACE**
|
**TRANSFORMERS CHAT INTERFACE**
|
||||||
|
|
||||||
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
|
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
|
||||||
- **help**: shows all available commands
|
- **!help**: shows all available commands
|
||||||
- **clear**: clears the current conversation and starts a new one
|
- **!status**: shows the current status of the model and generation settings
|
||||||
- **exit**: closes the interface
|
- **!clear**: clears the current conversation and starts a new one
|
||||||
|
- **!exit**: closes the interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -92,18 +89,32 @@ HELP_STRING = f"""
|
|||||||
**TRANSFORMERS CHAT INTERFACE HELP**
|
**TRANSFORMERS CHAT INTERFACE HELP**
|
||||||
|
|
||||||
Full command list:
|
Full command list:
|
||||||
- **help**: shows this help message
|
- **!help**: shows this help message
|
||||||
- **clear**: clears the current conversation and starts a new one
|
- **!clear**: clears the current conversation and starts a new one
|
||||||
- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example
|
- **!status**: shows the current status of the model and generation settings
|
||||||
names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
- **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
|
||||||
- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are
|
Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
||||||
separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}`
|
- **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
|
||||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
|
||||||
- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
|
||||||
|
- **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
||||||
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
|
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
|
||||||
- **exit**: closes the interface
|
- **!exit**: closes the interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
|
||||||
|
_DEPRECATION_MAP = [
|
||||||
|
("max_new_tokens", 256, "max_new_tokens"),
|
||||||
|
("do_sample", True, "do_sample"),
|
||||||
|
("num_beams", 1, "num_beams"),
|
||||||
|
("temperature", 1.0, "temperature"),
|
||||||
|
("top_k", 50, "top_k"),
|
||||||
|
("top_p", 1.0, "top_p"),
|
||||||
|
("repetition_penalty", 1.0, "repetition_penalty"),
|
||||||
|
("eos_tokens", None, "eos_token_id"),
|
||||||
|
("eos_token_ids", None, "eos_token_id"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class RichInterface:
|
class RichInterface:
|
||||||
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
||||||
@@ -181,6 +192,14 @@ class RichInterface:
|
|||||||
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
|
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
|
||||||
self._console.print()
|
self._console.print()
|
||||||
|
|
||||||
|
def print_status(self, model_name: str, generation_config: GenerationConfig, model_kwargs: dict):
|
||||||
|
"""Prints the status of the model and generation settings to the console."""
|
||||||
|
self._console.print(f"[bold blue]Model: {model_name}\n")
|
||||||
|
if model_kwargs:
|
||||||
|
self._console.print(f"[bold blue]Model kwargs: {model_kwargs}")
|
||||||
|
self._console.print(f"[bold blue]{generation_config}")
|
||||||
|
self._console.print()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatArguments:
|
class ChatArguments:
|
||||||
@@ -207,6 +226,17 @@ class ChatArguments:
|
|||||||
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
||||||
|
|
||||||
# Generation settings
|
# Generation settings
|
||||||
|
generation_config: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Path to a local generation config file or to a HuggingFace repo containing a "
|
||||||
|
"`generation_config.json` file. Other generation settings passed as CLI arguments will be applied on "
|
||||||
|
"top of this generation config."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Deprecated CLI args start here
|
||||||
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
||||||
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
||||||
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
||||||
@@ -222,6 +252,7 @@ class ChatArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
||||||
)
|
)
|
||||||
|
# Deprecated CLI args end here
|
||||||
|
|
||||||
# Model loading
|
# Model loading
|
||||||
model_revision: str = field(
|
model_revision: str = field(
|
||||||
@@ -280,23 +311,66 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
group = chat_parser.add_argument_group("Positional arguments")
|
group = chat_parser.add_argument_group("Positional arguments")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"model_name_or_path_positional", type=str, nargs="?", default=None, help="Name of the pre-trained model."
|
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"generate_flags",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
|
||||||
|
"and lists of integers, more advanced parameterization should be set through --generation-config. "
|
||||||
|
"Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
|
||||||
|
"If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
||||||
|
),
|
||||||
|
nargs="*",
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_parser.set_defaults(func=chat_command_factory)
|
chat_parser.set_defaults(func=chat_command_factory)
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path
|
args = self._handle_deprecated_args(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
if args.model_name_or_path is None:
|
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
|
||||||
|
"""
|
||||||
|
Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
|
||||||
|
args.
|
||||||
|
"""
|
||||||
|
has_warnings = False
|
||||||
|
|
||||||
|
# 1. Model as a positional argument
|
||||||
|
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
|
||||||
|
if args.model_name_or_path_positional is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"One of the following must be provided:"
|
"One of the following must be provided:"
|
||||||
"\n- The positional argument containing the model repo;"
|
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
|
||||||
"\n- the optional --model_name_or_path argument, containing the model repo"
|
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
|
||||||
"\ne.g. transformers chat <model_repo> or transformers chat --model_name_or_path <model_repo>"
|
|
||||||
)
|
)
|
||||||
|
elif args.model_name_or_path is not None:
|
||||||
|
has_warnings = True
|
||||||
|
warnings.warn(
|
||||||
|
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
|
||||||
|
"argument instead, e.g. `transformers chat <model_repo>`.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
# 2. Named generate option args
|
||||||
|
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
|
||||||
|
value = getattr(args, deprecated_arg)
|
||||||
|
if value != default_value:
|
||||||
|
has_warnings = True
|
||||||
|
warnings.warn(
|
||||||
|
f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
|
||||||
|
"alternative solutions to specify this generation option: \n"
|
||||||
|
"1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.\n"
|
||||||
|
"2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> "
|
||||||
|
f"{new_arg}={value}`",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
|
||||||
self.args = args
|
if has_warnings:
|
||||||
|
print("\n(Press enter to continue)")
|
||||||
|
input()
|
||||||
|
return args
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------------------------------------------
|
||||||
# Chat session methods
|
# Chat session methods
|
||||||
@@ -319,7 +393,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
if filename is None:
|
if filename is None:
|
||||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
filename = f"{args.model_name_or_path}/chat_{time_str}.json"
|
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
|
||||||
filename = os.path.join(folder, filename)
|
filename = os.path.join(folder, filename)
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
@@ -338,50 +412,95 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------------------------------------------
|
||||||
# Input parsing methods
|
# Input parsing methods
|
||||||
@staticmethod
|
def parse_generate_flags(self, generate_flags: list[str]) -> dict:
|
||||||
def parse_settings(
|
"""Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
|
||||||
user_input: str, current_args: ChatArguments, interface: RichInterface
|
if len(generate_flags) == 0:
|
||||||
) -> tuple[ChatArguments, bool]:
|
return {}
|
||||||
"""Parses the settings from the user input into the CLI arguments."""
|
|
||||||
settings = user_input[4:].strip().split(";")
|
|
||||||
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
|
||||||
settings = dict(settings)
|
|
||||||
error = False
|
|
||||||
|
|
||||||
for name in settings:
|
# Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
|
||||||
if hasattr(current_args, name):
|
# into a json string if we:
|
||||||
try:
|
# 1. Add quotes around each flag name
|
||||||
if isinstance(getattr(current_args, name), bool):
|
generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
|
||||||
if settings[name] == "True":
|
|
||||||
settings[name] = True
|
|
||||||
elif settings[name] == "False":
|
|
||||||
settings[name] = False
|
|
||||||
else:
|
|
||||||
raise ValueError
|
|
||||||
else:
|
|
||||||
settings[name] = type(getattr(current_args, name))(settings[name])
|
|
||||||
except ValueError:
|
|
||||||
error = True
|
|
||||||
interface.print_color(
|
|
||||||
text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
interface.print_color(text=f"There is no '{name}' setting.", color="red")
|
|
||||||
|
|
||||||
if error:
|
# 2. Handle types:
|
||||||
interface.print_color(
|
# 2. a. booleans should be lowercase, None should be null
|
||||||
text="There was an issue parsing the settings. No settings have been changed.",
|
generate_flags_as_dict = {
|
||||||
color="red",
|
k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
|
||||||
|
}
|
||||||
|
generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
|
||||||
|
|
||||||
|
# 2. b. strings should be quoted
|
||||||
|
def is_number(s: str) -> bool:
|
||||||
|
return s.replace(".", "", 1).isdigit()
|
||||||
|
|
||||||
|
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
|
||||||
|
# 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
|
||||||
|
# We also mention in the help message that we only accept lists of ints for now.
|
||||||
|
|
||||||
|
# 3. Join the the result into a comma separated string
|
||||||
|
generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
|
||||||
|
|
||||||
|
# 4. Add the opening/closing brackets
|
||||||
|
generate_flags_string = "{" + generate_flags_string + "}"
|
||||||
|
|
||||||
|
# 5. Remove quotes around boolean/null and around lists
|
||||||
|
generate_flags_string = generate_flags_string.replace('"null"', "null")
|
||||||
|
generate_flags_string = generate_flags_string.replace('"true"', "true")
|
||||||
|
generate_flags_string = generate_flags_string.replace('"false"', "false")
|
||||||
|
generate_flags_string = generate_flags_string.replace('"[', "[")
|
||||||
|
generate_flags_string = generate_flags_string.replace(']"', "]")
|
||||||
|
|
||||||
|
# 6. Replace the `=` with `:`
|
||||||
|
generate_flags_string = generate_flags_string.replace("=", ":")
|
||||||
|
|
||||||
|
try:
|
||||||
|
processed_generate_flags = json.loads(generate_flags_string)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to convert `generate_flags` into a valid JSON object."
|
||||||
|
"\n`generate_flags` = {generate_flags}"
|
||||||
|
"\nConverted JSON string = {generate_flags_string}"
|
||||||
)
|
)
|
||||||
|
return processed_generate_flags
|
||||||
|
|
||||||
|
def get_generation_parameterization(
|
||||||
|
self, args: ChatArguments, tokenizer: AutoTokenizer
|
||||||
|
) -> tuple[GenerationConfig, dict]:
|
||||||
|
"""
|
||||||
|
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
||||||
|
"""
|
||||||
|
# No generation config arg provided -> use base generation config, apply CLI defaults
|
||||||
|
if args.generation_config is None:
|
||||||
|
generation_config = GenerationConfig()
|
||||||
|
# Apply deprecated CLI args on top of the default generation config
|
||||||
|
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||||
|
deprecated_kwargs = {
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"do_sample": args.do_sample,
|
||||||
|
"num_beams": args.num_beams,
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"top_k": args.top_k,
|
||||||
|
"top_p": args.top_p,
|
||||||
|
"repetition_penalty": args.repetition_penalty,
|
||||||
|
"pad_token_id": pad_token_id,
|
||||||
|
"eos_token_id": eos_token_ids,
|
||||||
|
}
|
||||||
|
generation_config.update(**deprecated_kwargs)
|
||||||
|
# generation config arg provided -> use it as the base parameterization
|
||||||
else:
|
else:
|
||||||
for name in settings:
|
if ".json" in args.generation_config: # is a local file
|
||||||
setattr(current_args, name, settings[name])
|
dirname = os.path.dirname(args.generation_config)
|
||||||
interface.print_color(text=f"Set {name} to {settings[name]}.", color="green")
|
filename = os.path.basename(args.generation_config)
|
||||||
|
generation_config = GenerationConfig.from_pretrained(dirname, filename)
|
||||||
|
else:
|
||||||
|
generation_config = GenerationConfig.from_pretrained(args.generation_config)
|
||||||
|
|
||||||
time.sleep(1.5) # so the user has time to read the changes
|
# Finally: parse and apply `generate_flags`
|
||||||
|
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
||||||
return current_args, not error
|
model_kwargs = generation_config.update(**parsed_generate_flags)
|
||||||
|
# `model_kwargs` contain non-generation flags in `parsed_generate_flags` that should be passed directly to
|
||||||
|
# `generate`
|
||||||
|
return generation_config, model_kwargs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_eos_tokens(
|
def parse_eos_tokens(
|
||||||
@@ -406,36 +525,6 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
return pad_token_id, all_eos_token_ids
|
return pad_token_id, all_eos_token_ids
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_valid_setting_command(s: str) -> bool:
|
|
||||||
# First check the basic structure
|
|
||||||
if not s.startswith("set ") or "=" not in s:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Split into individual assignments
|
|
||||||
assignments = [a.strip() for a in s[4:].split(";") if a.strip()]
|
|
||||||
|
|
||||||
for assignment in assignments:
|
|
||||||
# Each assignment should have exactly one '='
|
|
||||||
if assignment.count("=") != 1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
key, value = assignment.split("=", 1)
|
|
||||||
key = key.strip()
|
|
||||||
value = value.strip()
|
|
||||||
if not key or not value:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Keys can only have alphabetic characters, spaces and underscores
|
|
||||||
if not set(key).issubset(ALLOWED_KEY_CHARS):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Values can have just about anything that isn't a semicolon
|
|
||||||
if not set(value).issubset(ALLOWED_VALUE_CHARS):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------------------------------------------
|
||||||
# Model loading and performance automation methods
|
# Model loading and performance automation methods
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -460,7 +549,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path_positional,
|
||||||
revision=args.model_revision,
|
revision=args.model_revision,
|
||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
)
|
)
|
||||||
@@ -475,7 +564,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
"quantization_config": quantization_config,
|
"quantization_config": quantization_config,
|
||||||
}
|
}
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
|
args.model_name_or_path_positional, trust_remote_code=args.trust_remote_code, **model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(model, "hf_device_map", None) is None:
|
if getattr(model, "hf_device_map", None) is None:
|
||||||
@@ -483,6 +572,88 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------------------------------------------
|
||||||
|
# User commands
|
||||||
|
def handle_non_exit_user_commands(
|
||||||
|
self,
|
||||||
|
user_input: str,
|
||||||
|
args: ChatArguments,
|
||||||
|
interface: RichInterface,
|
||||||
|
examples: dict[str, dict[str, str]],
|
||||||
|
generation_config: GenerationConfig,
|
||||||
|
model_kwargs: dict,
|
||||||
|
chat: list[dict],
|
||||||
|
) -> tuple[list[dict], GenerationConfig, dict]:
|
||||||
|
"""
|
||||||
|
Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
|
||||||
|
generation config (e.g. set a new flag).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if user_input == "!clear":
|
||||||
|
chat = self.clear_chat_history(args.system_prompt)
|
||||||
|
interface.clear()
|
||||||
|
|
||||||
|
elif user_input == "!help":
|
||||||
|
interface.print_help()
|
||||||
|
|
||||||
|
elif user_input.startswith("!save") and len(user_input.split()) < 2:
|
||||||
|
split_input = user_input.split()
|
||||||
|
|
||||||
|
if len(split_input) == 2:
|
||||||
|
filename = split_input[1]
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
filename = self.save_chat(chat, args, filename)
|
||||||
|
interface.print_color(text=f"Chat saved in {filename}!", color="green")
|
||||||
|
|
||||||
|
elif user_input.startswith("!set"):
|
||||||
|
# splits the new args into a list of strings, each string being a `flag=value` pair (same format as
|
||||||
|
# `generate_flags`)
|
||||||
|
new_generate_flags = user_input[4:].strip()
|
||||||
|
new_generate_flags = new_generate_flags.split()
|
||||||
|
# sanity check: each member in the list must have an =
|
||||||
|
for flag in new_generate_flags:
|
||||||
|
if "=" not in flag:
|
||||||
|
interface.print_color(
|
||||||
|
text=(
|
||||||
|
f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
|
||||||
|
"`arg_1=value_1 arg_2=value_2 ...`."
|
||||||
|
),
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# parses the new args into a dictionary of `generate` kwargs, and updates the corresponding variables
|
||||||
|
parsed_new_generate_flags = self.parse_generate_flags(new_generate_flags)
|
||||||
|
new_model_kwargs = generation_config.update(**parsed_new_generate_flags)
|
||||||
|
model_kwargs.update(**new_model_kwargs)
|
||||||
|
|
||||||
|
elif user_input.startswith("!example") and len(user_input.split()) == 2:
|
||||||
|
example_name = user_input.split()[1]
|
||||||
|
if example_name in examples:
|
||||||
|
interface.clear()
|
||||||
|
chat = []
|
||||||
|
interface.print_user_message(examples[example_name]["text"])
|
||||||
|
chat.append({"role": "user", "content": examples[example_name]["text"]})
|
||||||
|
else:
|
||||||
|
example_error = (
|
||||||
|
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
||||||
|
)
|
||||||
|
interface.print_color(text=example_error, color="red")
|
||||||
|
|
||||||
|
elif user_input == "!status":
|
||||||
|
interface.print_status(
|
||||||
|
model_name=args.model_name_or_path_positional,
|
||||||
|
generation_config=generation_config,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
|
||||||
|
interface.print_help()
|
||||||
|
|
||||||
|
return chat, generation_config, model_kwargs
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------------------------------------------
|
||||||
# Main logic
|
# Main logic
|
||||||
def run(self):
|
def run(self):
|
||||||
@@ -498,8 +669,6 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
with open(args.examples_path) as f:
|
with open(args.examples_path) as f:
|
||||||
examples = yaml.safe_load(f)
|
examples = yaml.safe_load(f)
|
||||||
|
|
||||||
current_args = copy.deepcopy(args)
|
|
||||||
|
|
||||||
if args.user is None:
|
if args.user is None:
|
||||||
user = self.get_username()
|
user = self.get_username()
|
||||||
else:
|
else:
|
||||||
@@ -507,12 +676,11 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||||
|
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
|
||||||
|
|
||||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
|
||||||
|
|
||||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
|
||||||
interface.clear()
|
interface.clear()
|
||||||
chat = self.clear_chat_history(current_args.system_prompt)
|
chat = self.clear_chat_history(args.system_prompt)
|
||||||
|
|
||||||
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
||||||
interface.print_help(minimal=True)
|
interface.print_help(minimal=True)
|
||||||
@@ -520,57 +688,26 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
try:
|
try:
|
||||||
user_input = interface.input()
|
user_input = interface.input()
|
||||||
|
|
||||||
if user_input == "clear":
|
# User commands
|
||||||
chat = self.clear_chat_history(current_args.system_prompt)
|
if user_input.startswith("!"):
|
||||||
interface.clear()
|
# `!exit` is special, it breaks the loop
|
||||||
continue
|
if user_input == "!exit":
|
||||||
|
break
|
||||||
if user_input == "help":
|
|
||||||
interface.print_help()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input == "exit":
|
|
||||||
break
|
|
||||||
|
|
||||||
if user_input == "reset":
|
|
||||||
interface.clear()
|
|
||||||
current_args = copy.deepcopy(args)
|
|
||||||
chat = self.clear_chat_history(current_args.system_prompt)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.startswith("save") and len(user_input.split()) < 2:
|
|
||||||
split_input = user_input.split()
|
|
||||||
|
|
||||||
if len(split_input) == 2:
|
|
||||||
filename = split_input[1]
|
|
||||||
else:
|
else:
|
||||||
filename = None
|
chat, generation_config, model_kwargs = self.handle_non_exit_user_commands(
|
||||||
filename = self.save_chat(chat, current_args, filename)
|
user_input=user_input,
|
||||||
interface.print_color(text=f"Chat saved in {filename}!", color="green")
|
args=args,
|
||||||
continue
|
interface=interface,
|
||||||
|
examples=examples,
|
||||||
if self.is_valid_setting_command(user_input):
|
generation_config=generation_config,
|
||||||
current_args, success = self.parse_settings(user_input, current_args, interface)
|
model_kwargs=model_kwargs,
|
||||||
if success:
|
chat=chat,
|
||||||
chat = []
|
|
||||||
interface.clear()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.startswith("example") and len(user_input.split()) == 2:
|
|
||||||
example_name = user_input.split()[1]
|
|
||||||
if example_name in examples:
|
|
||||||
interface.clear()
|
|
||||||
chat = []
|
|
||||||
interface.print_user_message(examples[example_name]["text"])
|
|
||||||
user_input = examples[example_name]["text"]
|
|
||||||
else:
|
|
||||||
example_error = (
|
|
||||||
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
|
||||||
)
|
)
|
||||||
interface.print_color(text=example_error, color="red")
|
# `!example` sends a user message to the model
|
||||||
|
if not user_input.startswith("!example"):
|
||||||
continue
|
continue
|
||||||
|
else:
|
||||||
chat.append({"role": "user", "content": user_input})
|
chat.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||||
model.device
|
model.device
|
||||||
@@ -580,15 +717,8 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
"inputs": inputs,
|
"inputs": inputs,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"streamer": generation_streamer,
|
"streamer": generation_streamer,
|
||||||
"max_new_tokens": current_args.max_new_tokens,
|
"generation_config": generation_config,
|
||||||
"do_sample": current_args.do_sample,
|
**model_kwargs,
|
||||||
"num_beams": current_args.num_beams,
|
|
||||||
"temperature": current_args.temperature,
|
|
||||||
"top_k": current_args.top_k,
|
|
||||||
"top_p": current_args.top_p,
|
|
||||||
"repetition_penalty": current_args.repetition_penalty,
|
|
||||||
"pad_token_id": pad_token_id,
|
|
||||||
"eos_token_id": eos_token_ids,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user