diff --git a/README.md b/README.md index 3ebe4e9134..d8a711b48e 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ To chat with a model, the usage pattern is the same. The only difference is you > [!TIP] > You can also chat with a model directly from the command line. > ```shell -> transformers chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct +> transformers chat Qwen/Qwen2.5-0.5B-Instruct > ``` ```py diff --git a/docs/source/en/conversations.md b/docs/source/en/conversations.md index b102f0c09d..94bb9fd591 100644 --- a/docs/source/en/conversations.md +++ b/docs/source/en/conversations.md @@ -27,7 +27,7 @@ This guide shows you how to quickly start chatting with Transformers from the co ## transformers CLI -Chat with a model directly from the command line as shown below. It launches an interactive session with a model. Enter `clear` to reset the conversation, `exit` to terminate the session, and `help` to display all the command options. +After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session. ```bash transformers chat Qwen/Qwen2.5-0.5B-Instruct @@ -37,6 +37,12 @@ transformers chat Qwen/Qwen2.5-0.5B-Instruct +You can launch the CLI with arbitrary `generate` flags, with the format `arg_1=value_1 arg_2=value_2 ...` + +```bash +transformers chat Qwen/Qwen2.5-0.5B-Instruct do_sample=False max_new_tokens=10 +``` + For a full list of options, run the command below. ```bash diff --git a/docs/source/en/llm_tutorial.md b/docs/source/en/llm_tutorial.md index d867657202..a191cdb463 100644 --- a/docs/source/en/llm_tutorial.md +++ b/docs/source/en/llm_tutorial.md @@ -20,9 +20,13 @@ rendered properly in your Markdown viewer. Text generation is the most popular application for large language models (LLMs). A LLM is trained to generate the next word (token) given some initial text (prompt) along with its own generated outputs up to a predefined length or when it reaches an end-of-sequence (`EOS`) token. -In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities. +In Transformers, the [`~GenerationMixin.generate`] API handles text generation, and it is available for all models with generative capabilities. This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid. -This guide will show you the basics of text generation with [`~GenerationMixin.generate`] and some common pitfalls to avoid. +> [!TIP] +> You can also chat with a model directly from the command line. ([reference](./conversations.md#transformers-cli)) +> ```shell +> transformers chat Qwen/Qwen2.5-0.5B-Instruct +> ``` ## Default generate @@ -134,6 +138,20 @@ outputs = model.generate(**inputs, generation_config=generation_config) print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) ``` +## Common Options + +[`~GenerationMixin.generate`] is a powerful tool that can be heavily customized. This can be daunting for a new users. This section contains a list of popular generation options that you can define in most text generation tools in Transformers: [`~GenerationMixin.generate`], [`GenerationConfig`], `pipelines`, the `chat` CLI, ... + +| Option name | Type | Simplified description | +|---|---|---| +| `max_new_tokens` | `int` | Controls the maximum generation length. Be sure to define it, as it usually defaults to a small value. | +| `do_sample` | `bool` | Defines whether generation will sample the next token (`True`), or is greedy instead (`False`). Most use cases should set this flag to `True`. Check [this guide](./generation_strategies.md) for more information. | +| `temperature` | `float` | How unpredictable the next selected token will be. High values (`>0.8`) are good for creative tasks, low values (e.g. `<0.4`) for tasks that require "thinking". Requires `do_sample=True`. | +| `num_beams` | `int` | When set to `>1`, activates the beam search algorithm. Beam search is good on input-grounded tasks. Check [this guide](./generation_strategies.md) for more information. | +| `repetition_penalty` | `float` | Set it to `>1.0` if you're seeing the model repeat itself often. Larger values apply a larger penalty. | +| `eos_token_id` | `List[int]` | The token(s) that will cause generation to stop. The default value is usually good, but you can specify a different token. | + + ## Pitfalls The section below covers some common issues you may encounter during text generation and how to solve them. @@ -286,4 +304,4 @@ Take a look below for some more specific and specialized text generation librari - [SynCode](https://github.com/uiuc-focal-lab/syncode): a library for context-free grammar guided generation (JSON, SQL, Python). - [Text Generation Inference](https://github.com/huggingface/text-generation-inference): a production-ready server for LLMs. - [Text generation web UI](https://github.com/oobabooga/text-generation-webui): a Gradio web UI for text generation. -- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation. \ No newline at end of file +- [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo): additional logits processors for controlling text generation. diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index f276368ce8..5c9bd76bdb 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -13,12 +13,12 @@ # limitations under the License. -import copy import json import os import platform import string import time +import warnings from argparse import ArgumentParser, Namespace from dataclasses import dataclass, field from threading import Thread @@ -42,7 +42,13 @@ if is_rich_available(): if is_torch_available(): import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + GenerationConfig, + TextIteratorStreamer, + ) ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace) @@ -64,25 +70,16 @@ DEFAULT_EXAMPLES = { "socks": {"text": "Why is it important to eat socks after meditating?"}, } -SUPPORTED_GENERATION_KWARGS = [ - "max_new_tokens", - "do_sample", - "num_beams", - "temperature", - "top_p", - "top_k", - "repetition_penalty", -] - # Printed at the start of a chat session HELP_STRING_MINIMAL = """ **TRANSFORMERS CHAT INTERFACE** Chat interface to try out a model. Besides chatting with the model, here are some basic commands: -- **help**: shows all available commands -- **clear**: clears the current conversation and starts a new one -- **exit**: closes the interface +- **!help**: shows all available commands +- **!status**: shows the current status of the model and generation settings +- **!clear**: clears the current conversation and starts a new one +- **!exit**: closes the interface """ @@ -92,18 +89,32 @@ HELP_STRING = f""" **TRANSFORMERS CHAT INTERFACE HELP** Full command list: -- **help**: shows this help message -- **clear**: clears the current conversation and starts a new one -- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example -names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}` -- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are -separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}` -- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set** -- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to +- **!help**: shows this help message +- **!clear**: clears the current conversation and starts a new one +- **!status**: shows the current status of the model and generation settings +- **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. +Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}` +- **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple +settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument. +If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options +- **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to `./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided -- **exit**: closes the interface +- **!exit**: closes the interface """ +# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag) +_DEPRECATION_MAP = [ + ("max_new_tokens", 256, "max_new_tokens"), + ("do_sample", True, "do_sample"), + ("num_beams", 1, "num_beams"), + ("temperature", 1.0, "temperature"), + ("top_k", 50, "top_k"), + ("top_p", 1.0, "top_p"), + ("repetition_penalty", 1.0, "repetition_penalty"), + ("eos_tokens", None, "eos_token_id"), + ("eos_token_ids", None, "eos_token_id"), +] + class RichInterface: def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None): @@ -181,6 +192,14 @@ class RichInterface: self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING)) self._console.print() + def print_status(self, model_name: str, generation_config: GenerationConfig, model_kwargs: dict): + """Prints the status of the model and generation settings to the console.""" + self._console.print(f"[bold blue]Model: {model_name}\n") + if model_kwargs: + self._console.print(f"[bold blue]Model kwargs: {model_kwargs}") + self._console.print(f"[bold blue]{generation_config}") + self._console.print() + @dataclass class ChatArguments: @@ -207,6 +226,17 @@ class ChatArguments: examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."}) # Generation settings + generation_config: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to a local generation config file or to a HuggingFace repo containing a " + "`generation_config.json` file. Other generation settings passed as CLI arguments will be applied on " + "top of this generation config." + ), + }, + ) + # Deprecated CLI args start here max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."}) do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."}) num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."}) @@ -222,6 +252,7 @@ class ChatArguments: default=None, metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."}, ) + # Deprecated CLI args end here # Model loading model_revision: str = field( @@ -280,23 +311,66 @@ class ChatCommand(BaseTransformersCLICommand): group = chat_parser.add_argument_group("Positional arguments") group.add_argument( - "model_name_or_path_positional", type=str, nargs="?", default=None, help="Name of the pre-trained model." + "model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model." + ) + group.add_argument( + "generate_flags", + type=str, + default=None, + help=( + "Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, " + "and lists of integers, more advanced parameterization should be set through --generation-config. " + "Example: `transformers chat max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. " + "If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options" + ), + nargs="*", ) - chat_parser.set_defaults(func=chat_command_factory) def __init__(self, args): - args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path + args = self._handle_deprecated_args(args) + self.args = args - if args.model_name_or_path is None: + def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments: + """ + Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new + args. + """ + has_warnings = False + + # 1. Model as a positional argument + args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path + if args.model_name_or_path_positional is None: raise ValueError( "One of the following must be provided:" - "\n- The positional argument containing the model repo;" - "\n- the optional --model_name_or_path argument, containing the model repo" - "\ne.g. transformers chat or transformers chat --model_name_or_path " + "\n- The positional argument containing the model repo, e.g. `transformers chat `" + "\n- the optional --model_name_or_path argument, containing the model repo (deprecated)" ) + elif args.model_name_or_path is not None: + has_warnings = True + warnings.warn( + "The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional " + "argument instead, e.g. `transformers chat `.", + FutureWarning, + ) + # 2. Named generate option args + for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP: + value = getattr(args, deprecated_arg) + if value != default_value: + has_warnings = True + warnings.warn( + f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two " + "alternative solutions to specify this generation option: \n" + "1. Pass `--generation-config ` to specify a generation config.\n" + "2. Pass `generate` flags through positional arguments, e.g. `transformers chat " + f"{new_arg}={value}`", + FutureWarning, + ) - self.args = args + if has_warnings: + print("\n(Press enter to continue)") + input() + return args # ----------------------------------------------------------------------------------------------------------------- # Chat session methods @@ -319,7 +393,7 @@ class ChatCommand(BaseTransformersCLICommand): if filename is None: time_str = time.strftime("%Y-%m-%d_%H-%M-%S") - filename = f"{args.model_name_or_path}/chat_{time_str}.json" + filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json" filename = os.path.join(folder, filename) os.makedirs(os.path.dirname(filename), exist_ok=True) @@ -338,50 +412,95 @@ class ChatCommand(BaseTransformersCLICommand): # ----------------------------------------------------------------------------------------------------------------- # Input parsing methods - @staticmethod - def parse_settings( - user_input: str, current_args: ChatArguments, interface: RichInterface - ) -> tuple[ChatArguments, bool]: - """Parses the settings from the user input into the CLI arguments.""" - settings = user_input[4:].strip().split(";") - settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings] - settings = dict(settings) - error = False + def parse_generate_flags(self, generate_flags: list[str]) -> dict: + """Parses the generate flags from the user input into a dictionary of `generate` kwargs.""" + if len(generate_flags) == 0: + return {} - for name in settings: - if hasattr(current_args, name): - try: - if isinstance(getattr(current_args, name), bool): - if settings[name] == "True": - settings[name] = True - elif settings[name] == "False": - settings[name] = False - else: - raise ValueError - else: - settings[name] = type(getattr(current_args, name))(settings[name]) - except ValueError: - error = True - interface.print_color( - text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.", - color="red", - ) - else: - interface.print_color(text=f"There is no '{name}' setting.", color="red") + # Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed + # into a json string if we: + # 1. Add quotes around each flag name + generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags} - if error: - interface.print_color( - text="There was an issue parsing the settings. No settings have been changed.", - color="red", + # 2. Handle types: + # 2. a. booleans should be lowercase, None should be null + generate_flags_as_dict = { + k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items() + } + generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()} + + # 2. b. strings should be quoted + def is_number(s: str) -> bool: + return s.replace(".", "", 1).isdigit() + + generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()} + # 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :) + # We also mention in the help message that we only accept lists of ints for now. + + # 3. Join the the result into a comma separated string + generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()]) + + # 4. Add the opening/closing brackets + generate_flags_string = "{" + generate_flags_string + "}" + + # 5. Remove quotes around boolean/null and around lists + generate_flags_string = generate_flags_string.replace('"null"', "null") + generate_flags_string = generate_flags_string.replace('"true"', "true") + generate_flags_string = generate_flags_string.replace('"false"', "false") + generate_flags_string = generate_flags_string.replace('"[', "[") + generate_flags_string = generate_flags_string.replace(']"', "]") + + # 6. Replace the `=` with `:` + generate_flags_string = generate_flags_string.replace("=", ":") + + try: + processed_generate_flags = json.loads(generate_flags_string) + except json.JSONDecodeError: + raise ValueError( + "Failed to convert `generate_flags` into a valid JSON object." + "\n`generate_flags` = {generate_flags}" + "\nConverted JSON string = {generate_flags_string}" ) + return processed_generate_flags + + def get_generation_parameterization( + self, args: ChatArguments, tokenizer: AutoTokenizer + ) -> tuple[GenerationConfig, dict]: + """ + Returns a GenerationConfig object holding the generation parameters for the CLI command. + """ + # No generation config arg provided -> use base generation config, apply CLI defaults + if args.generation_config is None: + generation_config = GenerationConfig() + # Apply deprecated CLI args on top of the default generation config + pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) + deprecated_kwargs = { + "max_new_tokens": args.max_new_tokens, + "do_sample": args.do_sample, + "num_beams": args.num_beams, + "temperature": args.temperature, + "top_k": args.top_k, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "pad_token_id": pad_token_id, + "eos_token_id": eos_token_ids, + } + generation_config.update(**deprecated_kwargs) + # generation config arg provided -> use it as the base parameterization else: - for name in settings: - setattr(current_args, name, settings[name]) - interface.print_color(text=f"Set {name} to {settings[name]}.", color="green") + if ".json" in args.generation_config: # is a local file + dirname = os.path.dirname(args.generation_config) + filename = os.path.basename(args.generation_config) + generation_config = GenerationConfig.from_pretrained(dirname, filename) + else: + generation_config = GenerationConfig.from_pretrained(args.generation_config) - time.sleep(1.5) # so the user has time to read the changes - - return current_args, not error + # Finally: parse and apply `generate_flags` + parsed_generate_flags = self.parse_generate_flags(args.generate_flags) + model_kwargs = generation_config.update(**parsed_generate_flags) + # `model_kwargs` contain non-generation flags in `parsed_generate_flags` that should be passed directly to + # `generate` + return generation_config, model_kwargs @staticmethod def parse_eos_tokens( @@ -406,36 +525,6 @@ class ChatCommand(BaseTransformersCLICommand): return pad_token_id, all_eos_token_ids - @staticmethod - def is_valid_setting_command(s: str) -> bool: - # First check the basic structure - if not s.startswith("set ") or "=" not in s: - return False - - # Split into individual assignments - assignments = [a.strip() for a in s[4:].split(";") if a.strip()] - - for assignment in assignments: - # Each assignment should have exactly one '=' - if assignment.count("=") != 1: - return False - - key, value = assignment.split("=", 1) - key = key.strip() - value = value.strip() - if not key or not value: - return False - - # Keys can only have alphabetic characters, spaces and underscores - if not set(key).issubset(ALLOWED_KEY_CHARS): - return False - - # Values can have just about anything that isn't a semicolon - if not set(value).issubset(ALLOWED_VALUE_CHARS): - return False - - return True - # ----------------------------------------------------------------------------------------------------------------- # Model loading and performance automation methods @staticmethod @@ -460,7 +549,7 @@ class ChatCommand(BaseTransformersCLICommand): def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]: tokenizer = AutoTokenizer.from_pretrained( - args.model_name_or_path, + args.model_name_or_path_positional, revision=args.model_revision, trust_remote_code=args.trust_remote_code, ) @@ -475,7 +564,7 @@ class ChatCommand(BaseTransformersCLICommand): "quantization_config": quantization_config, } model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs + args.model_name_or_path_positional, trust_remote_code=args.trust_remote_code, **model_kwargs ) if getattr(model, "hf_device_map", None) is None: @@ -483,6 +572,88 @@ class ChatCommand(BaseTransformersCLICommand): return model, tokenizer + # ----------------------------------------------------------------------------------------------------------------- + # User commands + def handle_non_exit_user_commands( + self, + user_input: str, + args: ChatArguments, + interface: RichInterface, + examples: dict[str, dict[str, str]], + generation_config: GenerationConfig, + model_kwargs: dict, + chat: list[dict], + ) -> tuple[list[dict], GenerationConfig, dict]: + """ + Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the + generation config (e.g. set a new flag). + """ + + if user_input == "!clear": + chat = self.clear_chat_history(args.system_prompt) + interface.clear() + + elif user_input == "!help": + interface.print_help() + + elif user_input.startswith("!save") and len(user_input.split()) < 2: + split_input = user_input.split() + + if len(split_input) == 2: + filename = split_input[1] + else: + filename = None + filename = self.save_chat(chat, args, filename) + interface.print_color(text=f"Chat saved in {filename}!", color="green") + + elif user_input.startswith("!set"): + # splits the new args into a list of strings, each string being a `flag=value` pair (same format as + # `generate_flags`) + new_generate_flags = user_input[4:].strip() + new_generate_flags = new_generate_flags.split() + # sanity check: each member in the list must have an = + for flag in new_generate_flags: + if "=" not in flag: + interface.print_color( + text=( + f"Invalid flag format, missing `=` after `{flag}`. Please use the format " + "`arg_1=value_1 arg_2=value_2 ...`." + ), + color="red", + ) + break + else: + # parses the new args into a dictionary of `generate` kwargs, and updates the corresponding variables + parsed_new_generate_flags = self.parse_generate_flags(new_generate_flags) + new_model_kwargs = generation_config.update(**parsed_new_generate_flags) + model_kwargs.update(**new_model_kwargs) + + elif user_input.startswith("!example") and len(user_input.split()) == 2: + example_name = user_input.split()[1] + if example_name in examples: + interface.clear() + chat = [] + interface.print_user_message(examples[example_name]["text"]) + chat.append({"role": "user", "content": examples[example_name]["text"]}) + else: + example_error = ( + f"Example {example_name} not found in list of available examples: {list(examples.keys())}." + ) + interface.print_color(text=example_error, color="red") + + elif user_input == "!status": + interface.print_status( + model_name=args.model_name_or_path_positional, + generation_config=generation_config, + model_kwargs=model_kwargs, + ) + + else: + interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red") + interface.print_help() + + return chat, generation_config, model_kwargs + # ----------------------------------------------------------------------------------------------------------------- # Main logic def run(self): @@ -498,8 +669,6 @@ class ChatCommand(BaseTransformersCLICommand): with open(args.examples_path) as f: examples = yaml.safe_load(f) - current_args = copy.deepcopy(args) - if args.user is None: user = self.get_username() else: @@ -507,12 +676,11 @@ class ChatCommand(BaseTransformersCLICommand): model, tokenizer = self.load_model_and_tokenizer(args) generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) + generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer) - pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) - - interface = RichInterface(model_name=args.model_name_or_path, user_name=user) + interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user) interface.clear() - chat = self.clear_chat_history(current_args.system_prompt) + chat = self.clear_chat_history(args.system_prompt) # Starts the session with a minimal help message at the top, so that a user doesn't get stuck interface.print_help(minimal=True) @@ -520,57 +688,26 @@ class ChatCommand(BaseTransformersCLICommand): try: user_input = interface.input() - if user_input == "clear": - chat = self.clear_chat_history(current_args.system_prompt) - interface.clear() - continue - - if user_input == "help": - interface.print_help() - continue - - if user_input == "exit": - break - - if user_input == "reset": - interface.clear() - current_args = copy.deepcopy(args) - chat = self.clear_chat_history(current_args.system_prompt) - continue - - if user_input.startswith("save") and len(user_input.split()) < 2: - split_input = user_input.split() - - if len(split_input) == 2: - filename = split_input[1] + # User commands + if user_input.startswith("!"): + # `!exit` is special, it breaks the loop + if user_input == "!exit": + break else: - filename = None - filename = self.save_chat(chat, current_args, filename) - interface.print_color(text=f"Chat saved in {filename}!", color="green") - continue - - if self.is_valid_setting_command(user_input): - current_args, success = self.parse_settings(user_input, current_args, interface) - if success: - chat = [] - interface.clear() - continue - - if user_input.startswith("example") and len(user_input.split()) == 2: - example_name = user_input.split()[1] - if example_name in examples: - interface.clear() - chat = [] - interface.print_user_message(examples[example_name]["text"]) - user_input = examples[example_name]["text"] - else: - example_error = ( - f"Example {example_name} not found in list of available examples: {list(examples.keys())}." + chat, generation_config, model_kwargs = self.handle_non_exit_user_commands( + user_input=user_input, + args=args, + interface=interface, + examples=examples, + generation_config=generation_config, + model_kwargs=model_kwargs, + chat=chat, ) - interface.print_color(text=example_error, color="red") + # `!example` sends a user message to the model + if not user_input.startswith("!example"): continue - - chat.append({"role": "user", "content": user_input}) + else: + chat.append({"role": "user", "content": user_input}) inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( model.device @@ -580,15 +717,8 @@ class ChatCommand(BaseTransformersCLICommand): "inputs": inputs, "attention_mask": attention_mask, "streamer": generation_streamer, - "max_new_tokens": current_args.max_new_tokens, - "do_sample": current_args.do_sample, - "num_beams": current_args.num_beams, - "temperature": current_args.temperature, - "top_k": current_args.top_k, - "top_p": current_args.top_p, - "repetition_penalty": current_args.repetition_penalty, - "pad_token_id": pad_token_id, - "eos_token_id": eos_token_ids, + "generation_config": generation_config, + **model_kwargs, } thread = Thread(target=model.generate, kwargs=generation_kwargs)