@@ -24,12 +24,16 @@ contains the API docs for the underlying classes.
|
||||
|
||||
## Agents
|
||||
|
||||
We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models.
|
||||
We provide three types of agents: [`HfAgent`] uses inference endpoints for opensource models, [`LocalAgent`] uses a model of your choice locally and [`OpenAiAgent`] uses OpenAI closed models.
|
||||
|
||||
### HfAgent
|
||||
|
||||
[[autodoc]] HfAgent
|
||||
|
||||
### LocalAgent
|
||||
|
||||
[[autodoc]] LocalAgent
|
||||
|
||||
### OpenAiAgent
|
||||
|
||||
[[autodoc]] OpenAiAgent
|
||||
|
||||
@@ -614,6 +614,7 @@ _import_structure = {
|
||||
"tools": [
|
||||
"Agent",
|
||||
"HfAgent",
|
||||
"LocalAgent",
|
||||
"OpenAiAgent",
|
||||
"PipelineTool",
|
||||
"RemoteTool",
|
||||
@@ -4361,7 +4362,17 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
# Tools
|
||||
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||
from .tools import (
|
||||
Agent,
|
||||
HfAgent,
|
||||
LocalAgent,
|
||||
OpenAiAgent,
|
||||
PipelineTool,
|
||||
RemoteTool,
|
||||
Tool,
|
||||
launch_gradio_demo,
|
||||
load_tool,
|
||||
)
|
||||
|
||||
# Trainer
|
||||
from .trainer_callback import (
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..utils import (
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
|
||||
"agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"],
|
||||
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ else:
|
||||
_import_structure["translation"] = ["TranslationTool"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, HfAgent, OpenAiAgent
|
||||
from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent
|
||||
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||
|
||||
try:
|
||||
|
||||
@@ -24,6 +24,8 @@ from typing import Dict
|
||||
import requests
|
||||
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
|
||||
|
||||
from ..generation import StoppingCriteria, StoppingCriteriaList
|
||||
from ..models.auto import AutoModelForCausalLM, AutoTokenizer
|
||||
from ..utils import is_openai_available, logging
|
||||
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
|
||||
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
|
||||
@@ -492,3 +494,114 @@ class HfAgent(Agent):
|
||||
if result.endswith(stop_seq):
|
||||
return result[: -len(stop_seq)]
|
||||
return result
|
||||
|
||||
|
||||
class LocalAgent(Agent):
|
||||
"""
|
||||
Agent that uses a local model and tokenizer to generate code.
|
||||
|
||||
Args:
|
||||
model ([`PreTrainedModel`]):
|
||||
The model to use for the agent.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer to use for the agent.
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
|
||||
|
||||
checkpoint = "bigcode/starcoder"
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
|
||||
agent = LocalAgent(model, tokenizer)
|
||||
agent.run("Draw me a picture of rivers and lakes.")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
|
||||
kwargs:
|
||||
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import LocalAgent
|
||||
|
||||
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
agent.run("Draw me a picture of rivers and lakes.")
|
||||
```
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
return cls(model, tokenizer)
|
||||
|
||||
@property
|
||||
def _model_device(self):
|
||||
if hasattr(self.model, "hf_device_map"):
|
||||
return list(self.model.hf_device_map.values())[0]
|
||||
for param in self.mode.parameters():
|
||||
return param.device
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
|
||||
src_len = encoded_inputs["input_ids"].shape[1]
|
||||
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
|
||||
outputs = self.model.generate(
|
||||
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
|
||||
)
|
||||
|
||||
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
result = result[: -len(stop_seq)]
|
||||
return result
|
||||
|
||||
|
||||
class StopSequenceCriteria(StoppingCriteria):
|
||||
"""
|
||||
This class can be used to stop generation whenever a sequence of tokens is encountered.
|
||||
|
||||
Args:
|
||||
stop_sequences (`str` or `List[str]`):
|
||||
The sequence (or list of sequences) on which to stop execution.
|
||||
tokenizer:
|
||||
The tokenizer used to decode the model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, stop_sequences, tokenizer):
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
self.stop_sequences = stop_sequences
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
||||
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
|
||||
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
|
||||
Reference in New Issue
Block a user