@@ -24,12 +24,16 @@ contains the API docs for the underlying classes.
|
|||||||
|
|
||||||
## Agents
|
## Agents
|
||||||
|
|
||||||
We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models.
|
We provide three types of agents: [`HfAgent`] uses inference endpoints for opensource models, [`LocalAgent`] uses a model of your choice locally and [`OpenAiAgent`] uses OpenAI closed models.
|
||||||
|
|
||||||
### HfAgent
|
### HfAgent
|
||||||
|
|
||||||
[[autodoc]] HfAgent
|
[[autodoc]] HfAgent
|
||||||
|
|
||||||
|
### LocalAgent
|
||||||
|
|
||||||
|
[[autodoc]] LocalAgent
|
||||||
|
|
||||||
### OpenAiAgent
|
### OpenAiAgent
|
||||||
|
|
||||||
[[autodoc]] OpenAiAgent
|
[[autodoc]] OpenAiAgent
|
||||||
|
|||||||
@@ -614,6 +614,7 @@ _import_structure = {
|
|||||||
"tools": [
|
"tools": [
|
||||||
"Agent",
|
"Agent",
|
||||||
"HfAgent",
|
"HfAgent",
|
||||||
|
"LocalAgent",
|
||||||
"OpenAiAgent",
|
"OpenAiAgent",
|
||||||
"PipelineTool",
|
"PipelineTool",
|
||||||
"RemoteTool",
|
"RemoteTool",
|
||||||
@@ -4361,7 +4362,17 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Tools
|
# Tools
|
||||||
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
from .tools import (
|
||||||
|
Agent,
|
||||||
|
HfAgent,
|
||||||
|
LocalAgent,
|
||||||
|
OpenAiAgent,
|
||||||
|
PipelineTool,
|
||||||
|
RemoteTool,
|
||||||
|
Tool,
|
||||||
|
launch_gradio_demo,
|
||||||
|
load_tool,
|
||||||
|
)
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer_callback import (
|
from .trainer_callback import (
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ..utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
|
"agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"],
|
||||||
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
|
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ else:
|
|||||||
_import_structure["translation"] = ["TranslationTool"]
|
_import_structure["translation"] = ["TranslationTool"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import Agent, HfAgent, OpenAiAgent
|
from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent
|
||||||
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ from typing import Dict
|
|||||||
import requests
|
import requests
|
||||||
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
|
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
|
||||||
|
|
||||||
|
from ..generation import StoppingCriteria, StoppingCriteriaList
|
||||||
|
from ..models.auto import AutoModelForCausalLM, AutoTokenizer
|
||||||
from ..utils import is_openai_available, logging
|
from ..utils import is_openai_available, logging
|
||||||
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
|
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
|
||||||
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
|
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
|
||||||
@@ -492,3 +494,114 @@ class HfAgent(Agent):
|
|||||||
if result.endswith(stop_seq):
|
if result.endswith(stop_seq):
|
||||||
return result[: -len(stop_seq)]
|
return result[: -len(stop_seq)]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAgent(Agent):
|
||||||
|
"""
|
||||||
|
Agent that uses a local model and tokenizer to generate code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model ([`PreTrainedModel`]):
|
||||||
|
The model to use for the agent.
|
||||||
|
tokenizer ([`PreTrainedTokenizer`]):
|
||||||
|
The tokenizer to use for the agent.
|
||||||
|
chat_prompt_template (`str`, *optional*):
|
||||||
|
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||||
|
run_prompt_template (`str`, *optional*):
|
||||||
|
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||||
|
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||||
|
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||||
|
one of the default tools, that default tool will be overridden.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
|
||||||
|
|
||||||
|
checkpoint = "bigcode/starcoder"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
agent = LocalAgent(model, tokenizer)
|
||||||
|
agent.run("Draw me a picture of rivers and lakes.")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
super().__init__(
|
||||||
|
chat_prompt_template=chat_prompt_template,
|
||||||
|
run_prompt_template=run_prompt_template,
|
||||||
|
additional_tools=additional_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||||
|
"""
|
||||||
|
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||||
|
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
|
||||||
|
kwargs:
|
||||||
|
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import LocalAgent
|
||||||
|
|
||||||
|
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
|
||||||
|
agent.run("Draw me a picture of rivers and lakes.")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
return cls(model, tokenizer)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _model_device(self):
|
||||||
|
if hasattr(self.model, "hf_device_map"):
|
||||||
|
return list(self.model.hf_device_map.values())[0]
|
||||||
|
for param in self.mode.parameters():
|
||||||
|
return param.device
|
||||||
|
|
||||||
|
def generate_one(self, prompt, stop):
|
||||||
|
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
|
||||||
|
src_len = encoded_inputs["input_ids"].shape[1]
|
||||||
|
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
|
||||||
|
outputs = self.model.generate(
|
||||||
|
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
|
||||||
|
# Inference API returns the stop sequence
|
||||||
|
for stop_seq in stop:
|
||||||
|
if result.endswith(stop_seq):
|
||||||
|
result = result[: -len(stop_seq)]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class StopSequenceCriteria(StoppingCriteria):
|
||||||
|
"""
|
||||||
|
This class can be used to stop generation whenever a sequence of tokens is encountered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stop_sequences (`str` or `List[str]`):
|
||||||
|
The sequence (or list of sequences) on which to stop execution.
|
||||||
|
tokenizer:
|
||||||
|
The tokenizer used to decode the model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stop_sequences, tokenizer):
|
||||||
|
if isinstance(stop_sequences, str):
|
||||||
|
stop_sequences = [stop_sequences]
|
||||||
|
self.stop_sequences = stop_sequences
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
||||||
|
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
|
||||||
|
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||||
|
|||||||
Reference in New Issue
Block a user