Reboot Agents (#30387)
* Create CodeAgent and ReactAgent * Fix formatting errors * Update documentation for agents * Add custom errors, improve logging * Support variable usage in ReactAgent * add messages * Add message passing format * Create React Code Agent * Update * Refactoring * Fix errors * Improve python interpreter * Only non-tensor inputs should be sent to device * Calculator tool slight refactor * Improve docstrings * Refactor * Fix tests * Fix more tests * Fix even more tests * Fix tests by replacing output and input types * Fix operand type issue * two small fixes * EM TTS * Fix agent running type errors * Change text to speech tests to allow changed outputs * Update doc with new agent types * Improve code interpreter * If max iterations reached, provide a real answer instead of an error * Add edge case in interpreter * Add safe imports to the interpreter * Interpreter tweaks: tuples and listcomp * Make style * Make quality * Add dictcomp to interpreter * Rename ReactJSONAgent to ReactJsonAgent * Misc changes * ToolCollection * Rename agent's logger to self.logger * Add while loops to interpreter * Update doc with new tools. still need to mention collections * Add collections to the doc * Small fixes on logs and interpretor * Fix toolbox return type * Docs + fixup * Skip doctests * Correct prompts with improved examples and formatting * Update prompt * Remove outdated docs * Change agent to accept Toolbox object for tools * Remove calculator tool * Propagate removal of calculator in doc * Fix 2 failing workflows * Simplify additional argument passing * AgentType audio * Minor changes: function name, types * Remove calculator tests * Fix test * Fix torch requirement * Fix final answer tests * Style fixes * Fix tests * Update docstrings with calculator removal * Small type hint fixes * Update tests/agents/test_translation.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_python_interpreter.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/default_tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_agents.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/bert/configuration_bert.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/speech_to_text.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_speech_to_text.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_tools_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * pygments * Answer comments * Cleaning up * Simplifying init for all agents * Improving prompts and making code nicer * Style fixes * Add multiple comparator test in interpreter * Style fixes * Improve BERT example in documentation * Add examples to doc * Fix python interpreter quality * Logging improvements * Change test flag to agents * Quality fix * Add example for HfEngine * Improve conversation example for HfEngine * typo fix * Verify doc * Update docs/source/en/agents.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/agents.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/prompts.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/python_interpreter.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/agents.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix style issues * local s2t tool --------- Co-authored-by: Cyril Kondratenko <kkn1993@gmail.com> Co-authored-by: Lysandre <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -18,8 +18,8 @@ import unittest
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
|
||||
from transformers.tools.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
161
tests/agents/test_agents.py
Normal file
161
tests/agents/test_agents.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.agents.agent_types import AgentText
|
||||
from transformers.agents.agents import AgentMaxIterationsError, CodeAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||
from transformers.agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||
|
||||
|
||||
def fake_react_json_llm(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Action:
|
||||
{
|
||||
"action": "python_interpreter",
|
||||
"action_input": {"code": "2*3.6452"}
|
||||
}
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_llm(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = 2**3.6452
|
||||
print(result)
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer(7.2904)
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
print(result)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class AgentTests(unittest.TestCase):
|
||||
def test_fake_code_agent(self):
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
|
||||
def test_fake_react_json_agent(self):
|
||||
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[1]["observation"] == "7.2904"
|
||||
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||
assert (
|
||||
agent.logs[2]["llm_output"]
|
||||
== """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_fake_react_code_agent(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
|
||||
assert agent.logs[2]["tool_call"] == {
|
||||
"tool_arguments": "final_answer(7.2904)",
|
||||
"tool_name": "code interpreter",
|
||||
}
|
||||
|
||||
def test_setup_agent_with_empty_toolbox(self):
|
||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||
|
||||
def test_react_fails_max_iterations(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[PythonInterpreterTool()],
|
||||
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends
|
||||
max_iterations=5,
|
||||
)
|
||||
agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert len(agent.logs) == 7
|
||||
assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError
|
||||
|
||||
@require_torch
|
||||
def test_init_agent_with_different_toolsets(self):
|
||||
toolset_1 = []
|
||||
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
||||
|
||||
toolset_3 = Toolbox(toolset_2)
|
||||
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
||||
|
||||
# check that add_base_tools will not interfere with existing tools
|
||||
with pytest.raises(KeyError) as e:
|
||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
||||
assert "python_interpreter already exists in the toolbox" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
|
||||
@@ -26,7 +26,6 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("document-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
@@ -35,22 +34,8 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
result = self.tool(document, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(document, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
self.tool(document=document, question="When is the coffee break?")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(document=document, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
71
tests/agents/test_final_answer.py
Normal file
71
tests/agents/test_final_answer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.inputs = {"answer": "Final answer"}
|
||||
self.tool = load_tool("final_answer")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Final answer")
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(answer=self.inputs["answer"])
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def create_inputs(self):
|
||||
inputs_text = {"answer": "Text input"}
|
||||
inputs_image = {
|
||||
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
|
||||
(512, 512)
|
||||
)
|
||||
}
|
||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||
return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
|
||||
@require_torch
|
||||
def test_agent_type_output(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
@require_torch
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
@@ -30,24 +30,13 @@ class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
355
tests/agents/test_python_interpreter.py
Normal file
355
tests/agents/test_python_interpreter.py
Normal file
@@ -0,0 +1,355 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from transformers.agents.python_interpreter import InterpretorError, evaluate_python_code
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("python_interpreter")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(code="(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["2 * 2"]
|
||||
_inputs = []
|
||||
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
# Should not work without the tool
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert "tried to execute add_two" in str(e.value)
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_for(self):
|
||||
code = "x = 0\nfor i in range(3):\n x = i"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"range": range}, state=state)
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_binop(self):
|
||||
code = "y + x"
|
||||
state = {"x": 3, "y": 6}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 9
|
||||
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
|
||||
|
||||
def test_recursive_function(self):
|
||||
code = """
|
||||
def recur_fibo(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
else:
|
||||
return(recur_fibo(n-1) + recur_fibo(n-2))
|
||||
recur_fibo(6)"""
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 8
|
||||
|
||||
def test_evaluate_string_methods(self):
|
||||
code = "'hello'.replace('h', 'o').split('e')"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == ["o", "llo"]
|
||||
|
||||
def test_evaluate_slicing(self):
|
||||
code = "'hello'[1:3][::-1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "le"
|
||||
|
||||
def test_access_attributes(self):
|
||||
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == int
|
||||
|
||||
def test_list_comprehension(self):
|
||||
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "t-h-e-s-e-a-g-u-l-l"
|
||||
|
||||
def test_string_indexing(self):
|
||||
code = """text_block = [
|
||||
"THESE",
|
||||
"AGULL"
|
||||
]
|
||||
sentence = ""
|
||||
for block in text_block:
|
||||
for col in range(len(text_block[0])):
|
||||
sentence += block[col]
|
||||
"""
|
||||
result = evaluate_python_code(code, {"len": len, "range": range}, state={})
|
||||
assert result == "THESEAGULL"
|
||||
|
||||
def test_tuples(self):
|
||||
code = "x = (1, 2, 3)\nx[1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 2
|
||||
|
||||
def test_listcomp(self):
|
||||
code = "x = [i for i in range(3)]"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
def test_break_continue(self):
|
||||
code = "for i in range(10):\n if i == 5:\n break\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 5
|
||||
|
||||
code = "for i in range(10):\n if i == 5:\n continue\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 9
|
||||
|
||||
def test_call_int(self):
|
||||
code = "import math\nstr(math.ceil(149))"
|
||||
result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
|
||||
assert result == "149"
|
||||
|
||||
def test_lambda(self):
|
||||
code = "f = lambda x: x + 2\nf(3)"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 5
|
||||
|
||||
def test_dictcomp(self):
|
||||
code = "x = {i: i**2 for i in range(3)}"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == {0: 0, 1: 1, 2: 4}
|
||||
|
||||
def test_tuple_assignment(self):
|
||||
code = "a, b = 0, 1\nb"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
def test_while(self):
|
||||
code = "i = 0\nwhile i < 3:\n i += 1\ni"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 3
|
||||
|
||||
# test infinite loop
|
||||
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "iterations in While loop exceeded" in str(e)
|
||||
|
||||
def test_generator(self):
|
||||
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [1, 4, 9, 16, 25]
|
||||
|
||||
def test_boolops(self):
|
||||
code = """if (not (a > b and a > c)) or d > e:
|
||||
best_city = "Brooklyn"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Brooklyn"
|
||||
|
||||
code = """if d > e and a < b:
|
||||
best_city = "Brooklyn"
|
||||
elif d < e and a < b:
|
||||
best_city = "Sacramento"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Sacramento"
|
||||
|
||||
def test_if_conditions(self):
|
||||
code = """char='a'
|
||||
if char.isalpha():
|
||||
print('2')"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "2"
|
||||
|
||||
def test_imports(self):
|
||||
code = "import math\nmath.sqrt(4)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.0
|
||||
|
||||
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "lose"
|
||||
|
||||
code = "import time\ntime.sleep(0.1)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result is None
|
||||
|
||||
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
code = "import itertools\nlist(itertools.islice(range(10), 3))"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
code = "import re\nre.search('a', 'abc').group()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "a"
|
||||
|
||||
code = "import stat\nstat.S_ISREG(0o100644)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.8
|
||||
|
||||
code = "import unicodedata\nunicodedata.name('A')"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "LATIN CAPITAL LETTER A"
|
||||
|
||||
def test_multiple_comparators(self):
|
||||
code = "0x30A0 <= ord('a') <= 0x30FF"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
def test_print_output(self):
|
||||
code = "print('Hello world!')\nprint('Ok no one cares')"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert result == "Ok no one cares"
|
||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||
@@ -15,24 +15,22 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
import numpy as np
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech-to-text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
result = self.tool(np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
result = self.tool(audio=np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
@@ -38,21 +38,13 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
resulting_tensor[:3],
|
||||
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
|
||||
)
|
||||
)
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
resulting_tensor[:3],
|
||||
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
|
||||
)
|
||||
)
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
107
tests/agents/test_tools_common.py
Normal file
107
tests/agents/test_tools_common.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, is_agent_test
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = ["text", "audio", "image", "any"]
|
||||
|
||||
|
||||
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||
inputs = {}
|
||||
|
||||
for input_name, input_desc in tool_inputs.items():
|
||||
input_type = input_desc["type"]
|
||||
|
||||
if input_type == "text":
|
||||
inputs[input_name] = "Text input"
|
||||
elif input_type == "image":
|
||||
inputs[input_name] = Image.open(
|
||||
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
).resize((512, 512))
|
||||
elif input_type == "audio":
|
||||
inputs[input_name] = np.ones(3000)
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_type(output):
|
||||
if isinstance(output, (str, AgentText)):
|
||||
return "text"
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
return "image"
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
return "audio"
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
|
||||
@is_agent_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_output(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
self.assertTrue(isinstance(inputs, dict))
|
||||
|
||||
for _, input_spec in inputs.items():
|
||||
self.assertTrue("type" in input_spec)
|
||||
self.assertTrue("description" in input_spec)
|
||||
self.assertTrue(input_spec["type"] in AUTHORIZED_TYPES)
|
||||
self.assertTrue(isinstance(input_spec["description"], str))
|
||||
|
||||
output_type = self.tool.output_type
|
||||
self.assertTrue(output_type in AUTHORIZED_TYPES)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "name"))
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
output = self.tool(**inputs)
|
||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
_inputs = []
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**inputs)
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
68
tests/agents/test_translation.py
Normal file
68
tests/agents/test_translation.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_type
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
self.assertEqual(output_type(output), self.tool.output_type)
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
example_inputs = {
|
||||
"text": "Hey, what's up?",
|
||||
"src_lang": "English",
|
||||
"tgt_lang": "Spanish",
|
||||
}
|
||||
|
||||
_inputs = []
|
||||
for input_name in example_inputs.keys():
|
||||
example_input = example_inputs[input_name]
|
||||
input_description = self.tool.inputs[input_name]
|
||||
input_type = input_description["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](example_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**example_inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
@@ -1,53 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-captioning")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-captioning", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
@@ -1,53 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-segmentation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-segmentation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
@@ -1,131 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
from transformers.tools.python_interpreter import evaluate
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
# Won't work without the tool
|
||||
with CaptureStdout() as out:
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result is None
|
||||
assert "tried to execute add_two" in out.out
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3."})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
|
||||
def test_evaluate_for(self):
|
||||
code = "x = 0\nfor i in range(3):\n x = i"
|
||||
state = {}
|
||||
result = evaluate(code, {"range": range}, state=state)
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 2, "i": 2})
|
||||
@@ -1,43 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-classification")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-classification", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
@@ -1,52 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
@@ -1,64 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("summarization")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("summarization", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
@@ -1,133 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import get_tests_dir, is_tool_test
|
||||
from transformers.tools.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
authorized_types = ["text", "image", "audio"]
|
||||
|
||||
|
||||
def create_inputs(input_types: List[str]):
|
||||
inputs = []
|
||||
|
||||
for input_type in input_types:
|
||||
if input_type == "text":
|
||||
inputs.append("Text input")
|
||||
elif input_type == "image":
|
||||
inputs.append(
|
||||
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
)
|
||||
elif input_type == "audio":
|
||||
inputs.append(torch.ones(3000))
|
||||
elif isinstance(input_type, list):
|
||||
inputs.append(create_inputs(input_type))
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_types(outputs: List):
|
||||
output_types = []
|
||||
|
||||
for output in outputs:
|
||||
if isinstance(output, (str, AgentText)):
|
||||
output_types.append("text")
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
output_types.append("image")
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
output_types.append("audio")
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
return output_types
|
||||
|
||||
|
||||
@is_tool_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_outputs(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "outputs"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
for _input in inputs:
|
||||
if isinstance(_input, list):
|
||||
for __input in _input:
|
||||
self.assertTrue(__input in authorized_types)
|
||||
else:
|
||||
self.assertTrue(_input in authorized_types)
|
||||
|
||||
outputs = self.tool.outputs
|
||||
for _output in outputs:
|
||||
self.assertTrue(_output in authorized_types)
|
||||
|
||||
def test_call(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
||||
|
||||
def test_agent_types_outputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
|
||||
for output, output_type in zip(outputs, self.tool.outputs):
|
||||
agent_type = AGENT_TYPE_MAPPING[output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
|
||||
_inputs = []
|
||||
|
||||
for _input, input_type in zip(inputs, self.tool.inputs):
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
@@ -1,86 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.tools.agent_types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_types
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_agent_types_outputs(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
|
||||
for output, output_type in zip(outputs, self.tool.outputs):
|
||||
agent_type = AGENT_TYPE_MAPPING[output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
|
||||
_inputs = []
|
||||
|
||||
for _input, input_type in zip(inputs, self.tool.inputs):
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
Reference in New Issue
Block a user