diff --git a/README.md b/README.md
index 3ebe4e9134..d8a711b48e 100644
--- a/README.md
+++ b/README.md
@@ -120,7 +120,7 @@ To chat with a model, the usage pattern is the same. The only difference is you
> [!TIP]
> You can also chat with a model directly from the command line.
> ```shell
-> transformers chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
+> transformers chat Qwen/Qwen2.5-0.5B-Instruct
> ```
```py
diff --git a/docs/source/en/conversations.md b/docs/source/en/conversations.md
index b102f0c09d..94bb9fd591 100644
--- a/docs/source/en/conversations.md
+++ b/docs/source/en/conversations.md
@@ -27,7 +27,7 @@ This guide shows you how to quickly start chatting with Transformers from the co
## transformers CLI
-Chat with a model directly from the command line as shown below. It launches an interactive session with a model. Enter `clear` to reset the conversation, `exit` to terminate the session, and `help` to display all the command options.
+After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
```bash
transformers chat Qwen/Qwen2.5-0.5B-Instruct
@@ -37,6 +37,12 @@ transformers chat Qwen/Qwen2.5-0.5B-Instruct
+You can launch the CLI with arbitrary `generate` flags, with the format `arg_1=value_1 arg_2=value_2 ...`
+
+```bash
+transformers chat Qwen/Qwen2.5-0.5B-Instruct do_sample=False max_new_tokens=10
+```
+
For a full list of options, run the command below.
```bash
diff --git a/docs/source/en/llm_tutorial.md b/docs/source/en/llm_tutorial.md
index d867657202..a191cdb463 100644
--- a/docs/source/en/llm_tutorial.md
+++ b/docs/source/en/llm_tutorial.md
@@ -20,9 +20,13 @@ rendered properly in your Markdown viewer.
Text generation is the most popular application for large language models (LLMs). A LLM is trained to generate the next word (token) given some initial text (prompt) along with its own generated outputs up to a predefined length or when it reaches an end-of-sequence (`EOS`) token.
-In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities.
+In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities. This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
-This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid.
+> [!TIP]
+> You can also chat with a model directly from the command line. ([reference](./conversations.md#transformers-cli))
+> ```shell
+> transformers chat Qwen/Qwen2.5-0.5B-Instruct
+> ```
## Default generate
@@ -134,6 +138,20 @@ outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
```
+## Common Options
+
+[`~GenerationMixin.generate`] is a powerful tool that can be heavily customized. This can be daunting for a new users. This section contains a list of popular generation options that you can define in most text generation tools in Transformers: [`~GenerationMixin.generate`], [`GenerationConfig`], `pipelines`, the `chat` CLI, ...
+
+| Option name | Type | Simplified description |
+|---|---|---|
+| `max_new_tokens` | `int` | Controls the maximum generation length. Be sure to define it, as it usually defaults to a small value. |
+| `do_sample` | `bool` | Defines whether generation will sample the next token (`True`), or is greedy instead (`False`). Most use cases should set this flag to `True`. Check [this guide](./generation_strategies.md) for more information. |
+| `temperature` | `float` | How unpredictable the next selected token will be. High values (`>0.8`) are good for creative tasks, low values (e.g. `<0.4`) for tasks that require "thinking". Requires `do_sample=True`. |
+| `num_beams` | `int` | When set to `>1`, activates the beam search algorithm. Beam search is good on input-grounded tasks. Check [this guide](./generation_strategies.md) for more information. |
+| `repetition_penalty` | `float` | Set it to `>1.0` if you're seeing the model repeat itself often. Larger values apply a larger penalty. |
+| `eos_token_id` | `List[int]` | The token(s) that will cause generation to stop. The default value is usually good, but you can specify a different token. |
+
+
## Pitfalls
The section below covers some common issues you may encounter during text generation and how to solve them.
@@ -286,4 +304,4 @@ Take a look below for some more specific and specialized text generation librari
- [SynCode](https://github.com/uiuc-focal-lab/syncode): a library for context-free grammar guided generation (JSON, SQL, Python).
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a production-ready server for LLMs.
- [Text generation web UI](https://github.com/oobabooga/text-generation-webui): a Gradio web UI for text generation.
-- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.
\ No newline at end of file
+- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation.
diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py
index f276368ce8..5c9bd76bdb 100644
--- a/src/transformers/commands/chat.py
+++ b/src/transformers/commands/chat.py
@@ -13,12 +13,12 @@
# limitations under the License.
-import copy
import json
import os
import platform
import string
import time
+import warnings
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
@@ -42,7 +42,13 @@ if is_rich_available():
if is_torch_available():
import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
+ from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ BitsAndBytesConfig,
+ GenerationConfig,
+ TextIteratorStreamer,
+ )
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
@@ -64,25 +70,16 @@ DEFAULT_EXAMPLES = {
"socks": {"text": "Why is it important to eat socks after meditating?"},
}
-SUPPORTED_GENERATION_KWARGS = [
- "max_new_tokens",
- "do_sample",
- "num_beams",
- "temperature",
- "top_p",
- "top_k",
- "repetition_penalty",
-]
-
# Printed at the start of a chat session
HELP_STRING_MINIMAL = """
**TRANSFORMERS CHAT INTERFACE**
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
-- **help**: shows all available commands
-- **clear**: clears the current conversation and starts a new one
-- **exit**: closes the interface
+- **!help**: shows all available commands
+- **!status**: shows the current status of the model and generation settings
+- **!clear**: clears the current conversation and starts a new one
+- **!exit**: closes the interface
"""
@@ -92,18 +89,32 @@ HELP_STRING = f"""
**TRANSFORMERS CHAT INTERFACE HELP**
Full command list:
-- **help**: shows this help message
-- **clear**: clears the current conversation and starts a new one
-- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example
-names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
-- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are
-separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}`
-- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
-- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
+- **!help**: shows this help message
+- **!clear**: clears the current conversation and starts a new one
+- **!status**: shows the current status of the model and generation settings
+- **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
+Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
+- **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
+settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
+If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
+- **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
-- **exit**: closes the interface
+- **!exit**: closes the interface
"""
+# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
+_DEPRECATION_MAP = [
+ ("max_new_tokens", 256, "max_new_tokens"),
+ ("do_sample", True, "do_sample"),
+ ("num_beams", 1, "num_beams"),
+ ("temperature", 1.0, "temperature"),
+ ("top_k", 50, "top_k"),
+ ("top_p", 1.0, "top_p"),
+ ("repetition_penalty", 1.0, "repetition_penalty"),
+ ("eos_tokens", None, "eos_token_id"),
+ ("eos_token_ids", None, "eos_token_id"),
+]
+
class RichInterface:
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
@@ -181,6 +192,14 @@ class RichInterface:
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
self._console.print()
+ def print_status(self, model_name: str, generation_config: GenerationConfig, model_kwargs: dict):
+ """Prints the status of the model and generation settings to the console."""
+ self._console.print(f"[bold blue]Model: {model_name}\n")
+ if model_kwargs:
+ self._console.print(f"[bold blue]Model kwargs: {model_kwargs}")
+ self._console.print(f"[bold blue]{generation_config}")
+ self._console.print()
+
@dataclass
class ChatArguments:
@@ -207,6 +226,17 @@ class ChatArguments:
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
# Generation settings
+ generation_config: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Path to a local generation config file or to a HuggingFace repo containing a "
+ "`generation_config.json` file. Other generation settings passed as CLI arguments will be applied on "
+ "top of this generation config."
+ ),
+ },
+ )
+ # Deprecated CLI args start here
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
@@ -222,6 +252,7 @@ class ChatArguments:
default=None,
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
)
+ # Deprecated CLI args end here
# Model loading
model_revision: str = field(
@@ -280,23 +311,66 @@ class ChatCommand(BaseTransformersCLICommand):
group = chat_parser.add_argument_group("Positional arguments")
group.add_argument(
- "model_name_or_path_positional", type=str, nargs="?", default=None, help="Name of the pre-trained model."
+ "model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
+ )
+ group.add_argument(
+ "generate_flags",
+ type=str,
+ default=None,
+ help=(
+ "Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
+ "and lists of integers, more advanced parameterization should be set through --generation-config. "
+ "Example: `transformers chat max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
+ "If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options"
+ ),
+ nargs="*",
)
-
chat_parser.set_defaults(func=chat_command_factory)
def __init__(self, args):
- args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path
+ args = self._handle_deprecated_args(args)
+ self.args = args
- if args.model_name_or_path is None:
+ def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
+ """
+ Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
+ args.
+ """
+ has_warnings = False
+
+ # 1. Model as a positional argument
+ args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
+ if args.model_name_or_path_positional is None:
raise ValueError(
"One of the following must be provided:"
- "\n- The positional argument containing the model repo;"
- "\n- the optional --model_name_or_path argument, containing the model repo"
- "\ne.g. transformers chat or transformers chat --model_name_or_path "
+ "\n- The positional argument containing the model repo, e.g. `transformers chat `"
+ "\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
)
+ elif args.model_name_or_path is not None:
+ has_warnings = True
+ warnings.warn(
+ "The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
+ "argument instead, e.g. `transformers chat `.",
+ FutureWarning,
+ )
+ # 2. Named generate option args
+ for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
+ value = getattr(args, deprecated_arg)
+ if value != default_value:
+ has_warnings = True
+ warnings.warn(
+ f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
+ "alternative solutions to specify this generation option: \n"
+ "1. Pass `--generation-config ` to specify a generation config.\n"
+ "2. Pass `generate` flags through positional arguments, e.g. `transformers chat "
+ f"{new_arg}={value}`",
+ FutureWarning,
+ )
- self.args = args
+ if has_warnings:
+ print("\n(Press enter to continue)")
+ input()
+ return args
# -----------------------------------------------------------------------------------------------------------------
# Chat session methods
@@ -319,7 +393,7 @@ class ChatCommand(BaseTransformersCLICommand):
if filename is None:
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
- filename = f"{args.model_name_or_path}/chat_{time_str}.json"
+ filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -338,50 +412,95 @@ class ChatCommand(BaseTransformersCLICommand):
# -----------------------------------------------------------------------------------------------------------------
# Input parsing methods
- @staticmethod
- def parse_settings(
- user_input: str, current_args: ChatArguments, interface: RichInterface
- ) -> tuple[ChatArguments, bool]:
- """Parses the settings from the user input into the CLI arguments."""
- settings = user_input[4:].strip().split(";")
- settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
- settings = dict(settings)
- error = False
+ def parse_generate_flags(self, generate_flags: list[str]) -> dict:
+ """Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
+ if len(generate_flags) == 0:
+ return {}
- for name in settings:
- if hasattr(current_args, name):
- try:
- if isinstance(getattr(current_args, name), bool):
- if settings[name] == "True":
- settings[name] = True
- elif settings[name] == "False":
- settings[name] = False
- else:
- raise ValueError
- else:
- settings[name] = type(getattr(current_args, name))(settings[name])
- except ValueError:
- error = True
- interface.print_color(
- text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.",
- color="red",
- )
- else:
- interface.print_color(text=f"There is no '{name}' setting.", color="red")
+ # Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
+ # into a json string if we:
+ # 1. Add quotes around each flag name
+ generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
- if error:
- interface.print_color(
- text="There was an issue parsing the settings. No settings have been changed.",
- color="red",
+ # 2. Handle types:
+ # 2. a. booleans should be lowercase, None should be null
+ generate_flags_as_dict = {
+ k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
+ }
+ generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
+
+ # 2. b. strings should be quoted
+ def is_number(s: str) -> bool:
+ return s.replace(".", "", 1).isdigit()
+
+ generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
+ # 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
+ # We also mention in the help message that we only accept lists of ints for now.
+
+ # 3. Join the the result into a comma separated string
+ generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
+
+ # 4. Add the opening/closing brackets
+ generate_flags_string = "{" + generate_flags_string + "}"
+
+ # 5. Remove quotes around boolean/null and around lists
+ generate_flags_string = generate_flags_string.replace('"null"', "null")
+ generate_flags_string = generate_flags_string.replace('"true"', "true")
+ generate_flags_string = generate_flags_string.replace('"false"', "false")
+ generate_flags_string = generate_flags_string.replace('"[', "[")
+ generate_flags_string = generate_flags_string.replace(']"', "]")
+
+ # 6. Replace the `=` with `:`
+ generate_flags_string = generate_flags_string.replace("=", ":")
+
+ try:
+ processed_generate_flags = json.loads(generate_flags_string)
+ except json.JSONDecodeError:
+ raise ValueError(
+ "Failed to convert `generate_flags` into a valid JSON object."
+ "\n`generate_flags` = {generate_flags}"
+ "\nConverted JSON string = {generate_flags_string}"
)
+ return processed_generate_flags
+
+ def get_generation_parameterization(
+ self, args: ChatArguments, tokenizer: AutoTokenizer
+ ) -> tuple[GenerationConfig, dict]:
+ """
+ Returns a GenerationConfig object holding the generation parameters for the CLI command.
+ """
+ # No generation config arg provided -> use base generation config, apply CLI defaults
+ if args.generation_config is None:
+ generation_config = GenerationConfig()
+ # Apply deprecated CLI args on top of the default generation config
+ pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
+ deprecated_kwargs = {
+ "max_new_tokens": args.max_new_tokens,
+ "do_sample": args.do_sample,
+ "num_beams": args.num_beams,
+ "temperature": args.temperature,
+ "top_k": args.top_k,
+ "top_p": args.top_p,
+ "repetition_penalty": args.repetition_penalty,
+ "pad_token_id": pad_token_id,
+ "eos_token_id": eos_token_ids,
+ }
+ generation_config.update(**deprecated_kwargs)
+ # generation config arg provided -> use it as the base parameterization
else:
- for name in settings:
- setattr(current_args, name, settings[name])
- interface.print_color(text=f"Set {name} to {settings[name]}.", color="green")
+ if ".json" in args.generation_config: # is a local file
+ dirname = os.path.dirname(args.generation_config)
+ filename = os.path.basename(args.generation_config)
+ generation_config = GenerationConfig.from_pretrained(dirname, filename)
+ else:
+ generation_config = GenerationConfig.from_pretrained(args.generation_config)
- time.sleep(1.5) # so the user has time to read the changes
-
- return current_args, not error
+ # Finally: parse and apply `generate_flags`
+ parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
+ model_kwargs = generation_config.update(**parsed_generate_flags)
+ # `model_kwargs` contain non-generation flags in `parsed_generate_flags` that should be passed directly to
+ # `generate`
+ return generation_config, model_kwargs
@staticmethod
def parse_eos_tokens(
@@ -406,36 +525,6 @@ class ChatCommand(BaseTransformersCLICommand):
return pad_token_id, all_eos_token_ids
- @staticmethod
- def is_valid_setting_command(s: str) -> bool:
- # First check the basic structure
- if not s.startswith("set ") or "=" not in s:
- return False
-
- # Split into individual assignments
- assignments = [a.strip() for a in s[4:].split(";") if a.strip()]
-
- for assignment in assignments:
- # Each assignment should have exactly one '='
- if assignment.count("=") != 1:
- return False
-
- key, value = assignment.split("=", 1)
- key = key.strip()
- value = value.strip()
- if not key or not value:
- return False
-
- # Keys can only have alphabetic characters, spaces and underscores
- if not set(key).issubset(ALLOWED_KEY_CHARS):
- return False
-
- # Values can have just about anything that isn't a semicolon
- if not set(value).issubset(ALLOWED_VALUE_CHARS):
- return False
-
- return True
-
# -----------------------------------------------------------------------------------------------------------------
# Model loading and performance automation methods
@staticmethod
@@ -460,7 +549,7 @@ class ChatCommand(BaseTransformersCLICommand):
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(
- args.model_name_or_path,
+ args.model_name_or_path_positional,
revision=args.model_revision,
trust_remote_code=args.trust_remote_code,
)
@@ -475,7 +564,7 @@ class ChatCommand(BaseTransformersCLICommand):
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(
- args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
+ args.model_name_or_path_positional, trust_remote_code=args.trust_remote_code, **model_kwargs
)
if getattr(model, "hf_device_map", None) is None:
@@ -483,6 +572,88 @@ class ChatCommand(BaseTransformersCLICommand):
return model, tokenizer
+ # -----------------------------------------------------------------------------------------------------------------
+ # User commands
+ def handle_non_exit_user_commands(
+ self,
+ user_input: str,
+ args: ChatArguments,
+ interface: RichInterface,
+ examples: dict[str, dict[str, str]],
+ generation_config: GenerationConfig,
+ model_kwargs: dict,
+ chat: list[dict],
+ ) -> tuple[list[dict], GenerationConfig, dict]:
+ """
+ Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
+ generation config (e.g. set a new flag).
+ """
+
+ if user_input == "!clear":
+ chat = self.clear_chat_history(args.system_prompt)
+ interface.clear()
+
+ elif user_input == "!help":
+ interface.print_help()
+
+ elif user_input.startswith("!save") and len(user_input.split()) < 2:
+ split_input = user_input.split()
+
+ if len(split_input) == 2:
+ filename = split_input[1]
+ else:
+ filename = None
+ filename = self.save_chat(chat, args, filename)
+ interface.print_color(text=f"Chat saved in {filename}!", color="green")
+
+ elif user_input.startswith("!set"):
+ # splits the new args into a list of strings, each string being a `flag=value` pair (same format as
+ # `generate_flags`)
+ new_generate_flags = user_input[4:].strip()
+ new_generate_flags = new_generate_flags.split()
+ # sanity check: each member in the list must have an =
+ for flag in new_generate_flags:
+ if "=" not in flag:
+ interface.print_color(
+ text=(
+ f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
+ "`arg_1=value_1 arg_2=value_2 ...`."
+ ),
+ color="red",
+ )
+ break
+ else:
+ # parses the new args into a dictionary of `generate` kwargs, and updates the corresponding variables
+ parsed_new_generate_flags = self.parse_generate_flags(new_generate_flags)
+ new_model_kwargs = generation_config.update(**parsed_new_generate_flags)
+ model_kwargs.update(**new_model_kwargs)
+
+ elif user_input.startswith("!example") and len(user_input.split()) == 2:
+ example_name = user_input.split()[1]
+ if example_name in examples:
+ interface.clear()
+ chat = []
+ interface.print_user_message(examples[example_name]["text"])
+ chat.append({"role": "user", "content": examples[example_name]["text"]})
+ else:
+ example_error = (
+ f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
+ )
+ interface.print_color(text=example_error, color="red")
+
+ elif user_input == "!status":
+ interface.print_status(
+ model_name=args.model_name_or_path_positional,
+ generation_config=generation_config,
+ model_kwargs=model_kwargs,
+ )
+
+ else:
+ interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
+ interface.print_help()
+
+ return chat, generation_config, model_kwargs
+
# -----------------------------------------------------------------------------------------------------------------
# Main logic
def run(self):
@@ -498,8 +669,6 @@ class ChatCommand(BaseTransformersCLICommand):
with open(args.examples_path) as f:
examples = yaml.safe_load(f)
- current_args = copy.deepcopy(args)
-
if args.user is None:
user = self.get_username()
else:
@@ -507,12 +676,11 @@ class ChatCommand(BaseTransformersCLICommand):
model, tokenizer = self.load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
+ generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
- pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
-
- interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
+ interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
interface.clear()
- chat = self.clear_chat_history(current_args.system_prompt)
+ chat = self.clear_chat_history(args.system_prompt)
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
interface.print_help(minimal=True)
@@ -520,57 +688,26 @@ class ChatCommand(BaseTransformersCLICommand):
try:
user_input = interface.input()
- if user_input == "clear":
- chat = self.clear_chat_history(current_args.system_prompt)
- interface.clear()
- continue
-
- if user_input == "help":
- interface.print_help()
- continue
-
- if user_input == "exit":
- break
-
- if user_input == "reset":
- interface.clear()
- current_args = copy.deepcopy(args)
- chat = self.clear_chat_history(current_args.system_prompt)
- continue
-
- if user_input.startswith("save") and len(user_input.split()) < 2:
- split_input = user_input.split()
-
- if len(split_input) == 2:
- filename = split_input[1]
+ # User commands
+ if user_input.startswith("!"):
+ # `!exit` is special, it breaks the loop
+ if user_input == "!exit":
+ break
else:
- filename = None
- filename = self.save_chat(chat, current_args, filename)
- interface.print_color(text=f"Chat saved in {filename}!", color="green")
- continue
-
- if self.is_valid_setting_command(user_input):
- current_args, success = self.parse_settings(user_input, current_args, interface)
- if success:
- chat = []
- interface.clear()
- continue
-
- if user_input.startswith("example") and len(user_input.split()) == 2:
- example_name = user_input.split()[1]
- if example_name in examples:
- interface.clear()
- chat = []
- interface.print_user_message(examples[example_name]["text"])
- user_input = examples[example_name]["text"]
- else:
- example_error = (
- f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
+ chat, generation_config, model_kwargs = self.handle_non_exit_user_commands(
+ user_input=user_input,
+ args=args,
+ interface=interface,
+ examples=examples,
+ generation_config=generation_config,
+ model_kwargs=model_kwargs,
+ chat=chat,
)
- interface.print_color(text=example_error, color="red")
+ # `!example` sends a user message to the model
+ if not user_input.startswith("!example"):
continue
-
- chat.append({"role": "user", "content": user_input})
+ else:
+ chat.append({"role": "user", "content": user_input})
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
@@ -580,15 +717,8 @@ class ChatCommand(BaseTransformersCLICommand):
"inputs": inputs,
"attention_mask": attention_mask,
"streamer": generation_streamer,
- "max_new_tokens": current_args.max_new_tokens,
- "do_sample": current_args.do_sample,
- "num_beams": current_args.num_beams,
- "temperature": current_args.temperature,
- "top_k": current_args.top_k,
- "top_p": current_args.top_p,
- "repetition_penalty": current_args.repetition_penalty,
- "pad_token_id": pad_token_id,
- "eos_token_id": eos_token_ids,
+ "generation_config": generation_config,
+ **model_kwargs,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)