[chat] generate parameterization powered by GenerationConfig and UX-related changes (#38047)

* accept arbitrary kwargs

* move user commands to a separate fn

* work with generation config files

* rm cmmt

* docs

* base generate flag doc section

* nits

* nits

* nits

* no <br>

* better basic args description
This commit is contained in:
Joao Gante
2025-05-12 14:04:41 +01:00
committed by GitHub
parent a5c6172c81
commit 8efe3a9d77
4 changed files with 326 additions and 172 deletions

View File

@@ -120,7 +120,7 @@ To chat with a model, the usage pattern is the same. The only difference is you
> [!TIP]
> You can also chat with a model directly from the command line.
> ```shell
> transformers chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
> ```
```py

View File

@@ -27,7 +27,7 @@ This guide shows you how to quickly start chatting with Transformers from the co
## transformers CLI
Chat with a model directly from the command line as shown below. It launches an interactive session with a model. Enter `clear` to reset the conversation, `exit` to terminate the session, and `help` to display all the command options.
After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
```bash
transformers chat Qwen/Qwen2.5-0.5B-Instruct
@@ -37,6 +37,12 @@ transformers chat Qwen/Qwen2.5-0.5B-Instruct
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers-chat-cli.png"/>
</div>
You can launch the CLI with arbitrary `generate` flags, with the format `arg_1=value_1 arg_2=value_2 ...`
```bash
transformers chat Qwen/Qwen2.5-0.5B-Instruct do_sample=False max_new_tokens=10
```
For a full list of options, run the command below.
```bash

View File

@@ -20,9 +20,13 @@ rendered properly in your Markdown viewer.
Text generation is the most popular application for large language models (LLMs). A LLM is trained to generate the next word (token) given some initial text (prompt) along with its own generated outputs up to a predefined length or when it reaches an end-of-sequence (`EOS`) token.
In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities.
In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities. This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
> [!TIP]
> You can also chat with a model directly from the command line. ([reference](./conversations.md#transformers-cli))
> ```shell
> transformers chat Qwen/Qwen2.5-0.5B-Instruct
> ```
## Default generate
@@ -134,6 +138,20 @@ outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
```
## Common Options
[`~GenerationMixin.generate`] is a powerful tool that can be heavily customized. This can be daunting for a new users. This section contains a list of popular generation options that you can define in most text generation tools in Transformers: [`~GenerationMixin.generate`], [`GenerationConfig`], `pipelines`, the `chat` CLI, ...
| Option name | Type | Simplified description |
|---|---|---|
| `max_new_tokens` | `int` | Controls the maximum generation length. Be sure to define it, as it usually defaults to a small value. |
| `do_sample` | `bool` | Defines whether generation will sample the next token (`True`), or is greedy instead (`False`). Most use cases should set this flag to `True`. Check [this guide](./generation_strategies.md) for more information. |
| `temperature` | `float` | How unpredictable the next selected token will be. High values (`>0.8`) are good for creative tasks, low values (e.g. `<0.4`) for tasks that require "thinking". Requires `do_sample=True`. |
| `num_beams` | `int` | When set to `>1`, activates the beam search algorithm. Beam search is good on input-grounded tasks. Check [this guide](./generation_strategies.md) for more information. |
| `repetition_penalty` | `float` | Set it to `>1.0` if you're seeing the model repeat itself often. Larger values apply a larger penalty. |
| `eos_token_id` | `List[int]` | The token(s) that will cause generation to stop. The default value is usually good, but you can specify a different token. |
## Pitfalls
The section below covers some common issues you may encounter during text generation and how to solve them.
@@ -286,4 +304,4 @@ Take a look below for some more specific and specialized text generation librari
- [SynCode](https://github.com/uiuc-focal-lab/syncode): a library for context-free grammar guided generation (JSON, SQL, Python).
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a production-ready server for LLMs.
- [Text generation web UI](https://github.com/oobabooga/text-generation-webui): a Gradio web UI for text generation.
- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.
- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.

View File

@@ -13,12 +13,12 @@
# limitations under the License.
import copy
import json
import os
import platform
import string
import time
import warnings
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
@@ -42,7 +42,13 @@ if is_rich_available():
if is_torch_available():
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
TextIteratorStreamer,
)
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
@@ -64,25 +70,16 @@ DEFAULT_EXAMPLES = {
"socks": {"text": "Why is it important to eat socks after meditating?"},
}
SUPPORTED_GENERATION_KWARGS = [
"max_new_tokens",
"do_sample",
"num_beams",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
]
# Printed at the start of a chat session
HELP_STRING_MINIMAL = """
**TRANSFORMERS CHAT INTERFACE**
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
- **help**: shows all available commands
- **clear**: clears the current conversation and starts a new one
- **exit**: closes the interface
- **!help**: shows all available commands
- **!status**: shows the current status of the model and generation settings
- **!clear**: clears the current conversation and starts a new one
- **!exit**: closes the interface
"""
@@ -92,18 +89,32 @@ HELP_STRING = f"""
**TRANSFORMERS CHAT INTERFACE HELP**
Full command list:
- **help**: shows this help message
- **clear**: clears the current conversation and starts a new one
- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example
names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are
separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}`
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
- **!help**: shows this help message
- **!clear**: clears the current conversation and starts a new one
- **!status**: shows the current status of the model and generation settings
- **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
- **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
- **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
- **exit**: closes the interface
- **!exit**: closes the interface
"""
# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
_DEPRECATION_MAP = [
("max_new_tokens", 256, "max_new_tokens"),
("do_sample", True, "do_sample"),
("num_beams", 1, "num_beams"),
("temperature", 1.0, "temperature"),
("top_k", 50, "top_k"),
("top_p", 1.0, "top_p"),
("repetition_penalty", 1.0, "repetition_penalty"),
("eos_tokens", None, "eos_token_id"),
("eos_token_ids", None, "eos_token_id"),
]
class RichInterface:
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
@@ -181,6 +192,14 @@ class RichInterface:
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
self._console.print()
def print_status(self, model_name: str, generation_config: GenerationConfig, model_kwargs: dict):
"""Prints the status of the model and generation settings to the console."""
self._console.print(f"[bold blue]Model: {model_name}\n")
if model_kwargs:
self._console.print(f"[bold blue]Model kwargs: {model_kwargs}")
self._console.print(f"[bold blue]{generation_config}")
self._console.print()
@dataclass
class ChatArguments:
@@ -207,6 +226,17 @@ class ChatArguments:
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
# Generation settings
generation_config: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to a local generation config file or to a HuggingFace repo containing a "
"`generation_config.json` file. Other generation settings passed as CLI arguments will be applied on "
"top of this generation config."
),
},
)
# Deprecated CLI args start here
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
@@ -222,6 +252,7 @@ class ChatArguments:
default=None,
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
)
# Deprecated CLI args end here
# Model loading
model_revision: str = field(
@@ -280,23 +311,66 @@ class ChatCommand(BaseTransformersCLICommand):
group = chat_parser.add_argument_group("Positional arguments")
group.add_argument(
"model_name_or_path_positional", type=str, nargs="?", default=None, help="Name of the pre-trained model."
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
)
group.add_argument(
"generate_flags",
type=str,
default=None,
help=(
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
"and lists of integers, more advanced parameterization should be set through --generation-config. "
"Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
"If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options"
),
nargs="*",
)
chat_parser.set_defaults(func=chat_command_factory)
def __init__(self, args):
args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path
args = self._handle_deprecated_args(args)
self.args = args
if args.model_name_or_path is None:
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
"""
Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
args.
"""
has_warnings = False
# 1. Model as a positional argument
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
if args.model_name_or_path_positional is None:
raise ValueError(
"One of the following must be provided:"
"\n- The positional argument containing the model repo;"
"\n- the optional --model_name_or_path argument, containing the model repo"
"\ne.g. transformers chat <model_repo> or transformers chat --model_name_or_path <model_repo>"
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
)
elif args.model_name_or_path is not None:
has_warnings = True
warnings.warn(
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
"argument instead, e.g. `transformers chat <model_repo>`.",
FutureWarning,
)
# 2. Named generate option args
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
value = getattr(args, deprecated_arg)
if value != default_value:
has_warnings = True
warnings.warn(
f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
"alternative solutions to specify this generation option: \n"
"1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.\n"
"2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> "
f"{new_arg}={value}`",
FutureWarning,
)
self.args = args
if has_warnings:
print("\n(Press enter to continue)")
input()
return args
# -----------------------------------------------------------------------------------------------------------------
# Chat session methods
@@ -319,7 +393,7 @@ class ChatCommand(BaseTransformersCLICommand):
if filename is None:
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{args.model_name_or_path}/chat_{time_str}.json"
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -338,50 +412,95 @@ class ChatCommand(BaseTransformersCLICommand):
# -----------------------------------------------------------------------------------------------------------------
# Input parsing methods
@staticmethod
def parse_settings(
user_input: str, current_args: ChatArguments, interface: RichInterface
) -> tuple[ChatArguments, bool]:
"""Parses the settings from the user input into the CLI arguments."""
settings = user_input[4:].strip().split(";")
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
settings = dict(settings)
error = False
def parse_generate_flags(self, generate_flags: list[str]) -> dict:
"""Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
if len(generate_flags) == 0:
return {}
for name in settings:
if hasattr(current_args, name):
try:
if isinstance(getattr(current_args, name), bool):
if settings[name] == "True":
settings[name] = True
elif settings[name] == "False":
settings[name] = False
else:
raise ValueError
else:
settings[name] = type(getattr(current_args, name))(settings[name])
except ValueError:
error = True
interface.print_color(
text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.",
color="red",
)
else:
interface.print_color(text=f"There is no '{name}' setting.", color="red")
# Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
# into a json string if we:
# 1. Add quotes around each flag name
generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
if error:
interface.print_color(
text="There was an issue parsing the settings. No settings have been changed.",
color="red",
# 2. Handle types:
# 2. a. booleans should be lowercase, None should be null
generate_flags_as_dict = {
k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
}
generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
# 2. b. strings should be quoted
def is_number(s: str) -> bool:
return s.replace(".", "", 1).isdigit()
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
# 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
# We also mention in the help message that we only accept lists of ints for now.
# 3. Join the the result into a comma separated string
generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
# 4. Add the opening/closing brackets
generate_flags_string = "{" + generate_flags_string + "}"
# 5. Remove quotes around boolean/null and around lists
generate_flags_string = generate_flags_string.replace('"null"', "null")
generate_flags_string = generate_flags_string.replace('"true"', "true")
generate_flags_string = generate_flags_string.replace('"false"', "false")
generate_flags_string = generate_flags_string.replace('"[', "[")
generate_flags_string = generate_flags_string.replace(']"', "]")
# 6. Replace the `=` with `:`
generate_flags_string = generate_flags_string.replace("=", ":")
try:
processed_generate_flags = json.loads(generate_flags_string)
except json.JSONDecodeError:
raise ValueError(
"Failed to convert `generate_flags` into a valid JSON object."
"\n`generate_flags` = {generate_flags}"
"\nConverted JSON string = {generate_flags_string}"
)
return processed_generate_flags
def get_generation_parameterization(
self, args: ChatArguments, tokenizer: AutoTokenizer
) -> tuple[GenerationConfig, dict]:
"""
Returns a GenerationConfig object holding the generation parameters for the CLI command.
"""
# No generation config arg provided -> use base generation config, apply CLI defaults
if args.generation_config is None:
generation_config = GenerationConfig()
# Apply deprecated CLI args on top of the default generation config
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
deprecated_kwargs = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.do_sample,
"num_beams": args.num_beams,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"pad_token_id": pad_token_id,
"eos_token_id": eos_token_ids,
}
generation_config.update(**deprecated_kwargs)
# generation config arg provided -> use it as the base parameterization
else:
for name in settings:
setattr(current_args, name, settings[name])
interface.print_color(text=f"Set {name} to {settings[name]}.", color="green")
if ".json" in args.generation_config: # is a local file
dirname = os.path.dirname(args.generation_config)
filename = os.path.basename(args.generation_config)
generation_config = GenerationConfig.from_pretrained(dirname, filename)
else:
generation_config = GenerationConfig.from_pretrained(args.generation_config)
time.sleep(1.5) # so the user has time to read the changes
return current_args, not error
# Finally: parse and apply `generate_flags`
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
model_kwargs = generation_config.update(**parsed_generate_flags)
# `model_kwargs` contain non-generation flags in `parsed_generate_flags` that should be passed directly to
# `generate`
return generation_config, model_kwargs
@staticmethod
def parse_eos_tokens(
@@ -406,36 +525,6 @@ class ChatCommand(BaseTransformersCLICommand):
return pad_token_id, all_eos_token_ids
@staticmethod
def is_valid_setting_command(s: str) -> bool:
# First check the basic structure
if not s.startswith("set ") or "=" not in s:
return False
# Split into individual assignments
assignments = [a.strip() for a in s[4:].split(";") if a.strip()]
for assignment in assignments:
# Each assignment should have exactly one '='
if assignment.count("=") != 1:
return False
key, value = assignment.split("=", 1)
key = key.strip()
value = value.strip()
if not key or not value:
return False
# Keys can only have alphabetic characters, spaces and underscores
if not set(key).issubset(ALLOWED_KEY_CHARS):
return False
# Values can have just about anything that isn't a semicolon
if not set(value).issubset(ALLOWED_VALUE_CHARS):
return False
return True
# -----------------------------------------------------------------------------------------------------------------
# Model loading and performance automation methods
@staticmethod
@@ -460,7 +549,7 @@ class ChatCommand(BaseTransformersCLICommand):
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
args.model_name_or_path_positional,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
)
@@ -475,7 +564,7 @@ class ChatCommand(BaseTransformersCLICommand):
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
args.model_name_or_path_positional, trust_remote_code=args.trust_remote_code, **model_kwargs
)
if getattr(model, "hf_device_map", None) is None:
@@ -483,6 +572,88 @@ class ChatCommand(BaseTransformersCLICommand):
return model, tokenizer
# -----------------------------------------------------------------------------------------------------------------
# User commands
def handle_non_exit_user_commands(
self,
user_input: str,
args: ChatArguments,
interface: RichInterface,
examples: dict[str, dict[str, str]],
generation_config: GenerationConfig,
model_kwargs: dict,
chat: list[dict],
) -> tuple[list[dict], GenerationConfig, dict]:
"""
Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
generation config (e.g. set a new flag).
"""
if user_input == "!clear":
chat = self.clear_chat_history(args.system_prompt)
interface.clear()
elif user_input == "!help":
interface.print_help()
elif user_input.startswith("!save") and len(user_input.split()) < 2:
split_input = user_input.split()
if len(split_input) == 2:
filename = split_input[1]
else:
filename = None
filename = self.save_chat(chat, args, filename)
interface.print_color(text=f"Chat saved in {filename}!", color="green")
elif user_input.startswith("!set"):
# splits the new args into a list of strings, each string being a `flag=value` pair (same format as
# `generate_flags`)
new_generate_flags = user_input[4:].strip()
new_generate_flags = new_generate_flags.split()
# sanity check: each member in the list must have an =
for flag in new_generate_flags:
if "=" not in flag:
interface.print_color(
text=(
f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
"`arg_1=value_1 arg_2=value_2 ...`."
),
color="red",
)
break
else:
# parses the new args into a dictionary of `generate` kwargs, and updates the corresponding variables
parsed_new_generate_flags = self.parse_generate_flags(new_generate_flags)
new_model_kwargs = generation_config.update(**parsed_new_generate_flags)
model_kwargs.update(**new_model_kwargs)
elif user_input.startswith("!example") and len(user_input.split()) == 2:
example_name = user_input.split()[1]
if example_name in examples:
interface.clear()
chat = []
interface.print_user_message(examples[example_name]["text"])
chat.append({"role": "user", "content": examples[example_name]["text"]})
else:
example_error = (
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
)
interface.print_color(text=example_error, color="red")
elif user_input == "!status":
interface.print_status(
model_name=args.model_name_or_path_positional,
generation_config=generation_config,
model_kwargs=model_kwargs,
)
else:
interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
interface.print_help()
return chat, generation_config, model_kwargs
# -----------------------------------------------------------------------------------------------------------------
# Main logic
def run(self):
@@ -498,8 +669,6 @@ class ChatCommand(BaseTransformersCLICommand):
with open(args.examples_path) as f:
examples = yaml.safe_load(f)
current_args = copy.deepcopy(args)
if args.user is None:
user = self.get_username()
else:
@@ -507,12 +676,11 @@ class ChatCommand(BaseTransformersCLICommand):
model, tokenizer = self.load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
interface.clear()
chat = self.clear_chat_history(current_args.system_prompt)
chat = self.clear_chat_history(args.system_prompt)
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
interface.print_help(minimal=True)
@@ -520,57 +688,26 @@ class ChatCommand(BaseTransformersCLICommand):
try:
user_input = interface.input()
if user_input == "clear":
chat = self.clear_chat_history(current_args.system_prompt)
interface.clear()
continue
if user_input == "help":
interface.print_help()
continue
if user_input == "exit":
break
if user_input == "reset":
interface.clear()
current_args = copy.deepcopy(args)
chat = self.clear_chat_history(current_args.system_prompt)
continue
if user_input.startswith("save") and len(user_input.split()) < 2:
split_input = user_input.split()
if len(split_input) == 2:
filename = split_input[1]
# User commands
if user_input.startswith("!"):
# `!exit` is special, it breaks the loop
if user_input == "!exit":
break
else:
filename = None
filename = self.save_chat(chat, current_args, filename)
interface.print_color(text=f"Chat saved in {filename}!", color="green")
continue
if self.is_valid_setting_command(user_input):
current_args, success = self.parse_settings(user_input, current_args, interface)
if success:
chat = []
interface.clear()
continue
if user_input.startswith("example") and len(user_input.split()) == 2:
example_name = user_input.split()[1]
if example_name in examples:
interface.clear()
chat = []
interface.print_user_message(examples[example_name]["text"])
user_input = examples[example_name]["text"]
else:
example_error = (
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
chat, generation_config, model_kwargs = self.handle_non_exit_user_commands(
user_input=user_input,
args=args,
interface=interface,
examples=examples,
generation_config=generation_config,
model_kwargs=model_kwargs,
chat=chat,
)
interface.print_color(text=example_error, color="red")
# `!example` sends a user message to the model
if not user_input.startswith("!example"):
continue
chat.append({"role": "user", "content": user_input})
else:
chat.append({"role": "user", "content": user_input})
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
@@ -580,15 +717,8 @@ class ChatCommand(BaseTransformersCLICommand):
"inputs": inputs,
"attention_mask": attention_mask,
"streamer": generation_streamer,
"max_new_tokens": current_args.max_new_tokens,
"do_sample": current_args.do_sample,
"num_beams": current_args.num_beams,
"temperature": current_args.temperature,
"top_k": current_args.top_k,
"top_p": current_args.top_p,
"repetition_penalty": current_args.repetition_penalty,
"pad_token_id": pad_token_id,
"eos_token_id": eos_token_ids,
"generation_config": generation_config,
**model_kwargs,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)