Decorator for easier tool building (#33439)
* Decorator for tool building
This commit is contained in:
@@ -325,62 +325,37 @@ model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
print(model.id)
|
||||
```
|
||||
|
||||
This code can be converted into a class that inherits from the [`Tool`] superclass.
|
||||
This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator:
|
||||
|
||||
|
||||
The custom tool needs:
|
||||
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
|
||||
- An attribute `description` is used to populate the agent's system prompt.
|
||||
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
|
||||
- An `output_type` attribute, which specifies the output type.
|
||||
- A `forward` method which contains the inference code to be executed.
|
||||
```py
|
||||
from transformers import tool
|
||||
|
||||
@tool
|
||||
def model_download_counter(task: str) -> str:
|
||||
"""
|
||||
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
|
||||
It returns the name of the checkpoint.
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = (
|
||||
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
|
||||
"It returns the name of the checkpoint."
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"task": {
|
||||
"type": "text",
|
||||
"description": "the task category (such as text-classification, depth-estimation, etc)",
|
||||
}
|
||||
}
|
||||
output_type = "text"
|
||||
|
||||
def forward(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
Args:
|
||||
task: The task for which
|
||||
"""
|
||||
model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
|
||||
The function needs:
|
||||
- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_counter`.
|
||||
- Type hints on both inputs and output
|
||||
- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint).
|
||||
All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!
|
||||
|
||||
> [!TIP]
|
||||
> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
|
||||
|
||||
```python
|
||||
tool.push_to_hub("{your_username}/hf-model-downloads")
|
||||
```
|
||||
|
||||
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
||||
|
||||
```python
|
||||
from transformers import load_tool, CodeAgent
|
||||
|
||||
model_download_tool = load_tool("m-ric/hf-model-downloads")
|
||||
Then you can directly initialize your agent:
|
||||
```py
|
||||
from transformers import CodeAgent
|
||||
agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
|
||||
agent.run(
|
||||
"Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
|
||||
@@ -400,7 +375,6 @@ print(f"The most downloaded model for the 'text-to-video' task is {most_download
|
||||
And the output:
|
||||
`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`
|
||||
|
||||
|
||||
### Manage your agent's toolbox
|
||||
|
||||
If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool.
|
||||
|
||||
@@ -60,7 +60,68 @@ manager_agent.run("Who is the CEO of Hugging Face?")
|
||||
> For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia).
|
||||
|
||||
|
||||
## Use tools from gradio or LangChain
|
||||
## Advanced tool usage
|
||||
|
||||
### Directly define a tool by subclassing Tool, and share it to the Hub
|
||||
|
||||
Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.
|
||||
|
||||
If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
|
||||
|
||||
The custom tool needs:
|
||||
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
|
||||
- An attribute `description` is used to populate the agent's system prompt.
|
||||
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
|
||||
- An `output_type` attribute, which specifies the output type.
|
||||
- A `forward` method which contains the inference code to be executed.
|
||||
|
||||
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema).
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = """
|
||||
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
|
||||
It returns the name of the checkpoint."""
|
||||
|
||||
inputs = {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "the task category (such as text-classification, depth-estimation, etc)",
|
||||
}
|
||||
}
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
|
||||
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
|
||||
|
||||
```python
|
||||
tool.push_to_hub("{your_username}/hf-model-downloads")
|
||||
```
|
||||
|
||||
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
||||
|
||||
```python
|
||||
from transformers import load_tool, CodeAgent
|
||||
|
||||
model_download_tool = load_tool("m-ric/hf-model-downloads")
|
||||
```
|
||||
|
||||
### Use gradio-tools
|
||||
|
||||
|
||||
@@ -60,6 +60,10 @@ We provide two types of agents, based on the main [`Agent`] class:
|
||||
|
||||
[[autodoc]] load_tool
|
||||
|
||||
### tool
|
||||
|
||||
[[autodoc]] tool
|
||||
|
||||
### Tool
|
||||
|
||||
[[autodoc]] Tool
|
||||
|
||||
@@ -70,6 +70,7 @@ _import_structure = {
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
"stream_to_gradio",
|
||||
"tool",
|
||||
],
|
||||
"audio_utils": [],
|
||||
"benchmark": [],
|
||||
@@ -4819,6 +4820,7 @@ if TYPE_CHECKING:
|
||||
launch_gradio_demo,
|
||||
load_tool,
|
||||
stream_to_gradio,
|
||||
tool,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ _import_structure = {
|
||||
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
||||
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
||||
"monitoring": ["stream_to_gradio"],
|
||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
||||
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||
from .llm_engine import HfApiEngine, TransformersEngine
|
||||
from .monitoring import stream_to_gradio
|
||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
||||
@@ -234,7 +234,7 @@ class AgentAudio(AgentType, str):
|
||||
return self._path
|
||||
|
||||
|
||||
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from .. import is_torch_available
|
||||
from ..utils import logging as transformers_logging
|
||||
from ..utils.import_utils import is_pygments_available
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .agent_types import AgentAudio, AgentImage
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .llm_engine import HfApiEngine, MessageRole
|
||||
from .prompts import (
|
||||
@@ -626,10 +626,9 @@ class CodeAgent(Agent):
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.agents import CodeAgent, PythonInterpreterTool
|
||||
from transformers.agents import CodeAgent
|
||||
|
||||
python_interpreter = PythonInterpreterTool()
|
||||
agent = CodeAgent(tools=[python_interpreter])
|
||||
agent = CodeAgent(tools=[])
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
@@ -1019,20 +1018,17 @@ class ReactJsonAgent(ReactAgent):
|
||||
arguments = {}
|
||||
observation = self.execute_tool_call(tool_name, arguments)
|
||||
observation_type = type(observation)
|
||||
if observation_type == AgentText:
|
||||
updated_information = str(observation).strip()
|
||||
else:
|
||||
# TODO: observation naming could allow for different names of same type
|
||||
if observation_type in [AgentImage, AgentAudio]:
|
||||
if observation_type == AgentImage:
|
||||
observation_name = "image.png"
|
||||
elif observation_type == AgentAudio:
|
||||
observation_name = "audio.mp3"
|
||||
else:
|
||||
observation_name = "object.object"
|
||||
# TODO: observation naming could allow for different names of same type
|
||||
|
||||
self.state[observation_name] = observation
|
||||
updated_information = f"Stored '{observation_name}' in memory."
|
||||
|
||||
else:
|
||||
updated_information = str(observation).strip()
|
||||
self.logger.info(updated_information)
|
||||
current_step_logs["observation"] = updated_information
|
||||
return current_step_logs
|
||||
|
||||
@@ -152,8 +152,7 @@ class PythonInterpreterTool(Tool):
|
||||
name = "python_interpreter"
|
||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||
|
||||
output_type = "text"
|
||||
available_tools = BASE_PYTHON_TOOLS.copy()
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, *args, authorized_imports=None, **kwargs):
|
||||
if authorized_imports is None:
|
||||
@@ -162,7 +161,7 @@ class PythonInterpreterTool(Tool):
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
|
||||
self.inputs = {
|
||||
"code": {
|
||||
"type": "text",
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
||||
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
|
||||
@@ -173,7 +172,7 @@ class PythonInterpreterTool(Tool):
|
||||
|
||||
def forward(self, code):
|
||||
output = str(
|
||||
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
|
||||
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -181,7 +180,7 @@ class PythonInterpreterTool(Tool):
|
||||
class FinalAnswerTool(Tool):
|
||||
name = "final_answer"
|
||||
description = "Provides a final answer to the given problem."
|
||||
inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}}
|
||||
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, answer):
|
||||
|
||||
@@ -31,7 +31,7 @@ if is_vision_available():
|
||||
|
||||
class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
||||
description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question."
|
||||
description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
|
||||
name = "document_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = VisionEncoderDecoderModel
|
||||
@@ -41,9 +41,9 @@ class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
"type": "image",
|
||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
||||
},
|
||||
"question": {"type": "text", "description": "The question in English"},
|
||||
"question": {"type": "string", "description": "The question in English"},
|
||||
}
|
||||
output_type = "text"
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
|
||||
@@ -38,9 +38,9 @@ class ImageQuestionAnsweringTool(PipelineTool):
|
||||
"type": "image",
|
||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
||||
},
|
||||
"question": {"type": "text", "description": "The question in English"},
|
||||
"question": {"type": "string", "description": "The question in English"},
|
||||
}
|
||||
output_type = "text"
|
||||
output_type = "string"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
@@ -199,7 +199,7 @@ Thought: I will now generate an image showcasing the oldest person.
|
||||
Action:
|
||||
{
|
||||
"action": "image_generator",
|
||||
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""}
|
||||
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
|
||||
}<end_action>
|
||||
Observation: "image.png"
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class DuckDuckGoSearchTool(Tool):
|
||||
name = "web_search"
|
||||
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||
Each result has keys 'title', 'href' and 'body'."""
|
||||
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
|
||||
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, query: str) -> str:
|
||||
@@ -45,11 +45,11 @@ class VisitWebpageTool(Tool):
|
||||
description = "Visits a wbepage at the given url and returns its content as a markdown string."
|
||||
inputs = {
|
||||
"url": {
|
||||
"type": "text",
|
||||
"type": "string",
|
||||
"description": "The url of the webpage to visit.",
|
||||
}
|
||||
}
|
||||
output_type = "text"
|
||||
output_type = "string"
|
||||
|
||||
def forward(self, url: str) -> str:
|
||||
try:
|
||||
|
||||
@@ -27,7 +27,7 @@ class SpeechToTextTool(PipelineTool):
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
|
||||
output_type = "text"
|
||||
output_type = "string"
|
||||
|
||||
def encode(self, audio):
|
||||
return self.pre_processor(audio, return_tensors="pt")
|
||||
|
||||
@@ -36,7 +36,7 @@ class TextToSpeechTool(PipelineTool):
|
||||
model_class = SpeechT5ForTextToSpeech
|
||||
post_processor_class = SpeechT5HifiGan
|
||||
|
||||
inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}}
|
||||
inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
|
||||
output_type = "audio"
|
||||
|
||||
def setup(self):
|
||||
|
||||
@@ -16,12 +16,13 @@
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
|
||||
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
||||
@@ -35,7 +36,9 @@ from ..dynamic_module_utils import (
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
TypeHintParsingException,
|
||||
cached_file,
|
||||
get_json_schema,
|
||||
is_accelerate_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
@@ -84,6 +87,20 @@ launch_gradio_demo({class_name})
|
||||
"""
|
||||
|
||||
|
||||
def validate_after_init(cls):
|
||||
original_init = cls.__init__
|
||||
|
||||
@wraps(original_init)
|
||||
def new_init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
if not isinstance(self, PipelineTool):
|
||||
self.validate_arguments()
|
||||
|
||||
cls.__init__ = new_init
|
||||
return cls
|
||||
|
||||
|
||||
@validate_after_init
|
||||
class Tool:
|
||||
"""
|
||||
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
||||
@@ -114,17 +131,35 @@ class Tool:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_initialized = False
|
||||
|
||||
def validate_attributes(self):
|
||||
def validate_arguments(self):
|
||||
required_attributes = {
|
||||
"description": str,
|
||||
"name": str,
|
||||
"inputs": Dict,
|
||||
"output_type": type,
|
||||
"output_type": str,
|
||||
}
|
||||
authorized_types = ["string", "integer", "number", "image", "audio", "any"]
|
||||
|
||||
for attr, expected_type in required_attributes.items():
|
||||
attr_value = getattr(self, attr, None)
|
||||
if not isinstance(attr_value, expected_type):
|
||||
raise TypeError(f"Instance attribute {attr} must exist and be of type {expected_type.__name__}")
|
||||
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
|
||||
for input_name, input_content in self.inputs.items():
|
||||
assert "type" in input_content, f"Input '{input_name}' should specify a type."
|
||||
if input_content["type"] not in authorized_types:
|
||||
raise Exception(
|
||||
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
|
||||
)
|
||||
assert "description" in input_content, f"Input '{input_name}' should have a description."
|
||||
|
||||
assert getattr(self, "output_type", None) in authorized_types
|
||||
|
||||
if not isinstance(self, PipelineTool):
|
||||
signature = inspect.signature(self.forward)
|
||||
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
|
||||
raise Exception(
|
||||
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
|
||||
)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||
@@ -382,7 +417,7 @@ class Tool:
|
||||
super().__init__()
|
||||
self.name = _gradio_tool.name
|
||||
self.description = _gradio_tool.description
|
||||
self.output_type = "text"
|
||||
self.output_type = "string"
|
||||
self._gradio_tool = _gradio_tool
|
||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
|
||||
self.inputs = {key: "" for key in func_args}
|
||||
@@ -404,7 +439,7 @@ class Tool:
|
||||
self.name = _langchain_tool.name.lower()
|
||||
self.description = _langchain_tool.description
|
||||
self.inputs = parse_langchain_args(_langchain_tool.args)
|
||||
self.output_type = "text"
|
||||
self.output_type = "string"
|
||||
self.langchain_tool = _langchain_tool
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -421,6 +456,7 @@ class Tool:
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
||||
- {{ tool.name }}: {{ tool.description }}
|
||||
Takes inputs: {{tool.inputs}}
|
||||
Returns an output of type: {{tool.output_type}}
|
||||
"""
|
||||
|
||||
|
||||
@@ -621,18 +657,18 @@ def launch_gradio_demo(tool_class: Tool):
|
||||
gradio_inputs = []
|
||||
for input_name, input_details in tool_class.inputs.items():
|
||||
input_type = input_details["type"]
|
||||
if input_type == "text":
|
||||
gradio_inputs.append(gr.Textbox(label=input_name))
|
||||
elif input_type == "image":
|
||||
if input_type == "image":
|
||||
gradio_inputs.append(gr.Image(label=input_name))
|
||||
elif input_type == "audio":
|
||||
gradio_inputs.append(gr.Audio(label=input_name))
|
||||
elif input_type in ["string", "integer", "number"]:
|
||||
gradio_inputs.append(gr.Textbox(label=input_name))
|
||||
else:
|
||||
error_message = f"Input type '{input_type}' not supported."
|
||||
raise ValueError(error_message)
|
||||
|
||||
gradio_output = tool_class.output_type
|
||||
assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported."
|
||||
assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported."
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
@@ -808,3 +844,37 @@ class ToolCollection:
|
||||
self._collection = get_collection(collection_slug, token=token)
|
||||
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
|
||||
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
|
||||
|
||||
|
||||
def tool(tool_function: Callable) -> Tool:
|
||||
"""
|
||||
Converts a function into an instance of a Tool subclass.
|
||||
|
||||
Args:
|
||||
tool_function: Your function. Should have type hints for each input and a type hint for the output.
|
||||
Should also have a docstring description including an 'Args:' part where each argument is described.
|
||||
"""
|
||||
parameters = get_json_schema(tool_function)["function"]
|
||||
if "return" not in parameters:
|
||||
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
|
||||
class_name = f"{parameters['name'].capitalize()}Tool"
|
||||
|
||||
class SpecificTool(Tool):
|
||||
name = parameters["name"]
|
||||
description = parameters["description"]
|
||||
inputs = parameters["parameters"]["properties"]
|
||||
output_type = parameters["return"]["type"]
|
||||
|
||||
@wraps(tool_function)
|
||||
def forward(self, *args, **kwargs):
|
||||
return tool_function(*args, **kwargs)
|
||||
|
||||
original_signature = inspect.signature(tool_function)
|
||||
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
|
||||
original_signature.parameters.values()
|
||||
)
|
||||
new_signature = original_signature.replace(parameters=new_parameters)
|
||||
SpecificTool.forward.__signature__ = new_signature
|
||||
|
||||
SpecificTool.__name__ = class_name
|
||||
return SpecificTool()
|
||||
|
||||
@@ -249,17 +249,17 @@ class TranslationTool(PipelineTool):
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = {
|
||||
"text": {"type": "text", "description": "The text to translate"},
|
||||
"text": {"type": "string", "description": "The text to translate"},
|
||||
"src_lang": {
|
||||
"type": "text",
|
||||
"type": "string",
|
||||
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
|
||||
},
|
||||
"tgt_lang": {
|
||||
"type": "text",
|
||||
"type": "string",
|
||||
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
|
||||
},
|
||||
}
|
||||
output_type = "text"
|
||||
output_type = "string"
|
||||
|
||||
def encode(self, text, src_lang, tgt_lang):
|
||||
if src_lang not in self.lang_to_code:
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args,
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_jinja_available
|
||||
from .import_utils import is_jinja_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_jinja_available():
|
||||
@@ -32,6 +32,12 @@ if is_jinja_available():
|
||||
else:
|
||||
jinja2 = None
|
||||
|
||||
if is_vision_available():
|
||||
from PIL.Image import Image
|
||||
|
||||
if is_torch_available():
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
|
||||
# Extracts the initial segment of the docstring, containing the function description
|
||||
@@ -70,6 +76,8 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||
float: {"type": "number"},
|
||||
str: {"type": "string"},
|
||||
bool: {"type": "boolean"},
|
||||
Image: {"type": "image"},
|
||||
Tensor: {"type": "audio"},
|
||||
Any: {},
|
||||
}
|
||||
return type_mapping.get(param_type, {"type": "object"})
|
||||
|
||||
@@ -68,7 +68,6 @@ Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = 2**3.6452
|
||||
print(result)
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
@@ -181,7 +180,6 @@ Action:
|
||||
assert isinstance(output, float)
|
||||
assert output == 7.2904
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
|
||||
assert agent.logs[2]["tool_call"] == {
|
||||
"tool_arguments": "final_answer(7.2904)",
|
||||
"tool_name": "code interpreter",
|
||||
@@ -234,7 +232,7 @@ Action:
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
|
||||
assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
|
||||
|
||||
def test_function_persistence_across_steps(self):
|
||||
agent = ReactCodeAgent(
|
||||
|
||||
@@ -19,8 +19,9 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
from transformers import is_torch_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import FinalAnswerTool
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
@@ -33,8 +34,7 @@ if is_torch_available():
|
||||
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.inputs = {"answer": "Final answer"}
|
||||
self.tool = load_tool("final_answer")
|
||||
self.tool.setup()
|
||||
self.tool = FinalAnswerTool()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Final answer")
|
||||
@@ -52,7 +52,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
)
|
||||
}
|
||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||
return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
|
||||
@require_torch
|
||||
def test_agent_type_output(self):
|
||||
|
||||
@@ -391,8 +391,9 @@ else:
|
||||
code = """char='a'
|
||||
if char.isalpha():
|
||||
print('2')"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "2"
|
||||
state = {}
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert state["print_outputs"] == "2\n"
|
||||
|
||||
def test_imports(self):
|
||||
code = "import math\nmath.sqrt(4)"
|
||||
@@ -469,7 +470,7 @@ if char.isalpha():
|
||||
code = "print('Hello world!')\nprint('Ok no one cares')"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert result == "Ok no one cares"
|
||||
assert result is None
|
||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||
|
||||
# test print in function
|
||||
@@ -593,8 +594,7 @@ except ValueError as e:
|
||||
def test_print(self):
|
||||
code = "print(min([1, 2, 3]))"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"min": min, "print": print}, state=state)
|
||||
assert result == "1"
|
||||
evaluate_python_code(code, {"min": min, "print": print}, state=state)
|
||||
assert state["print_outputs"] == "1\n"
|
||||
|
||||
def test_types_as_objects(self):
|
||||
|
||||
@@ -12,13 +12,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
from transformers.agents.tools import Tool, tool
|
||||
from transformers.testing_utils import get_tests_dir, is_agent_test
|
||||
|
||||
|
||||
@@ -29,7 +32,7 @@ if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = ["text", "audio", "image", "any"]
|
||||
AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
|
||||
|
||||
|
||||
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||
@@ -38,7 +41,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||
for input_name, input_desc in tool_inputs.items():
|
||||
input_type = input_desc["type"]
|
||||
|
||||
if input_type == "text":
|
||||
if input_type == "string":
|
||||
inputs[input_name] = "Text input"
|
||||
elif input_type == "image":
|
||||
inputs[input_name] = Image.open(
|
||||
@@ -54,7 +57,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||
|
||||
def output_type(output):
|
||||
if isinstance(output, (str, AgentText)):
|
||||
return "text"
|
||||
return "string"
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
return "image"
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
@@ -100,3 +103,69 @@ class ToolTesterMixin:
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
|
||||
class ToolTests(unittest.TestCase):
|
||||
def test_tool_init_with_decorator(self):
|
||||
@tool
|
||||
def coolfunc(a: str, b: int) -> float:
|
||||
"""Cool function
|
||||
|
||||
Args:
|
||||
a: The first argument
|
||||
b: The second one
|
||||
"""
|
||||
return b + 2, a
|
||||
|
||||
assert coolfunc.output_type == "number"
|
||||
|
||||
def test_tool_init_vanilla(self):
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = """
|
||||
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
|
||||
It returns the name of the checkpoint."""
|
||||
|
||||
inputs = {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "the task category (such as text-classification, depth-estimation, etc)",
|
||||
}
|
||||
}
|
||||
output_type = "integer"
|
||||
|
||||
def forward(self, task):
|
||||
return "best model"
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
assert list(tool.inputs.keys())[0] == "task"
|
||||
|
||||
def test_tool_init_decorator_raises_issues(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
@tool
|
||||
def coolfunc(a: str, b: int):
|
||||
"""Cool function
|
||||
|
||||
Args:
|
||||
a: The first argument
|
||||
b: The second one
|
||||
"""
|
||||
return a + b
|
||||
|
||||
assert coolfunc.output_type == "number"
|
||||
assert "Tool return type not found" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
@tool
|
||||
def coolfunc(a: str, b: int) -> int:
|
||||
"""Cool function
|
||||
|
||||
Args:
|
||||
a: The first argument
|
||||
"""
|
||||
return b + a
|
||||
|
||||
assert coolfunc.output_type == "number"
|
||||
assert "docstring has no description for the argument" in str(e)
|
||||
|
||||
Reference in New Issue
Block a user