From 1ca9ff5c91203b836a8a1f97dae1fd041550b27f Mon Sep 17 00:00:00 2001 From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com> Date: Mon, 2 Sep 2024 09:56:20 +0200 Subject: [PATCH] Add duckduckgo search tool (#32882) * Add duckduckgo search tool --- docs/source/en/agents.md | 3 +- src/transformers/agents/__init__.py | 2 ++ src/transformers/agents/default_tools.py | 4 +-- src/transformers/agents/python_interpreter.py | 2 ++ src/transformers/agents/search.py | 35 +++++++++++++++++++ src/transformers/agents/tools.py | 23 ++++++------ .../test_document_question_answering.py | 2 +- tests/agents/test_image_question_answering.py | 2 +- tests/agents/test_python_interpreter.py | 24 +++++++++---- tests/agents/test_search.py | 30 ++++++++++++++++ tests/agents/test_speech_to_text.py | 2 +- tests/agents/test_text_to_speech.py | 2 +- tests/agents/test_tools_common.py | 11 ++---- tests/agents/test_translation.py | 1 - 14 files changed, 109 insertions(+), 34 deletions(-) create mode 100644 src/transformers/agents/search.py create mode 100644 tests/agents/test_search.py diff --git a/docs/source/en/agents.md b/docs/source/en/agents.md index 992e75ebe5..8495e1a854 100644 --- a/docs/source/en/agents.md +++ b/docs/source/en/agents.md @@ -283,7 +283,8 @@ Transformers comes with a default toolbox for empowering agents, that you can ad - **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper)) - **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5)) - **Translation**: translates a given sentence from source language to target language. -- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you use `add_base_tools=True`, since code-based tools can already execute Python code +- **DuckDuckGo search***: performs a web search using DuckDuckGo browser. +- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code You can manually use a tool by calling the [`load_tool`] function and a task to perform. diff --git a/src/transformers/agents/__init__.py b/src/transformers/agents/__init__.py index f447d16580..4235d4c0d7 100644 --- a/src/transformers/agents/__init__.py +++ b/src/transformers/agents/__init__.py @@ -39,6 +39,7 @@ else: _import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"] _import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"] _import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"] + _import_structure["search"] = ["DuckDuckGoSearchTool"] _import_structure["speech_to_text"] = ["SpeechToTextTool"] _import_structure["text_to_speech"] = ["TextToSpeechTool"] _import_structure["translation"] = ["TranslationTool"] @@ -58,6 +59,7 @@ if TYPE_CHECKING: from .default_tools import FinalAnswerTool, PythonInterpreterTool from .document_question_answering import DocumentQuestionAnsweringTool from .image_question_answering import ImageQuestionAnsweringTool + from .search import DuckDuckGoSearchTool from .speech_to_text import SpeechToTextTool from .text_to_speech import TextToSpeechTool from .translation import TranslationTool diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 4190977672..84bbf3a973 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -25,7 +25,7 @@ from huggingface_hub import hf_hub_download, list_spaces from ..utils import is_offline_mode from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code -from .tools import TASK_MAPPING, TOOL_CONFIG_FILE, Tool +from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool def custom_print(*args): @@ -133,7 +133,7 @@ def setup_default_tools(logger): main_module = importlib.import_module("transformers") tools_module = main_module.agents - for task_name, tool_class_name in TASK_MAPPING.items(): + for task_name, tool_class_name in TOOL_MAPPING.items(): tool_class = getattr(tools_module, tool_class_name) tool_instance = tool_class() default_tools[tool_class.name] = PreTool( diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index e641a8d0c1..702363a21e 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -444,6 +444,8 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools): index = evaluate_ast(subscript.slice, state, static_tools, custom_tools) value = evaluate_ast(subscript.value, state, static_tools, custom_tools) + if isinstance(value, str) and isinstance(index, str): + raise InterpreterError("You're trying to subscript a string with a string index, which is impossible") if isinstance(value, pd.core.indexing._LocIndexer): parent_object = value.obj return parent_object.loc[index] diff --git a/src/transformers/agents/search.py b/src/transformers/agents/search.py new file mode 100644 index 0000000000..0b33c92733 --- /dev/null +++ b/src/transformers/agents/search.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .tools import Tool + + +class DuckDuckGoSearchTool(Tool): + name = "web_search" + description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. + Each result has keys 'title', 'href' and 'body'.""" + inputs = {"query": {"type": "text", "description": "The search query to perform."}} + output_type = "any" + + def forward(self, query: str) -> str: + try: + from duckduckgo_search import DDGS + except ImportError: + raise ImportError( + "You must install package `duckduckgo_search`: for instance run `pip install duckduckgo-search`." + ) + results = DDGS().text(query, max_results=7) + return results diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index b8b20d8e0e..f97ccc2e10 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -643,13 +643,14 @@ def launch_gradio_demo(tool_class: Tool): ).launch() -TASK_MAPPING = { - "document-question-answering": "DocumentQuestionAnsweringTool", - "image-question-answering": "ImageQuestionAnsweringTool", - "speech-to-text": "SpeechToTextTool", - "text-to-speech": "TextToSpeechTool", +TOOL_MAPPING = { + "document_question_answering": "DocumentQuestionAnsweringTool", + "image_question_answering": "ImageQuestionAnsweringTool", + "speech_to_text": "SpeechToTextTool", + "text_to_speech": "TextToSpeechTool", "translation": "TranslationTool", "python_interpreter": "PythonInterpreterTool", + "web_search": "DuckDuckGoSearchTool", } @@ -670,10 +671,10 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs): The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers are: - - `"document-question-answering"` - - `"image-question-answering"` - - `"speech-to-text"` - - `"text-to-speech"` + - `"document_question_answering"` + - `"image_question_answering"` + - `"speech_to_text"` + - `"text_to_speech"` - `"translation"` model_repo_id (`str`, *optional*): @@ -686,8 +687,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs): `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others will be passed along to its init. """ - if task_or_repo_id in TASK_MAPPING: - tool_class_name = TASK_MAPPING[task_or_repo_id] + if task_or_repo_id in TOOL_MAPPING: + tool_class_name = TOOL_MAPPING[task_or_repo_id] main_module = importlib.import_module("transformers") tools_module = main_module.agents tool_class = getattr(tools_module, tool_class_name) diff --git a/tests/agents/test_document_question_answering.py b/tests/agents/test_document_question_answering.py index 60f816c559..d135551084 100644 --- a/tests/agents/test_document_question_answering.py +++ b/tests/agents/test_document_question_answering.py @@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("document-question-answering") + self.tool = load_tool("document_question_answering") self.tool.setup() def test_exact_match_arg(self): diff --git a/tests/agents/test_image_question_answering.py b/tests/agents/test_image_question_answering.py index 1792d436dc..405933e78a 100644 --- a/tests/agents/test_image_question_answering.py +++ b/tests/agents/test_image_question_answering.py @@ -28,7 +28,7 @@ if is_vision_available(): class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("image-question-answering") + self.tool = load_tool("image_question_answering") self.tool.setup() def test_exact_match_arg(self): diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index feb923af28..84710cfec6 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -176,6 +176,23 @@ class PythonInterpreterTester(unittest.TestCase): assert result == 5 self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) + code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" + state = {} + evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state) + assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} + + def test_subscript_string_with_string_index_raises_appropriate_error(self): + code = """ +search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]" +for result in search_results: + if 'current' in result['title'].lower() or 'temperature' in result['title'].lower(): + current_weather_url = result['href'] + print(current_weather_url) + break""" + with pytest.raises(InterpreterError) as e: + evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert "You're trying to subscript a string with a string index" in e + def test_evaluate_for(self): code = "x = 0\nfor i in range(3):\n x = i" state = {} @@ -573,13 +590,6 @@ except ValueError as e: evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) assert state["exception_message"] == "An error occurred" - def test_subscript(self): - code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)" - - state = {} - evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state) - assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62} - def test_print(self): code = "print(min([1, 2, 3]))" state = {} diff --git a/tests/agents/test_search.py b/tests/agents/test_search.py new file mode 100644 index 0000000000..7e40e3ca29 --- /dev/null +++ b/tests/agents/test_search.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import load_tool + +from .test_tools_common import ToolTesterMixin + + +class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin): + def setUp(self): + self.tool = load_tool("web_search") + self.tool.setup() + + def test_exact_match_arg(self): + result = self.tool("Agents") + assert isinstance(result, list) and isinstance(result[0], dict) diff --git a/tests/agents/test_speech_to_text.py b/tests/agents/test_speech_to_text.py index 241cf9ef70..3d6e9a3929 100644 --- a/tests/agents/test_speech_to_text.py +++ b/tests/agents/test_speech_to_text.py @@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("speech-to-text") + self.tool = load_tool("speech_to_text") self.tool.setup() def test_exact_match_arg(self): diff --git a/tests/agents/test_text_to_speech.py b/tests/agents/test_text_to_speech.py index 572ec7d28d..d8ed9afcbf 100644 --- a/tests/agents/test_text_to_speech.py +++ b/tests/agents/test_text_to_speech.py @@ -30,7 +30,7 @@ from .test_tools_common import ToolTesterMixin @require_torch class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("text-to-speech") + self.tool = load_tool("text_to_speech") self.tool.setup() def test_exact_match_arg(self): diff --git a/tests/agents/test_tools_common.py b/tests/agents/test_tools_common.py index bb8881d92e..679473d0f2 100644 --- a/tests/agents/test_tools_common.py +++ b/tests/agents/test_tools_common.py @@ -90,8 +90,9 @@ class ToolTesterMixin: def test_agent_type_output(self): inputs = create_inputs(self.tool.inputs) output = self.tool(**inputs) - agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] - self.assertTrue(isinstance(output, agent_type)) + if self.tool.output_type != "any": + agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] + self.assertTrue(isinstance(output, agent_type)) def test_agent_types_inputs(self): inputs = create_inputs(self.tool.inputs) @@ -99,9 +100,3 @@ class ToolTesterMixin: for _input, expected_input in zip(inputs, self.tool.inputs.values()): input_type = expected_input["type"] _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) - - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] - - # Should not raise an error - output = self.tool(**inputs) - self.assertTrue(isinstance(output, output_type)) diff --git a/tests/agents/test_translation.py b/tests/agents/test_translation.py index e80b4e62b0..9027dd1731 100644 --- a/tests/agents/test_translation.py +++ b/tests/agents/test_translation.py @@ -44,7 +44,6 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin): def test_agent_type_output(self): inputs = ["Hey, what's up?", "English", "Spanish"] output = self.tool(*inputs) - output_type = AGENT_TYPE_MAPPING[self.tool.output_type] self.assertTrue(isinstance(output, output_type))