Decorator for easier tool building (#33439)

* Decorator for tool building
This commit is contained in:
Aymeric Roucher
2024-09-18 11:07:51 +02:00
committed by GitHub
parent fee86516a4
commit e6d9f39dd7
21 changed files with 292 additions and 111 deletions

View File

@@ -325,62 +325,37 @@ model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
print(model.id) print(model.id)
``` ```
This code can be converted into a class that inherits from the [`Tool`] superclass. This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator:
The custom tool needs: ```py
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`. from transformers import tool
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.
@tool
def model_download_counter(task: str) -> str:
"""
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint.
```python Args:
from transformers import Tool task: The task for which
from huggingface_hub import list_models """
model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1)))
class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = (
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
"It returns the name of the checkpoint."
)
inputs = {
"task": {
"type": "text",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "text"
def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id return model.id
``` ```
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use. The function needs:
- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_counter`.
- Type hints on both inputs and output
- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint).
All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!
> [!TIP]
> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).
```python Then you can directly initialize your agent:
from model_downloads import HFModelDownloadsTool ```py
from transformers import CodeAgent
tool = HFModelDownloadsTool()
```
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
```python
from transformers import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads")
agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine) agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
agent.run( agent.run(
"Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?" "Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
@@ -400,7 +375,6 @@ print(f"The most downloaded model for the 'text-to-video' task is {most_download
And the output: And the output:
`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."` `"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`
### Manage your agent's toolbox ### Manage your agent's toolbox
If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool. If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool.

View File

@@ -60,7 +60,68 @@ manager_agent.run("Who is the CEO of Hugging Face?")
> For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia). > For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia).
## Use tools from gradio or LangChain ## Advanced tool usage
### Directly define a tool by subclassing Tool, and share it to the Hub
Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.
If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema).
```python
from transformers import Tool
from huggingface_hub import list_models
class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = """
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint."""
inputs = {
"task": {
"type": "string",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "string"
def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
```
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
```python
from model_downloads import HFModelDownloadsTool
tool = HFModelDownloadsTool()
```
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
```python
from transformers import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads")
```
### Use gradio-tools ### Use gradio-tools

View File

@@ -60,6 +60,10 @@ We provide two types of agents, based on the main [`Agent`] class:
[[autodoc]] load_tool [[autodoc]] load_tool
### tool
[[autodoc]] tool
### Tool ### Tool
[[autodoc]] Tool [[autodoc]] Tool

View File

@@ -70,6 +70,7 @@ _import_structure = {
"launch_gradio_demo", "launch_gradio_demo",
"load_tool", "load_tool",
"stream_to_gradio", "stream_to_gradio",
"tool",
], ],
"audio_utils": [], "audio_utils": [],
"benchmark": [], "benchmark": [],
@@ -4819,6 +4820,7 @@ if TYPE_CHECKING:
launch_gradio_demo, launch_gradio_demo,
load_tool, load_tool,
stream_to_gradio, stream_to_gradio,
tool,
) )
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig

View File

@@ -27,7 +27,7 @@ _import_structure = {
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfApiEngine", "TransformersEngine"], "llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"], "monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
} }
try: try:
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfApiEngine, TransformersEngine from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
try: try:
if not is_torch_available(): if not is_torch_available():

View File

@@ -234,7 +234,7 @@ class AgentAudio(AgentType, str):
return self._path return self._path
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio} AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage} INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
if is_torch_available(): if is_torch_available():

View File

@@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from .. import is_torch_available from .. import is_torch_available
from ..utils import logging as transformers_logging from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole from .llm_engine import HfApiEngine, MessageRole
from .prompts import ( from .prompts import (
@@ -626,10 +626,9 @@ class CodeAgent(Agent):
Example: Example:
```py ```py
from transformers.agents import CodeAgent, PythonInterpreterTool from transformers.agents import CodeAgent
python_interpreter = PythonInterpreterTool() agent = CodeAgent(tools=[])
agent = CodeAgent(tools=[python_interpreter])
agent.run("What is the result of 2 power 3.7384?") agent.run("What is the result of 2 power 3.7384?")
``` ```
""" """
@@ -1019,20 +1018,17 @@ class ReactJsonAgent(ReactAgent):
arguments = {} arguments = {}
observation = self.execute_tool_call(tool_name, arguments) observation = self.execute_tool_call(tool_name, arguments)
observation_type = type(observation) observation_type = type(observation)
if observation_type == AgentText: if observation_type in [AgentImage, AgentAudio]:
updated_information = str(observation).strip()
else:
# TODO: observation naming could allow for different names of same type
if observation_type == AgentImage: if observation_type == AgentImage:
observation_name = "image.png" observation_name = "image.png"
elif observation_type == AgentAudio: elif observation_type == AgentAudio:
observation_name = "audio.mp3" observation_name = "audio.mp3"
else: # TODO: observation naming could allow for different names of same type
observation_name = "object.object"
self.state[observation_name] = observation self.state[observation_name] = observation
updated_information = f"Stored '{observation_name}' in memory." updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
self.logger.info(updated_information) self.logger.info(updated_information)
current_step_logs["observation"] = updated_information current_step_logs["observation"] = updated_information
return current_step_logs return current_step_logs

View File

@@ -152,8 +152,7 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter" name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations." description = "This is a tool that evaluates python code. It can be used to perform calculations."
output_type = "text" output_type = "string"
available_tools = BASE_PYTHON_TOOLS.copy()
def __init__(self, *args, authorized_imports=None, **kwargs): def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None: if authorized_imports is None:
@@ -162,7 +161,7 @@ class PythonInterpreterTool(Tool):
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = { self.inputs = {
"code": { "code": {
"type": "text", "type": "string",
"description": ( "description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}." f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
@@ -173,7 +172,7 @@ class PythonInterpreterTool(Tool):
def forward(self, code): def forward(self, code):
output = str( output = str(
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports) evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
) )
return output return output
@@ -181,7 +180,7 @@ class PythonInterpreterTool(Tool):
class FinalAnswerTool(Tool): class FinalAnswerTool(Tool):
name = "final_answer" name = "final_answer"
description = "Provides a final answer to the given problem." description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}} inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any" output_type = "any"
def forward(self, answer): def forward(self, answer):

View File

@@ -31,7 +31,7 @@ if is_vision_available():
class DocumentQuestionAnsweringTool(PipelineTool): class DocumentQuestionAnsweringTool(PipelineTool):
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa" default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question." description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
name = "document_qa" name = "document_qa"
pre_processor_class = AutoProcessor pre_processor_class = AutoProcessor
model_class = VisionEncoderDecoderModel model_class = VisionEncoderDecoderModel
@@ -41,9 +41,9 @@ class DocumentQuestionAnsweringTool(PipelineTool):
"type": "image", "type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.", "description": "The image containing the information. Can be a PIL Image or a string path to the image.",
}, },
"question": {"type": "text", "description": "The question in English"}, "question": {"type": "string", "description": "The question in English"},
} }
output_type = "text" output_type = "string"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if not is_vision_available(): if not is_vision_available():

View File

@@ -38,9 +38,9 @@ class ImageQuestionAnsweringTool(PipelineTool):
"type": "image", "type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.", "description": "The image containing the information. Can be a PIL Image or a string path to the image.",
}, },
"question": {"type": "text", "description": "The question in English"}, "question": {"type": "string", "description": "The question in English"},
} }
output_type = "text" output_type = "string"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"]) requires_backends(self, ["vision"])

View File

@@ -199,7 +199,7 @@ Thought: I will now generate an image showcasing the oldest person.
Action: Action:
{ {
"action": "image_generator", "action": "image_generator",
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""} "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
}<end_action> }<end_action>
Observation: "image.png" Observation: "image.png"

View File

@@ -26,7 +26,7 @@ class DuckDuckGoSearchTool(Tool):
name = "web_search" name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'.""" Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "text", "description": "The search query to perform."}} inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "any" output_type = "any"
def forward(self, query: str) -> str: def forward(self, query: str) -> str:
@@ -45,11 +45,11 @@ class VisitWebpageTool(Tool):
description = "Visits a wbepage at the given url and returns its content as a markdown string." description = "Visits a wbepage at the given url and returns its content as a markdown string."
inputs = { inputs = {
"url": { "url": {
"type": "text", "type": "string",
"description": "The url of the webpage to visit.", "description": "The url of the webpage to visit.",
} }
} }
output_type = "text" output_type = "string"
def forward(self, url: str) -> str: def forward(self, url: str) -> str:
try: try:

View File

@@ -27,7 +27,7 @@ class SpeechToTextTool(PipelineTool):
model_class = WhisperForConditionalGeneration model_class = WhisperForConditionalGeneration
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
output_type = "text" output_type = "string"
def encode(self, audio): def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt") return self.pre_processor(audio, return_tensors="pt")

View File

@@ -36,7 +36,7 @@ class TextToSpeechTool(PipelineTool):
model_class = SpeechT5ForTextToSpeech model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan post_processor_class = SpeechT5HifiGan
inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}} inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
output_type = "audio" output_type = "audio"
def setup(self): def setup(self):

View File

@@ -16,12 +16,13 @@
# limitations under the License. # limitations under the License.
import base64 import base64
import importlib import importlib
import inspect
import io import io
import json import json
import os import os
import tempfile import tempfile
from functools import lru_cache from functools import lru_cache, wraps
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
@@ -35,7 +36,9 @@ from ..dynamic_module_utils import (
from ..models.auto import AutoProcessor from ..models.auto import AutoProcessor
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
TypeHintParsingException,
cached_file, cached_file,
get_json_schema,
is_accelerate_available, is_accelerate_available,
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
@@ -84,6 +87,20 @@ launch_gradio_demo({class_name})
""" """
def validate_after_init(cls):
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if not isinstance(self, PipelineTool):
self.validate_arguments()
cls.__init__ = new_init
return cls
@validate_after_init
class Tool: class Tool:
""" """
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
@@ -114,17 +131,35 @@ class Tool:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.is_initialized = False self.is_initialized = False
def validate_attributes(self): def validate_arguments(self):
required_attributes = { required_attributes = {
"description": str, "description": str,
"name": str, "name": str,
"inputs": Dict, "inputs": Dict,
"output_type": type, "output_type": str,
} }
authorized_types = ["string", "integer", "number", "image", "audio", "any"]
for attr, expected_type in required_attributes.items(): for attr, expected_type in required_attributes.items():
attr_value = getattr(self, attr, None) attr_value = getattr(self, attr, None)
if not isinstance(attr_value, expected_type): if not isinstance(attr_value, expected_type):
raise TypeError(f"Instance attribute {attr} must exist and be of type {expected_type.__name__}") raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
for input_name, input_content in self.inputs.items():
assert "type" in input_content, f"Input '{input_name}' should specify a type."
if input_content["type"] not in authorized_types:
raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
)
assert "description" in input_content, f"Input '{input_name}' should have a description."
assert getattr(self, "output_type", None) in authorized_types
if not isinstance(self, PipelineTool):
signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.") return NotImplemented("Write this method in your subclass of `Tool`.")
@@ -382,7 +417,7 @@ class Tool:
super().__init__() super().__init__()
self.name = _gradio_tool.name self.name = _gradio_tool.name
self.description = _gradio_tool.description self.description = _gradio_tool.description
self.output_type = "text" self.output_type = "string"
self._gradio_tool = _gradio_tool self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys()) func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
self.inputs = {key: "" for key in func_args} self.inputs = {key: "" for key in func_args}
@@ -404,7 +439,7 @@ class Tool:
self.name = _langchain_tool.name.lower() self.name = _langchain_tool.name.lower()
self.description = _langchain_tool.description self.description = _langchain_tool.description
self.inputs = parse_langchain_args(_langchain_tool.args) self.inputs = parse_langchain_args(_langchain_tool.args)
self.output_type = "text" self.output_type = "string"
self.langchain_tool = _langchain_tool self.langchain_tool = _langchain_tool
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -421,6 +456,7 @@ class Tool:
DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
- {{ tool.name }}: {{ tool.description }} - {{ tool.name }}: {{ tool.description }}
Takes inputs: {{tool.inputs}} Takes inputs: {{tool.inputs}}
Returns an output of type: {{tool.output_type}}
""" """
@@ -621,18 +657,18 @@ def launch_gradio_demo(tool_class: Tool):
gradio_inputs = [] gradio_inputs = []
for input_name, input_details in tool_class.inputs.items(): for input_name, input_details in tool_class.inputs.items():
input_type = input_details["type"] input_type = input_details["type"]
if input_type == "text": if input_type == "image":
gradio_inputs.append(gr.Textbox(label=input_name))
elif input_type == "image":
gradio_inputs.append(gr.Image(label=input_name)) gradio_inputs.append(gr.Image(label=input_name))
elif input_type == "audio": elif input_type == "audio":
gradio_inputs.append(gr.Audio(label=input_name)) gradio_inputs.append(gr.Audio(label=input_name))
elif input_type in ["string", "integer", "number"]:
gradio_inputs.append(gr.Textbox(label=input_name))
else: else:
error_message = f"Input type '{input_type}' not supported." error_message = f"Input type '{input_type}' not supported."
raise ValueError(error_message) raise ValueError(error_message)
gradio_output = tool_class.output_type gradio_output = tool_class.output_type
assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported." assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported."
gr.Interface( gr.Interface(
fn=fn, fn=fn,
@@ -808,3 +844,37 @@ class ToolCollection:
self._collection = get_collection(collection_slug, token=token) self._collection = get_collection(collection_slug, token=token)
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"} self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids} self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
def tool(tool_function: Callable) -> Tool:
"""
Converts a function into an instance of a Tool subclass.
Args:
tool_function: Your function. Should have type hints for each input and a type hint for the output.
Should also have a docstring description including an 'Args:' part where each argument is described.
"""
parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
class_name = f"{parameters['name'].capitalize()}Tool"
class SpecificTool(Tool):
name = parameters["name"]
description = parameters["description"]
inputs = parameters["parameters"]["properties"]
output_type = parameters["return"]["type"]
@wraps(tool_function)
def forward(self, *args, **kwargs):
return tool_function(*args, **kwargs)
original_signature = inspect.signature(tool_function)
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
original_signature.parameters.values()
)
new_signature = original_signature.replace(parameters=new_parameters)
SpecificTool.forward.__signature__ = new_signature
SpecificTool.__name__ = class_name
return SpecificTool()

View File

@@ -249,17 +249,17 @@ class TranslationTool(PipelineTool):
model_class = AutoModelForSeq2SeqLM model_class = AutoModelForSeq2SeqLM
inputs = { inputs = {
"text": {"type": "text", "description": "The text to translate"}, "text": {"type": "string", "description": "The text to translate"},
"src_lang": { "src_lang": {
"type": "text", "type": "string",
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'", "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
}, },
"tgt_lang": { "tgt_lang": {
"type": "text", "type": "string",
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'", "description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
}, },
} }
output_type = "text" output_type = "string"
def encode(self, text, src_lang, tgt_lang): def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code: if src_lang not in self.lang_to_code:

View File

@@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args,
from packaging import version from packaging import version
from .import_utils import is_jinja_available from .import_utils import is_jinja_available, is_torch_available, is_vision_available
if is_jinja_available(): if is_jinja_available():
@@ -32,6 +32,12 @@ if is_jinja_available():
else: else:
jinja2 = None jinja2 = None
if is_vision_available():
from PIL.Image import Image
if is_torch_available():
from torch import Tensor
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
# Extracts the initial segment of the docstring, containing the function description # Extracts the initial segment of the docstring, containing the function description
@@ -70,6 +76,8 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
float: {"type": "number"}, float: {"type": "number"},
str: {"type": "string"}, str: {"type": "string"},
bool: {"type": "boolean"}, bool: {"type": "boolean"},
Image: {"type": "image"},
Tensor: {"type": "audio"},
Any: {}, Any: {},
} }
return type_mapping.get(param_type, {"type": "object"}) return type_mapping.get(param_type, {"type": "object"})

View File

@@ -68,7 +68,6 @@ Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
```py ```py
result = 2**3.6452 result = 2**3.6452
print(result)
```<end_code> ```<end_code>
""" """
else: # We're at step 2 else: # We're at step 2
@@ -181,7 +180,6 @@ Action:
assert isinstance(output, float) assert isinstance(output, float)
assert output == 7.2904 assert output == 7.2904
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
assert agent.logs[2]["tool_call"] == { assert agent.logs[2]["tool_call"] == {
"tool_arguments": "final_answer(7.2904)", "tool_arguments": "final_answer(7.2904)",
"tool_name": "code interpreter", "tool_name": "code interpreter",
@@ -234,7 +232,7 @@ Action:
# check that python_interpreter base tool does not get added to code agents # check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter) assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
def test_function_persistence_across_steps(self): def test_function_persistence_across_steps(self):
agent = ReactCodeAgent( agent = ReactCodeAgent(

View File

@@ -19,8 +19,9 @@ from pathlib import Path
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from transformers import is_torch_available, load_tool from transformers import is_torch_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import FinalAnswerTool
from transformers.testing_utils import get_tests_dir, require_torch from transformers.testing_utils import get_tests_dir, require_torch
from .test_tools_common import ToolTesterMixin from .test_tools_common import ToolTesterMixin
@@ -33,8 +34,7 @@ if is_torch_available():
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.inputs = {"answer": "Final answer"} self.inputs = {"answer": "Final answer"}
self.tool = load_tool("final_answer") self.tool = FinalAnswerTool()
self.tool.setup()
def test_exact_match_arg(self): def test_exact_match_arg(self):
result = self.tool("Final answer") result = self.tool("Final answer")
@@ -52,7 +52,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
) )
} }
inputs_audio = {"answer": torch.Tensor(np.ones(3000))} inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio} return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
@require_torch @require_torch
def test_agent_type_output(self): def test_agent_type_output(self):

View File

@@ -391,8 +391,9 @@ else:
code = """char='a' code = """char='a'
if char.isalpha(): if char.isalpha():
print('2')""" print('2')"""
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) state = {}
assert result == "2" evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert state["print_outputs"] == "2\n"
def test_imports(self): def test_imports(self):
code = "import math\nmath.sqrt(4)" code = "import math\nmath.sqrt(4)"
@@ -469,7 +470,7 @@ if char.isalpha():
code = "print('Hello world!')\nprint('Ok no one cares')" code = "print('Hello world!')\nprint('Ok no one cares')"
state = {} state = {}
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert result == "Ok no one cares" assert result is None
assert state["print_outputs"] == "Hello world!\nOk no one cares\n" assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
# test print in function # test print in function
@@ -593,8 +594,7 @@ except ValueError as e:
def test_print(self): def test_print(self):
code = "print(min([1, 2, 3]))" code = "print(min([1, 2, 3]))"
state = {} state = {}
result = evaluate_python_code(code, {"min": min, "print": print}, state=state) evaluate_python_code(code, {"min": min, "print": print}, state=state)
assert result == "1"
assert state["print_outputs"] == "1\n" assert state["print_outputs"] == "1\n"
def test_types_as_objects(self): def test_types_as_objects(self):

View File

@@ -12,13 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
import numpy as np import numpy as np
import pytest
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
from transformers.agents.tools import Tool, tool
from transformers.testing_utils import get_tests_dir, is_agent_test from transformers.testing_utils import get_tests_dir, is_agent_test
@@ -29,7 +32,7 @@ if is_vision_available():
from PIL import Image from PIL import Image
AUTHORIZED_TYPES = ["text", "audio", "image", "any"] AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
@@ -38,7 +41,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
for input_name, input_desc in tool_inputs.items(): for input_name, input_desc in tool_inputs.items():
input_type = input_desc["type"] input_type = input_desc["type"]
if input_type == "text": if input_type == "string":
inputs[input_name] = "Text input" inputs[input_name] = "Text input"
elif input_type == "image": elif input_type == "image":
inputs[input_name] = Image.open( inputs[input_name] = Image.open(
@@ -54,7 +57,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
def output_type(output): def output_type(output):
if isinstance(output, (str, AgentText)): if isinstance(output, (str, AgentText)):
return "text" return "string"
elif isinstance(output, (Image.Image, AgentImage)): elif isinstance(output, (Image.Image, AgentImage)):
return "image" return "image"
elif isinstance(output, (torch.Tensor, AgentAudio)): elif isinstance(output, (torch.Tensor, AgentAudio)):
@@ -100,3 +103,69 @@ class ToolTesterMixin:
for _input, expected_input in zip(inputs, self.tool.inputs.values()): for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"] input_type = expected_input["type"]
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
class ToolTests(unittest.TestCase):
def test_tool_init_with_decorator(self):
@tool
def coolfunc(a: str, b: int) -> float:
"""Cool function
Args:
a: The first argument
b: The second one
"""
return b + 2, a
assert coolfunc.output_type == "number"
def test_tool_init_vanilla(self):
class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = """
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint."""
inputs = {
"task": {
"type": "string",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "integer"
def forward(self, task):
return "best model"
tool = HFModelDownloadsTool()
assert list(tool.inputs.keys())[0] == "task"
def test_tool_init_decorator_raises_issues(self):
with pytest.raises(Exception) as e:
@tool
def coolfunc(a: str, b: int):
"""Cool function
Args:
a: The first argument
b: The second one
"""
return a + b
assert coolfunc.output_type == "number"
assert "Tool return type not found" in str(e)
with pytest.raises(Exception) as e:
@tool
def coolfunc(a: str, b: int) -> int:
"""Cool function
Args:
a: The first argument
"""
return b + a
assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e)