Add duckduckgo search tool (#32882)

* Add duckduckgo search tool
This commit is contained in:
Aymeric Roucher
2024-09-02 09:56:20 +02:00
committed by GitHub
parent b9bc691e8d
commit 1ca9ff5c91
14 changed files with 109 additions and 34 deletions

View File

@@ -283,7 +283,8 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper)) - **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5)) - **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
- **Translation**: translates a given sentence from source language to target language. - **Translation**: translates a given sentence from source language to target language.
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you use `add_base_tools=True`, since code-based tools can already execute Python code - **DuckDuckGo search***: performs a web search using DuckDuckGo browser.
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
You can manually use a tool by calling the [`load_tool`] function and a task to perform. You can manually use a tool by calling the [`load_tool`] function and a task to perform.

View File

@@ -39,6 +39,7 @@ else:
_import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"] _import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"] _import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"] _import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
_import_structure["search"] = ["DuckDuckGoSearchTool"]
_import_structure["speech_to_text"] = ["SpeechToTextTool"] _import_structure["speech_to_text"] = ["SpeechToTextTool"]
_import_structure["text_to_speech"] = ["TextToSpeechTool"] _import_structure["text_to_speech"] = ["TextToSpeechTool"]
_import_structure["translation"] = ["TranslationTool"] _import_structure["translation"] = ["TranslationTool"]
@@ -58,6 +59,7 @@ if TYPE_CHECKING:
from .default_tools import FinalAnswerTool, PythonInterpreterTool from .default_tools import FinalAnswerTool, PythonInterpreterTool
from .document_question_answering import DocumentQuestionAnsweringTool from .document_question_answering import DocumentQuestionAnsweringTool
from .image_question_answering import ImageQuestionAnsweringTool from .image_question_answering import ImageQuestionAnsweringTool
from .search import DuckDuckGoSearchTool
from .speech_to_text import SpeechToTextTool from .speech_to_text import SpeechToTextTool
from .text_to_speech import TextToSpeechTool from .text_to_speech import TextToSpeechTool
from .translation import TranslationTool from .translation import TranslationTool

View File

@@ -25,7 +25,7 @@ from huggingface_hub import hf_hub_download, list_spaces
from ..utils import is_offline_mode from ..utils import is_offline_mode
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TASK_MAPPING, TOOL_CONFIG_FILE, Tool from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
def custom_print(*args): def custom_print(*args):
@@ -133,7 +133,7 @@ def setup_default_tools(logger):
main_module = importlib.import_module("transformers") main_module = importlib.import_module("transformers")
tools_module = main_module.agents tools_module = main_module.agents
for task_name, tool_class_name in TASK_MAPPING.items(): for task_name, tool_class_name in TOOL_MAPPING.items():
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)
tool_instance = tool_class() tool_instance = tool_class()
default_tools[tool_class.name] = PreTool( default_tools[tool_class.name] = PreTool(

View File

@@ -444,6 +444,8 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools) index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
value = evaluate_ast(subscript.value, state, static_tools, custom_tools) value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
if isinstance(value, str) and isinstance(index, str):
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
if isinstance(value, pd.core.indexing._LocIndexer): if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj parent_object = value.obj
return parent_object.loc[index] return parent_object.loc[index]

View File

@@ -0,0 +1,35 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tools import Tool
class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
output_type = "any"
def forward(self, query: str) -> str:
try:
from duckduckgo_search import DDGS
except ImportError:
raise ImportError(
"You must install package `duckduckgo_search`: for instance run `pip install duckduckgo-search`."
)
results = DDGS().text(query, max_results=7)
return results

View File

@@ -643,13 +643,14 @@ def launch_gradio_demo(tool_class: Tool):
).launch() ).launch()
TASK_MAPPING = { TOOL_MAPPING = {
"document-question-answering": "DocumentQuestionAnsweringTool", "document_question_answering": "DocumentQuestionAnsweringTool",
"image-question-answering": "ImageQuestionAnsweringTool", "image_question_answering": "ImageQuestionAnsweringTool",
"speech-to-text": "SpeechToTextTool", "speech_to_text": "SpeechToTextTool",
"text-to-speech": "TextToSpeechTool", "text_to_speech": "TextToSpeechTool",
"translation": "TranslationTool", "translation": "TranslationTool",
"python_interpreter": "PythonInterpreterTool", "python_interpreter": "PythonInterpreterTool",
"web_search": "DuckDuckGoSearchTool",
} }
@@ -670,10 +671,10 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
are: are:
- `"document-question-answering"` - `"document_question_answering"`
- `"image-question-answering"` - `"image_question_answering"`
- `"speech-to-text"` - `"speech_to_text"`
- `"text-to-speech"` - `"text_to_speech"`
- `"translation"` - `"translation"`
model_repo_id (`str`, *optional*): model_repo_id (`str`, *optional*):
@@ -686,8 +687,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
will be passed along to its init. will be passed along to its init.
""" """
if task_or_repo_id in TASK_MAPPING: if task_or_repo_id in TOOL_MAPPING:
tool_class_name = TASK_MAPPING[task_or_repo_id] tool_class_name = TOOL_MAPPING[task_or_repo_id]
main_module = importlib.import_module("transformers") main_module = importlib.import_module("transformers")
tools_module = main_module.agents tools_module = main_module.agents
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)

View File

@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin): class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.tool = load_tool("document-question-answering") self.tool = load_tool("document_question_answering")
self.tool.setup() self.tool.setup()
def test_exact_match_arg(self): def test_exact_match_arg(self):

View File

@@ -28,7 +28,7 @@ if is_vision_available():
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin): class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.tool = load_tool("image-question-answering") self.tool = load_tool("image_question_answering")
self.tool.setup() self.tool.setup()
def test_exact_match_arg(self): def test_exact_match_arg(self):

View File

@@ -176,6 +176,23 @@ class PythonInterpreterTester(unittest.TestCase):
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {}
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_subscript_string_with_string_index_raises_appropriate_error(self):
code = """
search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
for result in search_results:
if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
current_weather_url = result['href']
print(current_weather_url)
break"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "You're trying to subscript a string with a string index" in e
def test_evaluate_for(self): def test_evaluate_for(self):
code = "x = 0\nfor i in range(3):\n x = i" code = "x = 0\nfor i in range(3):\n x = i"
state = {} state = {}
@@ -573,13 +590,6 @@ except ValueError as e:
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state) evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
assert state["exception_message"] == "An error occurred" assert state["exception_message"] == "An error occurred"
def test_subscript(self):
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {}
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_print(self): def test_print(self):
code = "print(min([1, 2, 3]))" code = "print(min([1, 2, 3]))"
state = {} state = {}

View File

@@ -0,0 +1,30 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("web_search")
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool("Agents")
assert isinstance(result, list) and isinstance(result[0], dict)

View File

@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin): class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.tool = load_tool("speech-to-text") self.tool = load_tool("speech_to_text")
self.tool.setup() self.tool.setup()
def test_exact_match_arg(self): def test_exact_match_arg(self):

View File

@@ -30,7 +30,7 @@ from .test_tools_common import ToolTesterMixin
@require_torch @require_torch
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin): class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.tool = load_tool("text-to-speech") self.tool = load_tool("text_to_speech")
self.tool.setup() self.tool.setup()
def test_exact_match_arg(self): def test_exact_match_arg(self):

View File

@@ -90,8 +90,9 @@ class ToolTesterMixin:
def test_agent_type_output(self): def test_agent_type_output(self):
inputs = create_inputs(self.tool.inputs) inputs = create_inputs(self.tool.inputs)
output = self.tool(**inputs) output = self.tool(**inputs)
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type] if self.tool.output_type != "any":
self.assertTrue(isinstance(output, agent_type)) agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type))
def test_agent_types_inputs(self): def test_agent_types_inputs(self):
inputs = create_inputs(self.tool.inputs) inputs = create_inputs(self.tool.inputs)
@@ -99,9 +100,3 @@ class ToolTesterMixin:
for _input, expected_input in zip(inputs, self.tool.inputs.values()): for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"] input_type = expected_input["type"]
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
# Should not raise an error
output = self.tool(**inputs)
self.assertTrue(isinstance(output, output_type))

View File

@@ -44,7 +44,6 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
def test_agent_type_output(self): def test_agent_type_output(self):
inputs = ["Hey, what's up?", "English", "Spanish"] inputs = ["Hey, what's up?", "English", "Spanish"]
output = self.tool(*inputs) output = self.tool(*inputs)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type] output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type)) self.assertTrue(isinstance(output, output_type))