Add duckduckgo search tool (#32882)

* Add duckduckgo search tool
This commit is contained in:
Aymeric Roucher
2024-09-02 09:56:20 +02:00
committed by GitHub
parent b9bc691e8d
commit 1ca9ff5c91
14 changed files with 109 additions and 34 deletions

View File

@@ -283,7 +283,8 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
- **Translation**: translates a given sentence from source language to target language.
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you use `add_base_tools=True`, since code-based tools can already execute Python code
- **DuckDuckGo search***: performs a web search using DuckDuckGo browser.
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you initialize it with `add_base_tools=True`, since code-based agent can already natively execute Python code
You can manually use a tool by calling the [`load_tool`] function and a task to perform.

View File

@@ -39,6 +39,7 @@ else:
_import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
_import_structure["search"] = ["DuckDuckGoSearchTool"]
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
_import_structure["translation"] = ["TranslationTool"]
@@ -58,6 +59,7 @@ if TYPE_CHECKING:
from .default_tools import FinalAnswerTool, PythonInterpreterTool
from .document_question_answering import DocumentQuestionAnsweringTool
from .image_question_answering import ImageQuestionAnsweringTool
from .search import DuckDuckGoSearchTool
from .speech_to_text import SpeechToTextTool
from .text_to_speech import TextToSpeechTool
from .translation import TranslationTool

View File

@@ -25,7 +25,7 @@ from huggingface_hub import hf_hub_download, list_spaces
from ..utils import is_offline_mode
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import TASK_MAPPING, TOOL_CONFIG_FILE, Tool
from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
def custom_print(*args):
@@ -133,7 +133,7 @@ def setup_default_tools(logger):
main_module = importlib.import_module("transformers")
tools_module = main_module.agents
for task_name, tool_class_name in TASK_MAPPING.items():
for task_name, tool_class_name in TOOL_MAPPING.items():
tool_class = getattr(tools_module, tool_class_name)
tool_instance = tool_class()
default_tools[tool_class.name] = PreTool(

View File

@@ -444,6 +444,8 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
if isinstance(value, str) and isinstance(index, str):
raise InterpreterError("You're trying to subscript a string with a string index, which is impossible")
if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj
return parent_object.loc[index]

View File

@@ -0,0 +1,35 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .tools import Tool
class DuckDuckGoSearchTool(Tool):
name = "web_search"
description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
Each result has keys 'title', 'href' and 'body'."""
inputs = {"query": {"type": "text", "description": "The search query to perform."}}
output_type = "any"
def forward(self, query: str) -> str:
try:
from duckduckgo_search import DDGS
except ImportError:
raise ImportError(
"You must install package `duckduckgo_search`: for instance run `pip install duckduckgo-search`."
)
results = DDGS().text(query, max_results=7)
return results

View File

@@ -643,13 +643,14 @@ def launch_gradio_demo(tool_class: Tool):
).launch()
TASK_MAPPING = {
"document-question-answering": "DocumentQuestionAnsweringTool",
"image-question-answering": "ImageQuestionAnsweringTool",
"speech-to-text": "SpeechToTextTool",
"text-to-speech": "TextToSpeechTool",
TOOL_MAPPING = {
"document_question_answering": "DocumentQuestionAnsweringTool",
"image_question_answering": "ImageQuestionAnsweringTool",
"speech_to_text": "SpeechToTextTool",
"text_to_speech": "TextToSpeechTool",
"translation": "TranslationTool",
"python_interpreter": "PythonInterpreterTool",
"web_search": "DuckDuckGoSearchTool",
}
@@ -670,10 +671,10 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
are:
- `"document-question-answering"`
- `"image-question-answering"`
- `"speech-to-text"`
- `"text-to-speech"`
- `"document_question_answering"`
- `"image_question_answering"`
- `"speech_to_text"`
- `"text_to_speech"`
- `"translation"`
model_repo_id (`str`, *optional*):
@@ -686,8 +687,8 @@ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
will be passed along to its init.
"""
if task_or_repo_id in TASK_MAPPING:
tool_class_name = TASK_MAPPING[task_or_repo_id]
if task_or_repo_id in TOOL_MAPPING:
tool_class_name = TOOL_MAPPING[task_or_repo_id]
main_module = importlib.import_module("transformers")
tools_module = main_module.agents
tool_class = getattr(tools_module, tool_class_name)

View File

@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("document-question-answering")
self.tool = load_tool("document_question_answering")
self.tool.setup()
def test_exact_match_arg(self):

View File

@@ -28,7 +28,7 @@ if is_vision_available():
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-question-answering")
self.tool = load_tool("image_question_answering")
self.tool.setup()
def test_exact_match_arg(self):

View File

@@ -176,6 +176,23 @@ class PythonInterpreterTester(unittest.TestCase):
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {}
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_subscript_string_with_string_index_raises_appropriate_error(self):
code = """
search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
for result in search_results:
if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
current_weather_url = result['href']
print(current_weather_url)
break"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "You're trying to subscript a string with a string index" in e
def test_evaluate_for(self):
code = "x = 0\nfor i in range(3):\n x = i"
state = {}
@@ -573,13 +590,6 @@ except ValueError as e:
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
assert state["exception_message"] == "An error occurred"
def test_subscript(self):
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
state = {}
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
def test_print(self):
code = "print(min([1, 2, 3]))"
state = {}

View File

@@ -0,0 +1,30 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("web_search")
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool("Agents")
assert isinstance(result, list) and isinstance(result[0], dict)

View File

@@ -24,7 +24,7 @@ from .test_tools_common import ToolTesterMixin
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("speech-to-text")
self.tool = load_tool("speech_to_text")
self.tool.setup()
def test_exact_match_arg(self):

View File

@@ -30,7 +30,7 @@ from .test_tools_common import ToolTesterMixin
@require_torch
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text-to-speech")
self.tool = load_tool("text_to_speech")
self.tool.setup()
def test_exact_match_arg(self):

View File

@@ -90,8 +90,9 @@ class ToolTesterMixin:
def test_agent_type_output(self):
inputs = create_inputs(self.tool.inputs)
output = self.tool(**inputs)
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type))
if self.tool.output_type != "any":
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, agent_type))
def test_agent_types_inputs(self):
inputs = create_inputs(self.tool.inputs)
@@ -99,9 +100,3 @@ class ToolTesterMixin:
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
input_type = expected_input["type"]
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
# Should not raise an error
output = self.tool(**inputs)
self.assertTrue(isinstance(output, output_type))

View File

@@ -44,7 +44,6 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
def test_agent_type_output(self):
inputs = ["Hey, what's up?", "English", "Spanish"]
output = self.tool(*inputs)
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
self.assertTrue(isinstance(output, output_type))