|
|
|
@@ -13,12 +13,12 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import platform
|
|
|
|
import platform
|
|
|
|
import string
|
|
|
|
import string
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from threading import Thread
|
|
|
|
from threading import Thread
|
|
|
|
@@ -42,7 +42,13 @@ if is_rich_available():
|
|
|
|
if is_torch_available():
|
|
|
|
if is_torch_available():
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
|
|
|
from transformers import (
|
|
|
|
|
|
|
|
AutoModelForCausalLM,
|
|
|
|
|
|
|
|
AutoTokenizer,
|
|
|
|
|
|
|
|
BitsAndBytesConfig,
|
|
|
|
|
|
|
|
GenerationConfig,
|
|
|
|
|
|
|
|
TextIteratorStreamer,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
|
|
|
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
|
|
|
@@ -64,25 +70,16 @@ DEFAULT_EXAMPLES = {
|
|
|
|
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
|
|
|
"socks": {"text": "Why is it important to eat socks after meditating?"},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
SUPPORTED_GENERATION_KWARGS = [
|
|
|
|
|
|
|
|
"max_new_tokens",
|
|
|
|
|
|
|
|
"do_sample",
|
|
|
|
|
|
|
|
"num_beams",
|
|
|
|
|
|
|
|
"temperature",
|
|
|
|
|
|
|
|
"top_p",
|
|
|
|
|
|
|
|
"top_k",
|
|
|
|
|
|
|
|
"repetition_penalty",
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Printed at the start of a chat session
|
|
|
|
# Printed at the start of a chat session
|
|
|
|
HELP_STRING_MINIMAL = """
|
|
|
|
HELP_STRING_MINIMAL = """
|
|
|
|
|
|
|
|
|
|
|
|
**TRANSFORMERS CHAT INTERFACE**
|
|
|
|
**TRANSFORMERS CHAT INTERFACE**
|
|
|
|
|
|
|
|
|
|
|
|
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
|
|
|
|
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
|
|
|
|
- **help**: shows all available commands
|
|
|
|
- **!help**: shows all available commands
|
|
|
|
- **clear**: clears the current conversation and starts a new one
|
|
|
|
- **!status**: shows the current status of the model and generation settings
|
|
|
|
- **exit**: closes the interface
|
|
|
|
- **!clear**: clears the current conversation and starts a new one
|
|
|
|
|
|
|
|
- **!exit**: closes the interface
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -92,18 +89,32 @@ HELP_STRING = f"""
|
|
|
|
**TRANSFORMERS CHAT INTERFACE HELP**
|
|
|
|
**TRANSFORMERS CHAT INTERFACE HELP**
|
|
|
|
|
|
|
|
|
|
|
|
Full command list:
|
|
|
|
Full command list:
|
|
|
|
- **help**: shows this help message
|
|
|
|
- **!help**: shows this help message
|
|
|
|
- **clear**: clears the current conversation and starts a new one
|
|
|
|
- **!clear**: clears the current conversation and starts a new one
|
|
|
|
- **example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input. Available example
|
|
|
|
- **!status**: shows the current status of the model and generation settings
|
|
|
|
names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
|
|
|
- **!example {{NAME}}**: loads example named `{{NAME}}` from the config and uses it as the user input.
|
|
|
|
- **set {{SETTING_NAME}}={{SETTING_VALUE}};**: changes the system prompt or generation settings (multiple settings are
|
|
|
|
Available example names: `{"`, `".join(DEFAULT_EXAMPLES.keys())}`
|
|
|
|
separated by a ';'). Available settings: `{"`, `".join(SUPPORTED_GENERATION_KWARGS)}`
|
|
|
|
- **!set {{ARG_1}}={{VALUE_1}} {{ARG_2}}={{VALUE_2}}** ...: changes the system prompt or generation settings (multiple
|
|
|
|
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
|
|
|
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
|
|
|
|
- **save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
|
|
|
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
|
|
|
|
|
|
|
|
- **!save {{SAVE_NAME}} (optional)**: saves the current chat and settings to file by default to
|
|
|
|
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
|
|
|
|
`./chat_history/{{MODEL_NAME}}/chat_{{DATETIME}}.yaml` or `{{SAVE_NAME}}` if provided
|
|
|
|
- **exit**: closes the interface
|
|
|
|
- **!exit**: closes the interface
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
|
|
|
|
|
|
|
|
_DEPRECATION_MAP = [
|
|
|
|
|
|
|
|
("max_new_tokens", 256, "max_new_tokens"),
|
|
|
|
|
|
|
|
("do_sample", True, "do_sample"),
|
|
|
|
|
|
|
|
("num_beams", 1, "num_beams"),
|
|
|
|
|
|
|
|
("temperature", 1.0, "temperature"),
|
|
|
|
|
|
|
|
("top_k", 50, "top_k"),
|
|
|
|
|
|
|
|
("top_p", 1.0, "top_p"),
|
|
|
|
|
|
|
|
("repetition_penalty", 1.0, "repetition_penalty"),
|
|
|
|
|
|
|
|
("eos_tokens", None, "eos_token_id"),
|
|
|
|
|
|
|
|
("eos_token_ids", None, "eos_token_id"),
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RichInterface:
|
|
|
|
class RichInterface:
|
|
|
|
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
|
|
|
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
|
|
|
@@ -181,6 +192,14 @@ class RichInterface:
|
|
|
|
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
|
|
|
|
self._console.print(Markdown(HELP_STRING_MINIMAL if minimal else HELP_STRING))
|
|
|
|
self._console.print()
|
|
|
|
self._console.print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_status(self, model_name: str, generation_config: GenerationConfig, model_kwargs: dict):
|
|
|
|
|
|
|
|
"""Prints the status of the model and generation settings to the console."""
|
|
|
|
|
|
|
|
self._console.print(f"[bold blue]Model: {model_name}\n")
|
|
|
|
|
|
|
|
if model_kwargs:
|
|
|
|
|
|
|
|
self._console.print(f"[bold blue]Model kwargs: {model_kwargs}")
|
|
|
|
|
|
|
|
self._console.print(f"[bold blue]{generation_config}")
|
|
|
|
|
|
|
|
self._console.print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
|
class ChatArguments:
|
|
|
|
class ChatArguments:
|
|
|
|
@@ -207,6 +226,17 @@ class ChatArguments:
|
|
|
|
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
|
|
|
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
|
|
|
|
|
|
|
|
|
|
|
|
# Generation settings
|
|
|
|
# Generation settings
|
|
|
|
|
|
|
|
generation_config: Optional[str] = field(
|
|
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
|
|
metadata={
|
|
|
|
|
|
|
|
"help": (
|
|
|
|
|
|
|
|
"Path to a local generation config file or to a HuggingFace repo containing a "
|
|
|
|
|
|
|
|
"`generation_config.json` file. Other generation settings passed as CLI arguments will be applied on "
|
|
|
|
|
|
|
|
"top of this generation config."
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
# Deprecated CLI args start here
|
|
|
|
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
|
|
|
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
|
|
|
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
|
|
|
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
|
|
|
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
|
|
|
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
|
|
|
@@ -222,6 +252,7 @@ class ChatArguments:
|
|
|
|
default=None,
|
|
|
|
default=None,
|
|
|
|
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
|
|
|
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Deprecated CLI args end here
|
|
|
|
|
|
|
|
|
|
|
|
# Model loading
|
|
|
|
# Model loading
|
|
|
|
model_revision: str = field(
|
|
|
|
model_revision: str = field(
|
|
|
|
@@ -280,23 +311,66 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
|
|
|
|
|
|
|
|
group = chat_parser.add_argument_group("Positional arguments")
|
|
|
|
group = chat_parser.add_argument_group("Positional arguments")
|
|
|
|
group.add_argument(
|
|
|
|
group.add_argument(
|
|
|
|
"model_name_or_path_positional", type=str, nargs="?", default=None, help="Name of the pre-trained model."
|
|
|
|
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
group.add_argument(
|
|
|
|
|
|
|
|
"generate_flags",
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
|
|
help=(
|
|
|
|
|
|
|
|
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
|
|
|
|
|
|
|
|
"and lists of integers, more advanced parameterization should be set through --generation-config. "
|
|
|
|
|
|
|
|
"Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
|
|
|
|
|
|
|
|
"If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
nargs="*",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
chat_parser.set_defaults(func=chat_command_factory)
|
|
|
|
chat_parser.set_defaults(func=chat_command_factory)
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, args):
|
|
|
|
def __init__(self, args):
|
|
|
|
args.model_name_or_path = args.model_name_or_path_positional or args.model_name_or_path
|
|
|
|
args = self._handle_deprecated_args(args)
|
|
|
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
|
|
if args.model_name_or_path is None:
|
|
|
|
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
|
|
|
|
|
|
|
|
args.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
has_warnings = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Model as a positional argument
|
|
|
|
|
|
|
|
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
|
|
|
|
|
|
|
|
if args.model_name_or_path_positional is None:
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|
"One of the following must be provided:"
|
|
|
|
"One of the following must be provided:"
|
|
|
|
"\n- The positional argument containing the model repo;"
|
|
|
|
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
|
|
|
|
"\n- the optional --model_name_or_path argument, containing the model repo"
|
|
|
|
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
|
|
|
|
"\ne.g. transformers chat <model_repo> or transformers chat --model_name_or_path <model_repo>"
|
|
|
|
)
|
|
|
|
|
|
|
|
elif args.model_name_or_path is not None:
|
|
|
|
|
|
|
|
has_warnings = True
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
|
|
|
|
|
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
|
|
|
|
|
|
|
|
"argument instead, e.g. `transformers chat <model_repo>`.",
|
|
|
|
|
|
|
|
FutureWarning,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
# 2. Named generate option args
|
|
|
|
|
|
|
|
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
|
|
|
|
|
|
|
|
value = getattr(args, deprecated_arg)
|
|
|
|
|
|
|
|
if value != default_value:
|
|
|
|
|
|
|
|
has_warnings = True
|
|
|
|
|
|
|
|
warnings.warn(
|
|
|
|
|
|
|
|
f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
|
|
|
|
|
|
|
|
"alternative solutions to specify this generation option: \n"
|
|
|
|
|
|
|
|
"1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.\n"
|
|
|
|
|
|
|
|
"2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> "
|
|
|
|
|
|
|
|
f"{new_arg}={value}`",
|
|
|
|
|
|
|
|
FutureWarning,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.args = args
|
|
|
|
if has_warnings:
|
|
|
|
|
|
|
|
print("\n(Press enter to continue)")
|
|
|
|
|
|
|
|
input()
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# Chat session methods
|
|
|
|
# Chat session methods
|
|
|
|
@@ -319,7 +393,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
|
|
|
|
|
|
|
|
if filename is None:
|
|
|
|
if filename is None:
|
|
|
|
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
filename = f"{args.model_name_or_path}/chat_{time_str}.json"
|
|
|
|
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
|
|
|
|
filename = os.path.join(folder, filename)
|
|
|
|
filename = os.path.join(folder, filename)
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
|
|
@@ -338,50 +412,95 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# Input parsing methods
|
|
|
|
# Input parsing methods
|
|
|
|
@staticmethod
|
|
|
|
def parse_generate_flags(self, generate_flags: list[str]) -> dict:
|
|
|
|
def parse_settings(
|
|
|
|
"""Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
|
|
|
|
user_input: str, current_args: ChatArguments, interface: RichInterface
|
|
|
|
if len(generate_flags) == 0:
|
|
|
|
) -> tuple[ChatArguments, bool]:
|
|
|
|
return {}
|
|
|
|
"""Parses the settings from the user input into the CLI arguments."""
|
|
|
|
|
|
|
|
settings = user_input[4:].strip().split(";")
|
|
|
|
# Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed
|
|
|
|
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
|
|
|
# into a json string if we:
|
|
|
|
settings = dict(settings)
|
|
|
|
# 1. Add quotes around each flag name
|
|
|
|
error = False
|
|
|
|
generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Handle types:
|
|
|
|
|
|
|
|
# 2. a. booleans should be lowercase, None should be null
|
|
|
|
|
|
|
|
generate_flags_as_dict = {
|
|
|
|
|
|
|
|
k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items()
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. b. strings should be quoted
|
|
|
|
|
|
|
|
def is_number(s: str) -> bool:
|
|
|
|
|
|
|
|
return s.replace(".", "", 1).isdigit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
|
|
|
|
|
|
|
|
# 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :)
|
|
|
|
|
|
|
|
# We also mention in the help message that we only accept lists of ints for now.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Join the the result into a comma separated string
|
|
|
|
|
|
|
|
generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Add the opening/closing brackets
|
|
|
|
|
|
|
|
generate_flags_string = "{" + generate_flags_string + "}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 5. Remove quotes around boolean/null and around lists
|
|
|
|
|
|
|
|
generate_flags_string = generate_flags_string.replace('"null"', "null")
|
|
|
|
|
|
|
|
generate_flags_string = generate_flags_string.replace('"true"', "true")
|
|
|
|
|
|
|
|
generate_flags_string = generate_flags_string.replace('"false"', "false")
|
|
|
|
|
|
|
|
generate_flags_string = generate_flags_string.replace('"[', "[")
|
|
|
|
|
|
|
|
generate_flags_string = generate_flags_string.replace(']"', "]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 6. Replace the `=` with `:`
|
|
|
|
|
|
|
|
generate_flags_string = generate_flags_string.replace("=", ":")
|
|
|
|
|
|
|
|
|
|
|
|
for name in settings:
|
|
|
|
|
|
|
|
if hasattr(current_args, name):
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if isinstance(getattr(current_args, name), bool):
|
|
|
|
processed_generate_flags = json.loads(generate_flags_string)
|
|
|
|
if settings[name] == "True":
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
settings[name] = True
|
|
|
|
raise ValueError(
|
|
|
|
elif settings[name] == "False":
|
|
|
|
"Failed to convert `generate_flags` into a valid JSON object."
|
|
|
|
settings[name] = False
|
|
|
|
"\n`generate_flags` = {generate_flags}"
|
|
|
|
else:
|
|
|
|
"\nConverted JSON string = {generate_flags_string}"
|
|
|
|
raise ValueError
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
settings[name] = type(getattr(current_args, name))(settings[name])
|
|
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
|
|
error = True
|
|
|
|
|
|
|
|
interface.print_color(
|
|
|
|
|
|
|
|
text=f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}.",
|
|
|
|
|
|
|
|
color="red",
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return processed_generate_flags
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_generation_parameterization(
|
|
|
|
|
|
|
|
self, args: ChatArguments, tokenizer: AutoTokenizer
|
|
|
|
|
|
|
|
) -> tuple[GenerationConfig, dict]:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# No generation config arg provided -> use base generation config, apply CLI defaults
|
|
|
|
|
|
|
|
if args.generation_config is None:
|
|
|
|
|
|
|
|
generation_config = GenerationConfig()
|
|
|
|
|
|
|
|
# Apply deprecated CLI args on top of the default generation config
|
|
|
|
|
|
|
|
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
|
|
|
|
|
|
|
deprecated_kwargs = {
|
|
|
|
|
|
|
|
"max_new_tokens": args.max_new_tokens,
|
|
|
|
|
|
|
|
"do_sample": args.do_sample,
|
|
|
|
|
|
|
|
"num_beams": args.num_beams,
|
|
|
|
|
|
|
|
"temperature": args.temperature,
|
|
|
|
|
|
|
|
"top_k": args.top_k,
|
|
|
|
|
|
|
|
"top_p": args.top_p,
|
|
|
|
|
|
|
|
"repetition_penalty": args.repetition_penalty,
|
|
|
|
|
|
|
|
"pad_token_id": pad_token_id,
|
|
|
|
|
|
|
|
"eos_token_id": eos_token_ids,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
generation_config.update(**deprecated_kwargs)
|
|
|
|
|
|
|
|
# generation config arg provided -> use it as the base parameterization
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
interface.print_color(text=f"There is no '{name}' setting.", color="red")
|
|
|
|
if ".json" in args.generation_config: # is a local file
|
|
|
|
|
|
|
|
dirname = os.path.dirname(args.generation_config)
|
|
|
|
if error:
|
|
|
|
filename = os.path.basename(args.generation_config)
|
|
|
|
interface.print_color(
|
|
|
|
generation_config = GenerationConfig.from_pretrained(dirname, filename)
|
|
|
|
text="There was an issue parsing the settings. No settings have been changed.",
|
|
|
|
|
|
|
|
color="red",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
for name in settings:
|
|
|
|
generation_config = GenerationConfig.from_pretrained(args.generation_config)
|
|
|
|
setattr(current_args, name, settings[name])
|
|
|
|
|
|
|
|
interface.print_color(text=f"Set {name} to {settings[name]}.", color="green")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time.sleep(1.5) # so the user has time to read the changes
|
|
|
|
# Finally: parse and apply `generate_flags`
|
|
|
|
|
|
|
|
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
|
|
|
return current_args, not error
|
|
|
|
model_kwargs = generation_config.update(**parsed_generate_flags)
|
|
|
|
|
|
|
|
# `model_kwargs` contain non-generation flags in `parsed_generate_flags` that should be passed directly to
|
|
|
|
|
|
|
|
# `generate`
|
|
|
|
|
|
|
|
return generation_config, model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def parse_eos_tokens(
|
|
|
|
def parse_eos_tokens(
|
|
|
|
@@ -406,36 +525,6 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
|
|
|
|
|
|
|
|
return pad_token_id, all_eos_token_ids
|
|
|
|
return pad_token_id, all_eos_token_ids
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def is_valid_setting_command(s: str) -> bool:
|
|
|
|
|
|
|
|
# First check the basic structure
|
|
|
|
|
|
|
|
if not s.startswith("set ") or "=" not in s:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Split into individual assignments
|
|
|
|
|
|
|
|
assignments = [a.strip() for a in s[4:].split(";") if a.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for assignment in assignments:
|
|
|
|
|
|
|
|
# Each assignment should have exactly one '='
|
|
|
|
|
|
|
|
if assignment.count("=") != 1:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key, value = assignment.split("=", 1)
|
|
|
|
|
|
|
|
key = key.strip()
|
|
|
|
|
|
|
|
value = value.strip()
|
|
|
|
|
|
|
|
if not key or not value:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Keys can only have alphabetic characters, spaces and underscores
|
|
|
|
|
|
|
|
if not set(key).issubset(ALLOWED_KEY_CHARS):
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Values can have just about anything that isn't a semicolon
|
|
|
|
|
|
|
|
if not set(value).issubset(ALLOWED_VALUE_CHARS):
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# Model loading and performance automation methods
|
|
|
|
# Model loading and performance automation methods
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
@@ -460,7 +549,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
|
|
|
def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
args.model_name_or_path,
|
|
|
|
args.model_name_or_path_positional,
|
|
|
|
revision=args.model_revision,
|
|
|
|
revision=args.model_revision,
|
|
|
|
trust_remote_code=args.trust_remote_code,
|
|
|
|
trust_remote_code=args.trust_remote_code,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
@@ -475,7 +564,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
"quantization_config": quantization_config,
|
|
|
|
"quantization_config": quantization_config,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
|
|
|
|
args.model_name_or_path_positional, trust_remote_code=args.trust_remote_code, **model_kwargs
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(model, "hf_device_map", None) is None:
|
|
|
|
if getattr(model, "hf_device_map", None) is None:
|
|
|
|
@@ -483,6 +572,88 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
|
|
|
|
|
|
|
|
return model, tokenizer
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
# User commands
|
|
|
|
|
|
|
|
def handle_non_exit_user_commands(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
user_input: str,
|
|
|
|
|
|
|
|
args: ChatArguments,
|
|
|
|
|
|
|
|
interface: RichInterface,
|
|
|
|
|
|
|
|
examples: dict[str, dict[str, str]],
|
|
|
|
|
|
|
|
generation_config: GenerationConfig,
|
|
|
|
|
|
|
|
model_kwargs: dict,
|
|
|
|
|
|
|
|
chat: list[dict],
|
|
|
|
|
|
|
|
) -> tuple[list[dict], GenerationConfig, dict]:
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
|
|
|
|
|
|
|
|
generation config (e.g. set a new flag).
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if user_input == "!clear":
|
|
|
|
|
|
|
|
chat = self.clear_chat_history(args.system_prompt)
|
|
|
|
|
|
|
|
interface.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif user_input == "!help":
|
|
|
|
|
|
|
|
interface.print_help()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif user_input.startswith("!save") and len(user_input.split()) < 2:
|
|
|
|
|
|
|
|
split_input = user_input.split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(split_input) == 2:
|
|
|
|
|
|
|
|
filename = split_input[1]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
filename = None
|
|
|
|
|
|
|
|
filename = self.save_chat(chat, args, filename)
|
|
|
|
|
|
|
|
interface.print_color(text=f"Chat saved in {filename}!", color="green")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif user_input.startswith("!set"):
|
|
|
|
|
|
|
|
# splits the new args into a list of strings, each string being a `flag=value` pair (same format as
|
|
|
|
|
|
|
|
# `generate_flags`)
|
|
|
|
|
|
|
|
new_generate_flags = user_input[4:].strip()
|
|
|
|
|
|
|
|
new_generate_flags = new_generate_flags.split()
|
|
|
|
|
|
|
|
# sanity check: each member in the list must have an =
|
|
|
|
|
|
|
|
for flag in new_generate_flags:
|
|
|
|
|
|
|
|
if "=" not in flag:
|
|
|
|
|
|
|
|
interface.print_color(
|
|
|
|
|
|
|
|
text=(
|
|
|
|
|
|
|
|
f"Invalid flag format, missing `=` after `{flag}`. Please use the format "
|
|
|
|
|
|
|
|
"`arg_1=value_1 arg_2=value_2 ...`."
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
color="red",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# parses the new args into a dictionary of `generate` kwargs, and updates the corresponding variables
|
|
|
|
|
|
|
|
parsed_new_generate_flags = self.parse_generate_flags(new_generate_flags)
|
|
|
|
|
|
|
|
new_model_kwargs = generation_config.update(**parsed_new_generate_flags)
|
|
|
|
|
|
|
|
model_kwargs.update(**new_model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif user_input.startswith("!example") and len(user_input.split()) == 2:
|
|
|
|
|
|
|
|
example_name = user_input.split()[1]
|
|
|
|
|
|
|
|
if example_name in examples:
|
|
|
|
|
|
|
|
interface.clear()
|
|
|
|
|
|
|
|
chat = []
|
|
|
|
|
|
|
|
interface.print_user_message(examples[example_name]["text"])
|
|
|
|
|
|
|
|
chat.append({"role": "user", "content": examples[example_name]["text"]})
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
example_error = (
|
|
|
|
|
|
|
|
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
interface.print_color(text=example_error, color="red")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif user_input == "!status":
|
|
|
|
|
|
|
|
interface.print_status(
|
|
|
|
|
|
|
|
model_name=args.model_name_or_path_positional,
|
|
|
|
|
|
|
|
generation_config=generation_config,
|
|
|
|
|
|
|
|
model_kwargs=model_kwargs,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
|
|
|
|
|
|
|
|
interface.print_help()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return chat, generation_config, model_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# -----------------------------------------------------------------------------------------------------------------
|
|
|
|
# Main logic
|
|
|
|
# Main logic
|
|
|
|
def run(self):
|
|
|
|
def run(self):
|
|
|
|
@@ -498,8 +669,6 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
with open(args.examples_path) as f:
|
|
|
|
with open(args.examples_path) as f:
|
|
|
|
examples = yaml.safe_load(f)
|
|
|
|
examples = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
|
|
|
current_args = copy.deepcopy(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.user is None:
|
|
|
|
if args.user is None:
|
|
|
|
user = self.get_username()
|
|
|
|
user = self.get_username()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
@@ -507,12 +676,11 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
|
|
|
|
|
|
|
|
model, tokenizer = self.load_model_and_tokenizer(args)
|
|
|
|
model, tokenizer = self.load_model_and_tokenizer(args)
|
|
|
|
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
|
|
|
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
|
|
|
|
|
|
|
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
|
|
|
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
|
|
|
|
|
|
|
|
|
|
|
|
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
|
|
|
|
|
|
|
interface.clear()
|
|
|
|
interface.clear()
|
|
|
|
chat = self.clear_chat_history(current_args.system_prompt)
|
|
|
|
chat = self.clear_chat_history(args.system_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
|
|
|
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
|
|
|
interface.print_help(minimal=True)
|
|
|
|
interface.print_help(minimal=True)
|
|
|
|
@@ -520,56 +688,25 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
user_input = interface.input()
|
|
|
|
user_input = interface.input()
|
|
|
|
|
|
|
|
|
|
|
|
if user_input == "clear":
|
|
|
|
# User commands
|
|
|
|
chat = self.clear_chat_history(current_args.system_prompt)
|
|
|
|
if user_input.startswith("!"):
|
|
|
|
interface.clear()
|
|
|
|
# `!exit` is special, it breaks the loop
|
|
|
|
continue
|
|
|
|
if user_input == "!exit":
|
|
|
|
|
|
|
|
|
|
|
|
if user_input == "help":
|
|
|
|
|
|
|
|
interface.print_help()
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if user_input == "exit":
|
|
|
|
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if user_input == "reset":
|
|
|
|
|
|
|
|
interface.clear()
|
|
|
|
|
|
|
|
current_args = copy.deepcopy(args)
|
|
|
|
|
|
|
|
chat = self.clear_chat_history(current_args.system_prompt)
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if user_input.startswith("save") and len(user_input.split()) < 2:
|
|
|
|
|
|
|
|
split_input = user_input.split()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(split_input) == 2:
|
|
|
|
|
|
|
|
filename = split_input[1]
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
filename = None
|
|
|
|
chat, generation_config, model_kwargs = self.handle_non_exit_user_commands(
|
|
|
|
filename = self.save_chat(chat, current_args, filename)
|
|
|
|
user_input=user_input,
|
|
|
|
interface.print_color(text=f"Chat saved in {filename}!", color="green")
|
|
|
|
args=args,
|
|
|
|
continue
|
|
|
|
interface=interface,
|
|
|
|
|
|
|
|
examples=examples,
|
|
|
|
if self.is_valid_setting_command(user_input):
|
|
|
|
generation_config=generation_config,
|
|
|
|
current_args, success = self.parse_settings(user_input, current_args, interface)
|
|
|
|
model_kwargs=model_kwargs,
|
|
|
|
if success:
|
|
|
|
chat=chat,
|
|
|
|
chat = []
|
|
|
|
|
|
|
|
interface.clear()
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if user_input.startswith("example") and len(user_input.split()) == 2:
|
|
|
|
|
|
|
|
example_name = user_input.split()[1]
|
|
|
|
|
|
|
|
if example_name in examples:
|
|
|
|
|
|
|
|
interface.clear()
|
|
|
|
|
|
|
|
chat = []
|
|
|
|
|
|
|
|
interface.print_user_message(examples[example_name]["text"])
|
|
|
|
|
|
|
|
user_input = examples[example_name]["text"]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
example_error = (
|
|
|
|
|
|
|
|
f"Example {example_name} not found in list of available examples: {list(examples.keys())}."
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
interface.print_color(text=example_error, color="red")
|
|
|
|
# `!example` sends a user message to the model
|
|
|
|
|
|
|
|
if not user_input.startswith("!example"):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
else:
|
|
|
|
chat.append({"role": "user", "content": user_input})
|
|
|
|
chat.append({"role": "user", "content": user_input})
|
|
|
|
|
|
|
|
|
|
|
|
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
|
|
|
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
|
|
|
@@ -580,15 +717,8 @@ class ChatCommand(BaseTransformersCLICommand):
|
|
|
|
"inputs": inputs,
|
|
|
|
"inputs": inputs,
|
|
|
|
"attention_mask": attention_mask,
|
|
|
|
"attention_mask": attention_mask,
|
|
|
|
"streamer": generation_streamer,
|
|
|
|
"streamer": generation_streamer,
|
|
|
|
"max_new_tokens": current_args.max_new_tokens,
|
|
|
|
"generation_config": generation_config,
|
|
|
|
"do_sample": current_args.do_sample,
|
|
|
|
**model_kwargs,
|
|
|
|
"num_beams": current_args.num_beams,
|
|
|
|
|
|
|
|
"temperature": current_args.temperature,
|
|
|
|
|
|
|
|
"top_k": current_args.top_k,
|
|
|
|
|
|
|
|
"top_p": current_args.top_p,
|
|
|
|
|
|
|
|
"repetition_penalty": current_args.repetition_penalty,
|
|
|
|
|
|
|
|
"pad_token_id": pad_token_id,
|
|
|
|
|
|
|
|
"eos_token_id": eos_token_ids,
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
|
|
|