@@ -283,7 +283,8 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
|
||||
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
||||
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
||||
- **Translation**: translates a given sentence from source language to target language.
|
||||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you use `add_base_tools=True`, since code-based tools can already execute Python code
|
||||
- **DuckDuckGo search***: performs a web search using DuckDuckGo browser.
|
||||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
|
||||
|
||||
|
||||
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
||||
|
||||
@@ -39,6 +39,7 @@ else:
|
||||
_import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
|
||||
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
|
||||
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
|
||||
_import_structure["search"] = ["DuckDuckGoSearchTool"]
|
||||
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
|
||||
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
|
||||
_import_structure["translation"] = ["TranslationTool"]
|
||||
@@ -58,6 +59,7 @@ if TYPE_CHECKING:
|
||||
from .default_tools import FinalAnswerTool, PythonInterpreterTool
|
||||
from .document_question_answering import DocumentQuestionAnsweringTool
|
||||
from .image_question_answering import ImageQuestionAnsweringTool
|
||||
from .search import DuckDuckGoSearchTool
|
||||
from .speech_to_text import SpeechToTextTool
|
||||
from .text_to_speech import TextToSpeechTool
|
||||
from .translation import TranslationTool
|
||||
|
||||
@@ -25,7 +25,7 @@ from huggingface_hub import hf_hub_download, list_spaces
|
||||
|
||||
from ..utils import is_offline_mode
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import TASK_MAPPING, TOOL_CONFIG_FILE, Tool
|
||||
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
||||
|
||||
|
||||
def custom_print(*args):
|
||||
@@ -133,7 +133,7 @@ def setup_default_tools(logger):
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.agents
|
||||
|
||||
for task_name, tool_class_name in TASK_MAPPING.items():
|
||||
for task_name, tool_class_name in TOOL_MAPPING.items():
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
tool_instance = tool_class()
|
||||
default_tools[tool_class.name] = PreTool(
|
||||
|
||||
@@ -444,6 +444,8 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
||||
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
||||
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
||||
|
||||
if isinstance(value, str) and isinstance(index, str):
|
||||
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||
parent_object = value.obj
|
||||
return parent_object.loc[index]
|
||||
|
||||
35
src/transformers/agents/search.py
Normal file
35
src/transformers/agents/search.py
Normal file
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .tools import Tool
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(Tool):
|
||||
name = "web_search"
|
||||
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||
Each result has keys 'title', 'href' and 'body'."""
|
||||
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, query: str) -> str:
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You must install package `duckduckgo_search`: for instance run `pip install duckduckgo-search`."
|
||||
)
|
||||
results = DDGS().text(query, max_results=7)
|
||||
return results
|
||||
@@ -643,13 +643,14 @@ def launch_gradio_demo(tool_class: Tool):
|
||||
).launch()
|
||||
|
||||
|
||||
TASK_MAPPING = {
|
||||
"document-question-answering": "DocumentQuestionAnsweringTool",
|
||||
"image-question-answering": "ImageQuestionAnsweringTool",
|
||||
"speech-to-text": "SpeechToTextTool",
|
||||
"text-to-speech": "TextToSpeechTool",
|
||||
TOOL_MAPPING = {
|
||||
"document_question_answering": "DocumentQuestionAnsweringTool",
|
||||
"image_question_answering": "ImageQuestionAnsweringTool",
|
||||
"speech_to_text": "SpeechToTextTool",
|
||||
"text_to_speech": "TextToSpeechTool",
|
||||
"translation": "TranslationTool",
|
||||
"python_interpreter": "PythonInterpreterTool",
|
||||
"web_search": "DuckDuckGoSearchTool",
|
||||
}
|
||||
|
||||
|
||||
@@ -670,10 +671,10 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
||||
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
|
||||
are:
|
||||
|
||||
- `"document-question-answering"`
|
||||
- `"image-question-answering"`
|
||||
- `"speech-to-text"`
|
||||
- `"text-to-speech"`
|
||||
- `"document_question_answering"`
|
||||
- `"image_question_answering"`
|
||||
- `"speech_to_text"`
|
||||
- `"text_to_speech"`
|
||||
- `"translation"`
|
||||
|
||||
model_repo_id (`str`, *optional*):
|
||||
@@ -686,8 +687,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
||||
will be passed along to its init.
|
||||
"""
|
||||
if task_or_repo_id in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING[task_or_repo_id]
|
||||
if task_or_repo_id in TOOL_MAPPING:
|
||||
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.agents
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
|
||||
@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document-question-answering")
|
||||
self.tool = load_tool("document_question_answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
|
||||
@@ -28,7 +28,7 @@ if is_vision_available():
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-question-answering")
|
||||
self.tool = load_tool("image_question_answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
|
||||
@@ -176,6 +176,23 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||
state = {}
|
||||
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||
|
||||
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
||||
code = """
|
||||
search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
|
||||
for result in search_results:
|
||||
if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
|
||||
current_weather_url = result['href']
|
||||
print(current_weather_url)
|
||||
break"""
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "You're trying to subscript a string with a string index" in e
|
||||
|
||||
def test_evaluate_for(self):
|
||||
code = "x = 0\nfor i in range(3):\n x = i"
|
||||
state = {}
|
||||
@@ -573,13 +590,6 @@ except ValueError as e:
|
||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||
assert state["exception_message"] == "An error occurred"
|
||||
|
||||
def test_subscript(self):
|
||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||
|
||||
state = {}
|
||||
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||
|
||||
def test_print(self):
|
||||
code = "print(min([1, 2, 3]))"
|
||||
state = {}
|
||||
|
||||
30
tests/agents/test_search.py
Normal file
30
tests/agents/test_search.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("web_search")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Agents")
|
||||
assert isinstance(result, list) and isinstance(result[0], dict)
|
||||
@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech-to-text")
|
||||
self.tool = load_tool("speech_to_text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
|
||||
@@ -30,7 +30,7 @@ from .test_tools_common import ToolTesterMixin
|
||||
@require_torch
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-to-speech")
|
||||
self.tool = load_tool("text_to_speech")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
|
||||
@@ -90,8 +90,9 @@ class ToolTesterMixin:
|
||||
def test_agent_type_output(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
output = self.tool(**inputs)
|
||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
if self.tool.output_type != "any":
|
||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
@@ -99,9 +100,3 @@ class ToolTesterMixin:
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**inputs)
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
@@ -44,7 +44,6 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user