@@ -283,7 +283,8 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
|
|||||||
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
||||||
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
||||||
- **Translation**: translates a given sentence from source language to target language.
|
- **Translation**: translates a given sentence from source language to target language.
|
||||||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you use `add_base_tools=True`, since code-based tools can already execute Python code
|
- **DuckDuckGo search***: performs a web search using DuckDuckGo browser.
|
||||||
|
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
|
||||||
|
|
||||||
|
|
||||||
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ else:
|
|||||||
_import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
|
_import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
|
||||||
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
|
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
|
||||||
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
|
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
|
||||||
|
_import_structure["search"] = ["DuckDuckGoSearchTool"]
|
||||||
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
|
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
|
||||||
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
|
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
|
||||||
_import_structure["translation"] = ["TranslationTool"]
|
_import_structure["translation"] = ["TranslationTool"]
|
||||||
@@ -58,6 +59,7 @@ if TYPE_CHECKING:
|
|||||||
from .default_tools import FinalAnswerTool, PythonInterpreterTool
|
from .default_tools import FinalAnswerTool, PythonInterpreterTool
|
||||||
from .document_question_answering import DocumentQuestionAnsweringTool
|
from .document_question_answering import DocumentQuestionAnsweringTool
|
||||||
from .image_question_answering import ImageQuestionAnsweringTool
|
from .image_question_answering import ImageQuestionAnsweringTool
|
||||||
|
from .search import DuckDuckGoSearchTool
|
||||||
from .speech_to_text import SpeechToTextTool
|
from .speech_to_text import SpeechToTextTool
|
||||||
from .text_to_speech import TextToSpeechTool
|
from .text_to_speech import TextToSpeechTool
|
||||||
from .translation import TranslationTool
|
from .translation import TranslationTool
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from huggingface_hub import hf_hub_download, list_spaces
|
|||||||
|
|
||||||
from ..utils import is_offline_mode
|
from ..utils import is_offline_mode
|
||||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||||
from .tools import TASK_MAPPING, TOOL_CONFIG_FILE, Tool
|
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
|
||||||
|
|
||||||
|
|
||||||
def custom_print(*args):
|
def custom_print(*args):
|
||||||
@@ -133,7 +133,7 @@ def setup_default_tools(logger):
|
|||||||
main_module = importlib.import_module("transformers")
|
main_module = importlib.import_module("transformers")
|
||||||
tools_module = main_module.agents
|
tools_module = main_module.agents
|
||||||
|
|
||||||
for task_name, tool_class_name in TASK_MAPPING.items():
|
for task_name, tool_class_name in TOOL_MAPPING.items():
|
||||||
tool_class = getattr(tools_module, tool_class_name)
|
tool_class = getattr(tools_module, tool_class_name)
|
||||||
tool_instance = tool_class()
|
tool_instance = tool_class()
|
||||||
default_tools[tool_class.name] = PreTool(
|
default_tools[tool_class.name] = PreTool(
|
||||||
|
|||||||
@@ -444,6 +444,8 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
|||||||
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
||||||
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
||||||
|
|
||||||
|
if isinstance(value, str) and isinstance(index, str):
|
||||||
|
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
|
||||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||||
parent_object = value.obj
|
parent_object = value.obj
|
||||||
return parent_object.loc[index]
|
return parent_object.loc[index]
|
||||||
|
|||||||
35
src/transformers/agents/search.py
Normal file
35
src/transformers/agents/search.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .tools import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class DuckDuckGoSearchTool(Tool):
|
||||||
|
name = "web_search"
|
||||||
|
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
|
||||||
|
Each result has keys 'title', 'href' and 'body'."""
|
||||||
|
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
|
||||||
|
output_type = "any"
|
||||||
|
|
||||||
|
def forward(self, query: str) -> str:
|
||||||
|
try:
|
||||||
|
from duckduckgo_search import DDGS
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"You must install package `duckduckgo_search`: for instance run `pip install duckduckgo-search`."
|
||||||
|
)
|
||||||
|
results = DDGS().text(query, max_results=7)
|
||||||
|
return results
|
||||||
@@ -643,13 +643,14 @@ def launch_gradio_demo(tool_class: Tool):
|
|||||||
).launch()
|
).launch()
|
||||||
|
|
||||||
|
|
||||||
TASK_MAPPING = {
|
TOOL_MAPPING = {
|
||||||
"document-question-answering": "DocumentQuestionAnsweringTool",
|
"document_question_answering": "DocumentQuestionAnsweringTool",
|
||||||
"image-question-answering": "ImageQuestionAnsweringTool",
|
"image_question_answering": "ImageQuestionAnsweringTool",
|
||||||
"speech-to-text": "SpeechToTextTool",
|
"speech_to_text": "SpeechToTextTool",
|
||||||
"text-to-speech": "TextToSpeechTool",
|
"text_to_speech": "TextToSpeechTool",
|
||||||
"translation": "TranslationTool",
|
"translation": "TranslationTool",
|
||||||
"python_interpreter": "PythonInterpreterTool",
|
"python_interpreter": "PythonInterpreterTool",
|
||||||
|
"web_search": "DuckDuckGoSearchTool",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -670,10 +671,10 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
|||||||
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
|
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
|
||||||
are:
|
are:
|
||||||
|
|
||||||
- `"document-question-answering"`
|
- `"document_question_answering"`
|
||||||
- `"image-question-answering"`
|
- `"image_question_answering"`
|
||||||
- `"speech-to-text"`
|
- `"speech_to_text"`
|
||||||
- `"text-to-speech"`
|
- `"text_to_speech"`
|
||||||
- `"translation"`
|
- `"translation"`
|
||||||
|
|
||||||
model_repo_id (`str`, *optional*):
|
model_repo_id (`str`, *optional*):
|
||||||
@@ -686,8 +687,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
|||||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
||||||
will be passed along to its init.
|
will be passed along to its init.
|
||||||
"""
|
"""
|
||||||
if task_or_repo_id in TASK_MAPPING:
|
if task_or_repo_id in TOOL_MAPPING:
|
||||||
tool_class_name = TASK_MAPPING[task_or_repo_id]
|
tool_class_name = TOOL_MAPPING[task_or_repo_id]
|
||||||
main_module = importlib.import_module("transformers")
|
main_module = importlib.import_module("transformers")
|
||||||
tools_module = main_module.agents
|
tools_module = main_module.agents
|
||||||
tool_class = getattr(tools_module, tool_class_name)
|
tool_class = getattr(tools_module, tool_class_name)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
|
|||||||
|
|
||||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tool = load_tool("document-question-answering")
|
self.tool = load_tool("document_question_answering")
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ if is_vision_available():
|
|||||||
|
|
||||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tool = load_tool("image-question-answering")
|
self.tool = load_tool("image_question_answering")
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
|
|||||||
@@ -176,6 +176,23 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||||
|
|
||||||
|
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||||
|
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||||
|
|
||||||
|
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
||||||
|
code = """
|
||||||
|
search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
|
||||||
|
for result in search_results:
|
||||||
|
if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
|
||||||
|
current_weather_url = result['href']
|
||||||
|
print(current_weather_url)
|
||||||
|
break"""
|
||||||
|
with pytest.raises(InterpreterError) as e:
|
||||||
|
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert "You're trying to subscript a string with a string index" in e
|
||||||
|
|
||||||
def test_evaluate_for(self):
|
def test_evaluate_for(self):
|
||||||
code = "x = 0\nfor i in range(3):\n x = i"
|
code = "x = 0\nfor i in range(3):\n x = i"
|
||||||
state = {}
|
state = {}
|
||||||
@@ -573,13 +590,6 @@ except ValueError as e:
|
|||||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||||
assert state["exception_message"] == "An error occurred"
|
assert state["exception_message"] == "An error occurred"
|
||||||
|
|
||||||
def test_subscript(self):
|
|
||||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
|
||||||
|
|
||||||
state = {}
|
|
||||||
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
|
||||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
|
||||||
|
|
||||||
def test_print(self):
|
def test_print(self):
|
||||||
code = "print(min([1, 2, 3]))"
|
code = "print(min([1, 2, 3]))"
|
||||||
state = {}
|
state = {}
|
||||||
|
|||||||
30
tests/agents/test_search.py
Normal file
30
tests/agents/test_search.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import load_tool
|
||||||
|
|
||||||
|
from .test_tools_common import ToolTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
|
def setUp(self):
|
||||||
|
self.tool = load_tool("web_search")
|
||||||
|
self.tool.setup()
|
||||||
|
|
||||||
|
def test_exact_match_arg(self):
|
||||||
|
result = self.tool("Agents")
|
||||||
|
assert isinstance(result, list) and isinstance(result[0], dict)
|
||||||
@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
|
|||||||
|
|
||||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tool = load_tool("speech-to-text")
|
self.tool = load_tool("speech_to_text")
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from .test_tools_common import ToolTesterMixin
|
|||||||
@require_torch
|
@require_torch
|
||||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tool = load_tool("text-to-speech")
|
self.tool = load_tool("text_to_speech")
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class ToolTesterMixin:
|
|||||||
def test_agent_type_output(self):
|
def test_agent_type_output(self):
|
||||||
inputs = create_inputs(self.tool.inputs)
|
inputs = create_inputs(self.tool.inputs)
|
||||||
output = self.tool(**inputs)
|
output = self.tool(**inputs)
|
||||||
|
if self.tool.output_type != "any":
|
||||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, agent_type))
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
||||||
@@ -99,9 +100,3 @@ class ToolTesterMixin:
|
|||||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||||
input_type = expected_input["type"]
|
input_type = expected_input["type"]
|
||||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
|
||||||
|
|
||||||
# Should not raise an error
|
|
||||||
output = self.tool(**inputs)
|
|
||||||
self.assertTrue(isinstance(output, output_type))
|
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
|||||||
def test_agent_type_output(self):
|
def test_agent_type_output(self):
|
||||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||||
output = self.tool(*inputs)
|
output = self.tool(*inputs)
|
||||||
|
|
||||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||||
self.assertTrue(isinstance(output, output_type))
|
self.assertTrue(isinstance(output, output_type))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user