Decorator for easier tool building (#33439)

* Decorator for tool building
This commit is contained in:
Aymeric Roucher
2024-09-18 11:07:51 +02:00
committed by GitHub
parent fee86516a4
commit e6d9f39dd7
21 changed files with 292 additions and 111 deletions

View File

@@ -325,62 +325,37 @@ model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
print(model.id)
```
This code can be converted into a class that inherits from the [`Tool`] superclass.
This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator:
The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.
```py
from transformers import tool
@tool
def model_download_counter(task: str) -> str:
"""
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint.
```python
from transformers import Tool
from huggingface_hub import list_models
class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = (
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
"It returns the name of the checkpoint."
)
inputs = {
"task": {
"type": "text",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "text"
def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
Args:
task: The task for which
"""
model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1)))
return model.id
```
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
The function needs:
- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_counter`.
- Type hints on both inputs and output
- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint).
All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible!
> [!TIP]
> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template).
```python
from model_downloads import HFModelDownloadsTool
tool = HFModelDownloadsTool()
```
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
```python
from transformers import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads")
Then you can directly initialize your agent:
```py
from transformers import CodeAgent
agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
agent.run(
"Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
@@ -400,7 +375,6 @@ print(f"The most downloaded model for the 'text-to-video' task is {most_download
And the output:
`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`
### Manage your agent's toolbox
If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool.

View File

@@ -60,7 +60,68 @@ manager_agent.run("Who is the CEO of Hugging Face?")
> For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia).
## Use tools from gradio or LangChain
## Advanced tool usage
### Directly define a tool by subclassing Tool, and share it to the Hub
Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator.
If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass.
The custom tool needs:
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
- An attribute `description` is used to populate the agent's system prompt.
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
- An `output_type` attribute, which specifies the output type.
- A `forward` method which contains the inference code to be executed.
The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema).
```python
from transformers import Tool
from huggingface_hub import list_models
class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = """
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint."""
inputs = {
"task": {
"type": "string",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "string"
def forward(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
```
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
```python
from model_downloads import HFModelDownloadsTool
tool = HFModelDownloadsTool()
```
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
```python
tool.push_to_hub("{your_username}/hf-model-downloads")
```
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
```python
from transformers import load_tool, CodeAgent
model_download_tool = load_tool("m-ric/hf-model-downloads")
```
### Use gradio-tools

View File

@@ -60,6 +60,10 @@ We provide two types of agents, based on the main [`Agent`] class:
[[autodoc]] load_tool
### tool
[[autodoc]] tool
### Tool
[[autodoc]] Tool

View File

@@ -70,6 +70,7 @@ _import_structure = {
"launch_gradio_demo",
"load_tool",
"stream_to_gradio",
"tool",
],
"audio_utils": [],
"benchmark": [],
@@ -4819,6 +4820,7 @@ if TYPE_CHECKING:
launch_gradio_demo,
load_tool,
stream_to_gradio,
tool,
)
from .configuration_utils import PretrainedConfig

View File

@@ -27,7 +27,7 @@ _import_structure = {
"agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"],
}
try:
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool
try:
if not is_torch_available():

View File

@@ -234,7 +234,7 @@ class AgentAudio(AgentType, str):
return self._path
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
if is_torch_available():

View File

@@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from .. import is_torch_available
from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText
from .agent_types import AgentAudio, AgentImage
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfApiEngine, MessageRole
from .prompts import (
@@ -626,10 +626,9 @@ class CodeAgent(Agent):
Example:
```py
from transformers.agents import CodeAgent, PythonInterpreterTool
from transformers.agents import CodeAgent
python_interpreter = PythonInterpreterTool()
agent = CodeAgent(tools=[python_interpreter])
agent = CodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?")
```
"""
@@ -1019,20 +1018,17 @@ class ReactJsonAgent(ReactAgent):
arguments = {}
observation = self.execute_tool_call(tool_name, arguments)
observation_type = type(observation)
if observation_type == AgentText:
updated_information = str(observation).strip()
else:
# TODO: observation naming could allow for different names of same type
if observation_type in [AgentImage, AgentAudio]:
if observation_type == AgentImage:
observation_name = "image.png"
elif observation_type == AgentAudio:
observation_name = "audio.mp3"
else:
observation_name = "object.object"
# TODO: observation naming could allow for different names of same type
self.state[observation_name] = observation
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
self.logger.info(updated_information)
current_step_logs["observation"] = updated_information
return current_step_logs

View File

@@ -152,8 +152,7 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."
output_type = "text"
available_tools = BASE_PYTHON_TOOLS.copy()
output_type = "string"
def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
@@ -162,7 +161,7 @@ class PythonInterpreterTool(Tool):
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "text",
"type": "string",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
@@ -173,7 +172,7 @@ class PythonInterpreterTool(Tool):
def forward(self, code):
output = str(
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
)
return output
@@ -181,7 +180,7 @@ class PythonInterpreterTool(Tool):
class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}}
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any"
def forward(self, answer):

View File

@@ -31,7 +31,7 @@ if is_vision_available():
class DocumentQuestionAnsweringTool(PipelineTool):
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question."
description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
name = "document_qa"
pre_processor_class = AutoProcessor
model_class = VisionEncoderDecoderModel
@@ -41,9 +41,9 @@ class DocumentQuestionAnsweringTool(PipelineTool):
"type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
},
"question": {"type": "text", "description": "The question in English"},
"question": {"type": "string", "description": "The question in English"},
}
output_type = "text"
output_type = "string"
def __init__(self, *args, **kwargs):
if not is_vision_available():

View File

@@ -38,9 +38,9 @@ class ImageQuestionAnsweringTool(PipelineTool):
"type": "image",
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
},
"question": {"type": "text", "description": "The question in English"},
"question": {"type": "string", "description": "The question in English"},
}
output_type = "text"
output_type = "string"
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])

View File

@@ -199,7 +199,7 @@ Thought: I will now generate an image showcasing the oldest person.
Action:
{
"action": "image_generator",
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""}
"action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
}<end_action>
Observation: "image.png"

View File

@@ -26,7 +26,7 @@ class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "any"
def forward(self, query: str) -> str:
@@ -45,11 +45,11 @@ class VisitWebpageTool(Tool):
description = "Visits a wbepage at the given url and returns its content as a markdown string."
inputs = {
"url": {
"type": "text",
"type": "string",
"description": "The url of the webpage to visit.",
}
}
output_type = "text"
output_type = "string"
def forward(self, url: str) -> str:
try:

View File

@@ -27,7 +27,7 @@ class SpeechToTextTool(PipelineTool):
model_class = WhisperForConditionalGeneration
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
output_type = "text"
output_type = "string"
def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt")

View File

@@ -36,7 +36,7 @@ class TextToSpeechTool(PipelineTool):
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan
inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}}
inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}}
output_type = "audio"
def setup(self):

View File

@@ -16,12 +16,13 @@
# limitations under the License.
import base64
import importlib
import inspect
import io
import json
import os
import tempfile
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union
from functools import lru_cache, wraps
from typing import Any, Callable, Dict, List, Optional, Union
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
@@ -35,7 +36,9 @@ from ..dynamic_module_utils import (
from ..models.auto import AutoProcessor
from ..utils import (
CONFIG_NAME,
TypeHintParsingException,
cached_file,
get_json_schema,
is_accelerate_available,
is_torch_available,
is_vision_available,
@@ -84,6 +87,20 @@ launch_gradio_demo({class_name})
"""
def validate_after_init(cls):
original_init = cls.__init__
@wraps(original_init)
def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if not isinstance(self, PipelineTool):
self.validate_arguments()
cls.__init__ = new_init
return cls
@validate_after_init
class Tool:
"""
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
@@ -114,17 +131,35 @@ class Tool:
def __init__(self, *args, **kwargs):
self.is_initialized = False
def validate_attributes(self):
def validate_arguments(self):
required_attributes = {
"description": str,
"name": str,
"inputs": Dict,
"output_type": type,
"output_type": str,
}
authorized_types = ["string", "integer", "number", "image", "audio", "any"]
for attr, expected_type in required_attributes.items():
attr_value = getattr(self, attr, None)
if not isinstance(attr_value, expected_type):
raise TypeError(f"Instance attribute {attr} must exist and be of type {expected_type.__name__}")
raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
for input_name, input_content in self.inputs.items():
assert "type" in input_content, f"Input '{input_name}' should specify a type."
if input_content["type"] not in authorized_types:
raise Exception(
f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
)
assert "description" in input_content, f"Input '{input_name}' should have a description."
assert getattr(self, "output_type", None) in authorized_types
if not isinstance(self, PipelineTool):
signature = inspect.signature(self.forward)
if not set(signature.parameters.keys()) == set(self.inputs.keys()):
raise Exception(
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)
def forward(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.")
@@ -382,7 +417,7 @@ class Tool:
super().__init__()
self.name = _gradio_tool.name
self.description = _gradio_tool.description
self.output_type = "text"
self.output_type = "string"
self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
self.inputs = {key: "" for key in func_args}
@@ -404,7 +439,7 @@ class Tool:
self.name = _langchain_tool.name.lower()
self.description = _langchain_tool.description
self.inputs = parse_langchain_args(_langchain_tool.args)
self.output_type = "text"
self.output_type = "string"
self.langchain_tool = _langchain_tool
def forward(self, *args, **kwargs):
@@ -421,6 +456,7 @@ class Tool:
DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
- {{ tool.name }}: {{ tool.description }}
Takes inputs: {{tool.inputs}}
Returns an output of type: {{tool.output_type}}
"""
@@ -621,18 +657,18 @@ def launch_gradio_demo(tool_class: Tool):
gradio_inputs = []
for input_name, input_details in tool_class.inputs.items():
input_type = input_details["type"]
if input_type == "text":
gradio_inputs.append(gr.Textbox(label=input_name))
elif input_type == "image":
if input_type == "image":
gradio_inputs.append(gr.Image(label=input_name))
elif input_type == "audio":
gradio_inputs.append(gr.Audio(label=input_name))
elif input_type in ["string", "integer", "number"]:
gradio_inputs.append(gr.Textbox(label=input_name))
else:
error_message = f"Input type '{input_type}' not supported."
raise ValueError(error_message)
gradio_output = tool_class.output_type
assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported."
assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported."
gr.Interface(
fn=fn,
@@ -808,3 +844,37 @@ class ToolCollection:
self._collection = get_collection(collection_slug, token=token)
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
def tool(tool_function: Callable) -> Tool:
"""
Converts a function into an instance of a Tool subclass.
Args:
tool_function: Your function. Should have type hints for each input and a type hint for the output.
Should also have a docstring description including an 'Args:' part where each argument is described.
"""
parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
class_name = f"{parameters['name'].capitalize()}Tool"
class SpecificTool(Tool):
name = parameters["name"]
description = parameters["description"]
inputs = parameters["parameters"]["properties"]
output_type = parameters["return"]["type"]
@wraps(tool_function)
def forward(self, *args, **kwargs):
return tool_function(*args, **kwargs)
original_signature = inspect.signature(tool_function)
new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
original_signature.parameters.values()
)
new_signature = original_signature.replace(parameters=new_parameters)
SpecificTool.forward.__signature__ = new_signature
SpecificTool.__name__ = class_name
return SpecificTool()

View File

@@ -249,17 +249,17 @@ class TranslationTool(PipelineTool):
model_class = AutoModelForSeq2SeqLM
inputs = {
"text": {"type": "text", "description": "The text to translate"},
"text": {"type": "string", "description": "The text to translate"},
"src_lang": {
"type": "text",
"type": "string",
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
},
"tgt_lang": {
"type": "text",
"type": "string",
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
},
}
output_type = "text"
output_type = "string"
def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code:

View File

@@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args,
from packaging import version
from .import_utils import is_jinja_available
from .import_utils import is_jinja_available, is_torch_available, is_vision_available
if is_jinja_available():
@@ -32,6 +32,12 @@ if is_jinja_available():
else:
jinja2 = None
if is_vision_available():
from PIL.Image import Image
if is_torch_available():
from torch import Tensor
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
# Extracts the initial segment of the docstring, containing the function description
@@ -70,6 +76,8 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
float: {"type": "number"},
str: {"type": "string"},
bool: {"type": "boolean"},
Image: {"type": "image"},
Tensor: {"type": "audio"},
Any: {},
}
return type_mapping.get(param_type, {"type": "object"})

View File

@@ -68,7 +68,6 @@ Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = 2**3.6452
print(result)
```<end_code>
"""
else: # We're at step 2
@@ -181,7 +180,6 @@ Action:
assert isinstance(output, float)
assert output == 7.2904
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
assert agent.logs[2]["tool_call"] == {
"tool_arguments": "final_answer(7.2904)",
"tool_name": "code interpreter",
@@ -234,7 +232,7 @@ Action:
# check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
def test_function_persistence_across_steps(self):
agent = ReactCodeAgent(

View File

@@ -19,8 +19,9 @@ from pathlib import Path
import numpy as np
from PIL import Image
from transformers import is_torch_available, load_tool
from transformers import is_torch_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import FinalAnswerTool
from transformers.testing_utils import get_tests_dir, require_torch
from .test_tools_common import ToolTesterMixin
@@ -33,8 +34,7 @@ if is_torch_available():
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.inputs = {"answer": "Final answer"}
self.tool = load_tool("final_answer")
self.tool.setup()
self.tool = FinalAnswerTool()
def test_exact_match_arg(self):
result = self.tool("Final answer")
@@ -52,7 +52,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
)
}
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio}
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
@require_torch
def test_agent_type_output(self):

View File

@@ -391,8 +391,9 @@ else:
code = """char='a'
if char.isalpha():
print('2')"""
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "2"
state = {}
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert state["print_outputs"] == "2\n"
def test_imports(self):
code = "import math\nmath.sqrt(4)"
@@ -469,7 +470,7 @@ if char.isalpha():
code = "print('Hello world!')\nprint('Ok no one cares')"
state = {}
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
assert result == "Ok no one cares"
assert result is None
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
# test print in function
@@ -593,8 +594,7 @@ except ValueError as e:
def test_print(self):
code = "print(min([1, 2, 3]))"
state = {}
result = evaluate_python_code(code, {"min": min, "print": print}, state=state)
assert result == "1"
evaluate_python_code(code, {"min": min, "print": print}, state=state)
assert state["print_outputs"] == "1\n"
def test_types_as_objects(self):

View File

@@ -12,13 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from typing import Dict, Union
import numpy as np
import pytest
from transformers import is_torch_available, is_vision_available
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
from transformers.agents.tools import Tool, tool
from transformers.testing_utils import get_tests_dir, is_agent_test
@@ -29,7 +32,7 @@ if is_vision_available():
from PIL import Image
AUTHORIZED_TYPES = ["text", "audio", "image", "any"]
AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
@@ -38,7 +41,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
for input_name, input_desc in tool_inputs.items():
input_type = input_desc["type"]
if input_type == "text":
if input_type == "string":
inputs[input_name] = "Text input"
elif input_type == "image":
inputs[input_name] = Image.open(
@@ -54,7 +57,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
def output_type(output):
if isinstance(output, (str, AgentText)):
return "text"
return "string"
elif isinstance(output, (Image.Image, AgentImage)):
return "image"
elif isinstance(output, (torch.Tensor, AgentAudio)):
@@ -100,3 +103,69 @@ class ToolTesterMixin:
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"]
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
class ToolTests(unittest.TestCase):
def test_tool_init_with_decorator(self):
@tool
def coolfunc(a: str, b: int) -> float:
"""Cool function
Args:
a: The first argument
b: The second one
"""
return b + 2, a
assert coolfunc.output_type == "number"
def test_tool_init_vanilla(self):
class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = """
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
It returns the name of the checkpoint."""
inputs = {
"task": {
"type": "string",
"description": "the task category (such as text-classification, depth-estimation, etc)",
}
}
output_type = "integer"
def forward(self, task):
return "best model"
tool = HFModelDownloadsTool()
assert list(tool.inputs.keys())[0] == "task"
def test_tool_init_decorator_raises_issues(self):
with pytest.raises(Exception) as e:
@tool
def coolfunc(a: str, b: int):
"""Cool function
Args:
a: The first argument
b: The second one
"""
return a + b
assert coolfunc.output_type == "number"
assert "Tool return type not found" in str(e)
with pytest.raises(Exception) as e:
@tool
def coolfunc(a: str, b: int) -> int:
"""Cool function
Args:
a: The first argument
"""
return b + a
assert coolfunc.output_type == "number"
assert "docstring has no description for the argument" in str(e)