From e6d9f39dd7551d1a95be081cbb59f94c54c3dbf6 Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Wed, 18 Sep 2024 11:07:51 +0200 Subject: [PATCH] Decorator for easier tool building (#33439) * Decorator for tool building --- docs/source/en/agents.md | 72 +++++---------- docs/source/en/agents_advanced.md | 63 ++++++++++++- docs/source/en/main_classes/agent.md | 4 + src/transformers/__init__.py | 2 + src/transformers/agents/__init__.py | 4 +- src/transformers/agents/agent_types.py | 2 +- src/transformers/agents/agents.py | 18 ++-- src/transformers/agents/default_tools.py | 9 +- .../agents/document_question_answering.py | 6 +- .../agents/image_question_answering.py | 4 +- src/transformers/agents/prompts.py | 2 +- src/transformers/agents/search.py | 6 +- src/transformers/agents/speech_to_text.py | 2 +- src/transformers/agents/text_to_speech.py | 2 +- src/transformers/agents/tools.py | 92 ++++++++++++++++--- src/transformers/agents/translation.py | 8 +- src/transformers/utils/chat_template_utils.py | 10 +- tests/agents/test_agents.py | 4 +- tests/agents/test_final_answer.py | 8 +- tests/agents/test_python_interpreter.py | 10 +- tests/agents/test_tools_common.py | 75 ++++++++++++++- 21 files changed, 292 insertions(+), 111 deletions(-) diff --git a/docs/source/en/agents.md b/docs/source/en/agents.md index 0b889f4eec..ac06c04d9b 100644 --- a/docs/source/en/agents.md +++ b/docs/source/en/agents.md @@ -325,62 +325,37 @@ model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) print(model.id) ``` -This code can be converted into a class that inherits from the [`Tool`] superclass. +This code can quickly be converted into a tool, just by wrapping it in a function and adding the `tool` decorator: -The custom tool needs: -- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`. -- An attribute `description` is used to populate the agent's system prompt. -- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input. -- An `output_type` attribute, which specifies the output type. -- A `forward` method which contains the inference code to be executed. +```py +from transformers import tool +@tool +def model_download_counter(task: str) -> str: + """ + This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. + It returns the name of the checkpoint. -```python -from transformers import Tool -from huggingface_hub import list_models - -class HFModelDownloadsTool(Tool): - name = "model_download_counter" - description = ( - "This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. " - "It returns the name of the checkpoint." - ) - - inputs = { - "task": { - "type": "text", - "description": "the task category (such as text-classification, depth-estimation, etc)", - } - } - output_type = "text" - - def forward(self, task: str): - model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) - return model.id + Args: + task: The task for which + """ + model = next(iter(list_models(filter="text-classification", sort="downloads", direction=-1))) + return model.id ``` -Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use. +The function needs: +- A clear name. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's put `model_download_counter`. +- Type hints on both inputs and output +- A description, that includes an 'Args:' part where each argument is described (without a type indication this time, it will be pulled from the type hint). +All these will be automatically baked into the agent's system prompt upon initialization: so strive to make them as clear as possible! +> [!TIP] +> This definition format is the same as tool schemas used in `apply_chat_template`, the only difference is the added `tool` decorator: read more on our tool use API [here](https://huggingface.co/blog/unified-tool-use#passing-tools-to-a-chat-template). -```python -from model_downloads import HFModelDownloadsTool - -tool = HFModelDownloadsTool() -``` - -You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. - -```python -tool.push_to_hub("{your_username}/hf-model-downloads") -``` - -Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. - -```python -from transformers import load_tool, CodeAgent - -model_download_tool = load_tool("m-ric/hf-model-downloads") +Then you can directly initialize your agent: +```py +from transformers import CodeAgent agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine) agent.run( "Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?" @@ -400,7 +375,6 @@ print(f"The most downloaded model for the 'text-to-video' task is {most_download And the output: `"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."` - ### Manage your agent's toolbox If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool. diff --git a/docs/source/en/agents_advanced.md b/docs/source/en/agents_advanced.md index 399eeb9b70..2327357525 100644 --- a/docs/source/en/agents_advanced.md +++ b/docs/source/en/agents_advanced.md @@ -60,7 +60,68 @@ manager_agent.run("Who is the CEO of Hugging Face?") > For an in-depth example of an efficient multi-agent implementation, see [how we pushed our multi-agent system to the top of the GAIA leaderboard](https://huggingface.co/blog/beating-gaia). -## Use tools from gradio or LangChain +## Advanced tool usage + +### Directly define a tool by subclassing Tool, and share it to the Hub + +Let's take again the tool example from main documentation, for which we had implemented a `tool` decorator. + +If you need to add variation, like custom attributes for your too, you can build your tool following the fine-grained method: building a class that inherits from the [`Tool`] superclass. + +The custom tool needs: +- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`. +- An attribute `description` is used to populate the agent's system prompt. +- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input. +- An `output_type` attribute, which specifies the output type. +- A `forward` method which contains the inference code to be executed. + +The types for both `inputs` and `output_type` should be amongst [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema). + +```python +from transformers import Tool +from huggingface_hub import list_models + +class HFModelDownloadsTool(Tool): + name = "model_download_counter" + description = """ + This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. + It returns the name of the checkpoint.""" + + inputs = { + "task": { + "type": "string", + "description": "the task category (such as text-classification, depth-estimation, etc)", + } + } + output_type = "string" + + def forward(self, task: str): + model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) + return model.id +``` + +Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use. + + +```python +from model_downloads import HFModelDownloadsTool + +tool = HFModelDownloadsTool() +``` + +You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access. + +```python +tool.push_to_hub("{your_username}/hf-model-downloads") +``` + +Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent. + +```python +from transformers import load_tool, CodeAgent + +model_download_tool = load_tool("m-ric/hf-model-downloads") +``` ### Use gradio-tools diff --git a/docs/source/en/main_classes/agent.md b/docs/source/en/main_classes/agent.md index 8628785815..ed0486b601 100644 --- a/docs/source/en/main_classes/agent.md +++ b/docs/source/en/main_classes/agent.md @@ -60,6 +60,10 @@ We provide two types of agents, based on the main [`Agent`] class: [[autodoc]] load_tool +### tool + +[[autodoc]] tool + ### Tool [[autodoc]] Tool diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 36775d8454..bfd0d37916 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -70,6 +70,7 @@ _import_structure = { "launch_gradio_demo", "load_tool", "stream_to_gradio", + "tool", ], "audio_utils": [], "benchmark": [], @@ -4819,6 +4820,7 @@ if TYPE_CHECKING: launch_gradio_demo, load_tool, stream_to_gradio, + tool, ) from .configuration_utils import PretrainedConfig diff --git a/src/transformers/agents/__init__.py b/src/transformers/agents/__init__.py index d053e385cf..70762c252a 100644 --- a/src/transformers/agents/__init__.py +++ b/src/transformers/agents/__init__.py @@ -27,7 +27,7 @@ _import_structure = { "agents": ["Agent", "CodeAgent", "ManagedAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "llm_engine": ["HfApiEngine", "TransformersEngine"], "monitoring": ["stream_to_gradio"], - "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], + "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool", "tool"], } try: @@ -48,7 +48,7 @@ if TYPE_CHECKING: from .agents import Agent, CodeAgent, ManagedAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .llm_engine import HfApiEngine, TransformersEngine from .monitoring import stream_to_gradio - from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool + from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool, tool try: if not is_torch_available(): diff --git a/src/transformers/agents/agent_types.py b/src/transformers/agents/agent_types.py index 4a36eaaee0..f5be746265 100644 --- a/src/transformers/agents/agent_types.py +++ b/src/transformers/agents/agent_types.py @@ -234,7 +234,7 @@ class AgentAudio(AgentType, str): return self._path -AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio} +AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage} if is_torch_available(): diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index 5a4aea28d9..73b7186d25 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from .. import is_torch_available from ..utils import logging as transformers_logging from ..utils.import_utils import is_pygments_available -from .agent_types import AgentAudio, AgentImage, AgentText +from .agent_types import AgentAudio, AgentImage from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfApiEngine, MessageRole from .prompts import ( @@ -626,10 +626,9 @@ class CodeAgent(Agent): Example: ```py - from transformers.agents import CodeAgent, PythonInterpreterTool + from transformers.agents import CodeAgent - python_interpreter = PythonInterpreterTool() - agent = CodeAgent(tools=[python_interpreter]) + agent = CodeAgent(tools=[]) agent.run("What is the result of 2 power 3.7384?") ``` """ @@ -1019,20 +1018,17 @@ class ReactJsonAgent(ReactAgent): arguments = {} observation = self.execute_tool_call(tool_name, arguments) observation_type = type(observation) - if observation_type == AgentText: - updated_information = str(observation).strip() - else: - # TODO: observation naming could allow for different names of same type + if observation_type in [AgentImage, AgentAudio]: if observation_type == AgentImage: observation_name = "image.png" elif observation_type == AgentAudio: observation_name = "audio.mp3" - else: - observation_name = "object.object" + # TODO: observation naming could allow for different names of same type self.state[observation_name] = observation updated_information = f"Stored '{observation_name}' in memory." - + else: + updated_information = str(observation).strip() self.logger.info(updated_information) current_step_logs["observation"] = updated_information return current_step_logs diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index b02b12d528..3946aa9f87 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -152,8 +152,7 @@ class PythonInterpreterTool(Tool): name = "python_interpreter" description = "This is a tool that evaluates python code. It can be used to perform calculations." - output_type = "text" - available_tools = BASE_PYTHON_TOOLS.copy() + output_type = "string" def __init__(self, *args, authorized_imports=None, **kwargs): if authorized_imports is None: @@ -162,7 +161,7 @@ class PythonInterpreterTool(Tool): self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) self.inputs = { "code": { - "type": "text", + "type": "string", "description": ( "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " f"else you will get an error. This code can only import the following python libraries: {authorized_imports}." @@ -173,7 +172,7 @@ class PythonInterpreterTool(Tool): def forward(self, code): output = str( - evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports) + evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports) ) return output @@ -181,7 +180,7 @@ class PythonInterpreterTool(Tool): class FinalAnswerTool(Tool): name = "final_answer" description = "Provides a final answer to the given problem." - inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}} + inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}} output_type = "any" def forward(self, answer): diff --git a/src/transformers/agents/document_question_answering.py b/src/transformers/agents/document_question_answering.py index 030120ac6c..23ae5b0429 100644 --- a/src/transformers/agents/document_question_answering.py +++ b/src/transformers/agents/document_question_answering.py @@ -31,7 +31,7 @@ if is_vision_available(): class DocumentQuestionAnsweringTool(PipelineTool): default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa" - description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question." + description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question." name = "document_qa" pre_processor_class = AutoProcessor model_class = VisionEncoderDecoderModel @@ -41,9 +41,9 @@ class DocumentQuestionAnsweringTool(PipelineTool): "type": "image", "description": "The image containing the information. Can be a PIL Image or a string path to the image.", }, - "question": {"type": "text", "description": "The question in English"}, + "question": {"type": "string", "description": "The question in English"}, } - output_type = "text" + output_type = "string" def __init__(self, *args, **kwargs): if not is_vision_available(): diff --git a/src/transformers/agents/image_question_answering.py b/src/transformers/agents/image_question_answering.py index 020d22c47f..de0efb7b6f 100644 --- a/src/transformers/agents/image_question_answering.py +++ b/src/transformers/agents/image_question_answering.py @@ -38,9 +38,9 @@ class ImageQuestionAnsweringTool(PipelineTool): "type": "image", "description": "The image containing the information. Can be a PIL Image or a string path to the image.", }, - "question": {"type": "text", "description": "The question in English"}, + "question": {"type": "string", "description": "The question in English"}, } - output_type = "text" + output_type = "string" def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index de8ad1d284..7a84b1db44 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -199,7 +199,7 @@ Thought: I will now generate an image showcasing the oldest person. Action: { "action": "image_generator", - "action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""} + "action_input": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} } Observation: "image.png" diff --git a/src/transformers/agents/search.py b/src/transformers/agents/search.py index 5ce36bf99b..f50a7c6ab8 100644 --- a/src/transformers/agents/search.py +++ b/src/transformers/agents/search.py @@ -26,7 +26,7 @@ class DuckDuckGoSearchTool(Tool): name = "web_search" description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. Each result has keys 'title', 'href' and 'body'.""" - inputs = {"query": {"type": "text", "description": "The search query to perform."}} + inputs = {"query": {"type": "string", "description": "The search query to perform."}} output_type = "any" def forward(self, query: str) -> str: @@ -45,11 +45,11 @@ class VisitWebpageTool(Tool): description = "Visits a wbepage at the given url and returns its content as a markdown string." inputs = { "url": { - "type": "text", + "type": "string", "description": "The url of the webpage to visit.", } } - output_type = "text" + output_type = "string" def forward(self, url: str) -> str: try: diff --git a/src/transformers/agents/speech_to_text.py b/src/transformers/agents/speech_to_text.py index 817b6319e6..8061651a08 100644 --- a/src/transformers/agents/speech_to_text.py +++ b/src/transformers/agents/speech_to_text.py @@ -27,7 +27,7 @@ class SpeechToTextTool(PipelineTool): model_class = WhisperForConditionalGeneration inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}} - output_type = "text" + output_type = "string" def encode(self, audio): return self.pre_processor(audio, return_tensors="pt") diff --git a/src/transformers/agents/text_to_speech.py b/src/transformers/agents/text_to_speech.py index 3166fab802..ed41ef6017 100644 --- a/src/transformers/agents/text_to_speech.py +++ b/src/transformers/agents/text_to_speech.py @@ -36,7 +36,7 @@ class TextToSpeechTool(PipelineTool): model_class = SpeechT5ForTextToSpeech post_processor_class = SpeechT5HifiGan - inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}} + inputs = {"text": {"type": "string", "description": "The text to read out loud (in English)"}} output_type = "audio" def setup(self): diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index f97ccc2e10..cfb1e4cf95 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -16,12 +16,13 @@ # limitations under the License. import base64 import importlib +import inspect import io import json import os import tempfile -from functools import lru_cache -from typing import Any, Dict, List, Optional, Union +from functools import lru_cache, wraps +from typing import Any, Callable, Dict, List, Optional, Union from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session @@ -35,7 +36,9 @@ from ..dynamic_module_utils import ( from ..models.auto import AutoProcessor from ..utils import ( CONFIG_NAME, + TypeHintParsingException, cached_file, + get_json_schema, is_accelerate_available, is_torch_available, is_vision_available, @@ -84,6 +87,20 @@ launch_gradio_demo({class_name}) """ +def validate_after_init(cls): + original_init = cls.__init__ + + @wraps(original_init) + def new_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + if not isinstance(self, PipelineTool): + self.validate_arguments() + + cls.__init__ = new_init + return cls + + +@validate_after_init class Tool: """ A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the @@ -114,17 +131,35 @@ class Tool: def __init__(self, *args, **kwargs): self.is_initialized = False - def validate_attributes(self): + def validate_arguments(self): required_attributes = { "description": str, "name": str, "inputs": Dict, - "output_type": type, + "output_type": str, } + authorized_types = ["string", "integer", "number", "image", "audio", "any"] + for attr, expected_type in required_attributes.items(): attr_value = getattr(self, attr, None) if not isinstance(attr_value, expected_type): - raise TypeError(f"Instance attribute {attr} must exist and be of type {expected_type.__name__}") + raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.") + for input_name, input_content in self.inputs.items(): + assert "type" in input_content, f"Input '{input_name}' should specify a type." + if input_content["type"] not in authorized_types: + raise Exception( + f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}." + ) + assert "description" in input_content, f"Input '{input_name}' should have a description." + + assert getattr(self, "output_type", None) in authorized_types + + if not isinstance(self, PipelineTool): + signature = inspect.signature(self.forward) + if not set(signature.parameters.keys()) == set(self.inputs.keys()): + raise Exception( + "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'." + ) def forward(self, *args, **kwargs): return NotImplemented("Write this method in your subclass of `Tool`.") @@ -382,7 +417,7 @@ class Tool: super().__init__() self.name = _gradio_tool.name self.description = _gradio_tool.description - self.output_type = "text" + self.output_type = "string" self._gradio_tool = _gradio_tool func_args = list(inspect.signature(_gradio_tool.run).parameters.keys()) self.inputs = {key: "" for key in func_args} @@ -404,7 +439,7 @@ class Tool: self.name = _langchain_tool.name.lower() self.description = _langchain_tool.description self.inputs = parse_langchain_args(_langchain_tool.args) - self.output_type = "text" + self.output_type = "string" self.langchain_tool = _langchain_tool def forward(self, *args, **kwargs): @@ -421,6 +456,7 @@ class Tool: DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """ - {{ tool.name }}: {{ tool.description }} Takes inputs: {{tool.inputs}} + Returns an output of type: {{tool.output_type}} """ @@ -621,18 +657,18 @@ def launch_gradio_demo(tool_class: Tool): gradio_inputs = [] for input_name, input_details in tool_class.inputs.items(): input_type = input_details["type"] - if input_type == "text": - gradio_inputs.append(gr.Textbox(label=input_name)) - elif input_type == "image": + if input_type == "image": gradio_inputs.append(gr.Image(label=input_name)) elif input_type == "audio": gradio_inputs.append(gr.Audio(label=input_name)) + elif input_type in ["string", "integer", "number"]: + gradio_inputs.append(gr.Textbox(label=input_name)) else: error_message = f"Input type '{input_type}' not supported." raise ValueError(error_message) gradio_output = tool_class.output_type - assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported." + assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported." gr.Interface( fn=fn, @@ -808,3 +844,37 @@ class ToolCollection: self._collection = get_collection(collection_slug, token=token) self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"} self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids} + + +def tool(tool_function: Callable) -> Tool: + """ + Converts a function into an instance of a Tool subclass. + + Args: + tool_function: Your function. Should have type hints for each input and a type hint for the output. + Should also have a docstring description including an 'Args:' part where each argument is described. + """ + parameters = get_json_schema(tool_function)["function"] + if "return" not in parameters: + raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!") + class_name = f"{parameters['name'].capitalize()}Tool" + + class SpecificTool(Tool): + name = parameters["name"] + description = parameters["description"] + inputs = parameters["parameters"]["properties"] + output_type = parameters["return"]["type"] + + @wraps(tool_function) + def forward(self, *args, **kwargs): + return tool_function(*args, **kwargs) + + original_signature = inspect.signature(tool_function) + new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list( + original_signature.parameters.values() + ) + new_signature = original_signature.replace(parameters=new_parameters) + SpecificTool.forward.__signature__ = new_signature + + SpecificTool.__name__ = class_name + return SpecificTool() diff --git a/src/transformers/agents/translation.py b/src/transformers/agents/translation.py index efc97c6e0b..7ae61f9679 100644 --- a/src/transformers/agents/translation.py +++ b/src/transformers/agents/translation.py @@ -249,17 +249,17 @@ class TranslationTool(PipelineTool): model_class = AutoModelForSeq2SeqLM inputs = { - "text": {"type": "text", "description": "The text to translate"}, + "text": {"type": "string", "description": "The text to translate"}, "src_lang": { - "type": "text", + "type": "string", "description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'", }, "tgt_lang": { - "type": "text", + "type": "string", "description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'", }, } - output_type = "text" + output_type = "string" def encode(self, text, src_lang, tgt_lang): if src_lang not in self.lang_to_code: diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index aabaf4a366..74912ce301 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, from packaging import version -from .import_utils import is_jinja_available +from .import_utils import is_jinja_available, is_torch_available, is_vision_available if is_jinja_available(): @@ -32,6 +32,12 @@ if is_jinja_available(): else: jinja2 = None +if is_vision_available(): + from PIL.Image import Image + +if is_torch_available(): + from torch import Tensor + BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) # Extracts the initial segment of the docstring, containing the function description @@ -70,6 +76,8 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]: float: {"type": "number"}, str: {"type": "string"}, bool: {"type": "boolean"}, + Image: {"type": "image"}, + Tensor: {"type": "audio"}, Any: {}, } return type_mapping.get(param_type, {"type": "object"}) diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 67cb31b7da..4f24abbeed 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -68,7 +68,6 @@ Thought: I should multiply 2 by 3.6452. special_marker Code: ```py result = 2**3.6452 -print(result) ``` """ else: # We're at step 2 @@ -181,7 +180,6 @@ Action: assert isinstance(output, float) assert output == 7.2904 assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" - assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6 assert agent.logs[2]["tool_call"] == { "tool_arguments": "final_answer(7.2904)", "tool_name": "code interpreter", @@ -234,7 +232,7 @@ Action: # check that python_interpreter base tool does not get added to code agents agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) - assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter) + assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter) def test_function_persistence_across_steps(self): agent = ReactCodeAgent( diff --git a/tests/agents/test_final_answer.py b/tests/agents/test_final_answer.py index 59d5dec84b..91bdd65e89 100644 --- a/tests/agents/test_final_answer.py +++ b/tests/agents/test_final_answer.py @@ -19,8 +19,9 @@ from pathlib import Path import numpy as np from PIL import Image -from transformers import is_torch_available, load_tool +from transformers import is_torch_available from transformers.agents.agent_types import AGENT_TYPE_MAPPING +from transformers.agents.default_tools import FinalAnswerTool from transformers.testing_utils import get_tests_dir, require_torch from .test_tools_common import ToolTesterMixin @@ -33,8 +34,7 @@ if is_torch_available(): class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): self.inputs = {"answer": "Final answer"} - self.tool = load_tool("final_answer") - self.tool.setup() + self.tool = FinalAnswerTool() def test_exact_match_arg(self): result = self.tool("Final answer") @@ -52,7 +52,7 @@ class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin): ) } inputs_audio = {"answer": torch.Tensor(np.ones(3000))} - return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio} + return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio} @require_torch def test_agent_type_output(self): diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 84710cfec6..15e5ad7bb3 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -391,8 +391,9 @@ else: code = """char='a' if char.isalpha(): print('2')""" - result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) - assert result == "2" + state = {} + evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) + assert state["print_outputs"] == "2\n" def test_imports(self): code = "import math\nmath.sqrt(4)" @@ -469,7 +470,7 @@ if char.isalpha(): code = "print('Hello world!')\nprint('Ok no one cares')" state = {} result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) - assert result == "Ok no one cares" + assert result is None assert state["print_outputs"] == "Hello world!\nOk no one cares\n" # test print in function @@ -593,8 +594,7 @@ except ValueError as e: def test_print(self): code = "print(min([1, 2, 3]))" state = {} - result = evaluate_python_code(code, {"min": min, "print": print}, state=state) - assert result == "1" + evaluate_python_code(code, {"min": min, "print": print}, state=state) assert state["print_outputs"] == "1\n" def test_types_as_objects(self): diff --git a/tests/agents/test_tools_common.py b/tests/agents/test_tools_common.py index 679473d0f2..8226e71098 100644 --- a/tests/agents/test_tools_common.py +++ b/tests/agents/test_tools_common.py @@ -12,13 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest from pathlib import Path from typing import Dict, Union import numpy as np +import pytest from transformers import is_torch_available, is_vision_available from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText +from transformers.agents.tools import Tool, tool from transformers.testing_utils import get_tests_dir, is_agent_test @@ -29,7 +32,7 @@ if is_vision_available(): from PIL import Image -AUTHORIZED_TYPES = ["text", "audio", "image", "any"] +AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"] def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): @@ -38,7 +41,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): for input_name, input_desc in tool_inputs.items(): input_type = input_desc["type"] - if input_type == "text": + if input_type == "string": inputs[input_name] = "Text input" elif input_type == "image": inputs[input_name] = Image.open( @@ -54,7 +57,7 @@ def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]): def output_type(output): if isinstance(output, (str, AgentText)): - return "text" + return "string" elif isinstance(output, (Image.Image, AgentImage)): return "image" elif isinstance(output, (torch.Tensor, AgentAudio)): @@ -100,3 +103,69 @@ class ToolTesterMixin: for _input, expected_input in zip(inputs, self.tool.inputs.values()): input_type = expected_input["type"] _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) + + +class ToolTests(unittest.TestCase): + def test_tool_init_with_decorator(self): + @tool + def coolfunc(a: str, b: int) -> float: + """Cool function + + Args: + a: The first argument + b: The second one + """ + return b + 2, a + + assert coolfunc.output_type == "number" + + def test_tool_init_vanilla(self): + class HFModelDownloadsTool(Tool): + name = "model_download_counter" + description = """ + This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. + It returns the name of the checkpoint.""" + + inputs = { + "task": { + "type": "string", + "description": "the task category (such as text-classification, depth-estimation, etc)", + } + } + output_type = "integer" + + def forward(self, task): + return "best model" + + tool = HFModelDownloadsTool() + assert list(tool.inputs.keys())[0] == "task" + + def test_tool_init_decorator_raises_issues(self): + with pytest.raises(Exception) as e: + + @tool + def coolfunc(a: str, b: int): + """Cool function + + Args: + a: The first argument + b: The second one + """ + return a + b + + assert coolfunc.output_type == "number" + assert "Tool return type not found" in str(e) + + with pytest.raises(Exception) as e: + + @tool + def coolfunc(a: str, b: int) -> int: + """Cool function + + Args: + a: The first argument + """ + return b + a + + assert coolfunc.output_type == "number" + assert "docstring has no description for the argument" in str(e)