Reboot Agents (#30387)
* Create CodeAgent and ReactAgent * Fix formatting errors * Update documentation for agents * Add custom errors, improve logging * Support variable usage in ReactAgent * add messages * Add message passing format * Create React Code Agent * Update * Refactoring * Fix errors * Improve python interpreter * Only non-tensor inputs should be sent to device * Calculator tool slight refactor * Improve docstrings * Refactor * Fix tests * Fix more tests * Fix even more tests * Fix tests by replacing output and input types * Fix operand type issue * two small fixes * EM TTS * Fix agent running type errors * Change text to speech tests to allow changed outputs * Update doc with new agent types * Improve code interpreter * If max iterations reached, provide a real answer instead of an error * Add edge case in interpreter * Add safe imports to the interpreter * Interpreter tweaks: tuples and listcomp * Make style * Make quality * Add dictcomp to interpreter * Rename ReactJSONAgent to ReactJsonAgent * Misc changes * ToolCollection * Rename agent's logger to self.logger * Add while loops to interpreter * Update doc with new tools. still need to mention collections * Add collections to the doc * Small fixes on logs and interpretor * Fix toolbox return type * Docs + fixup * Skip doctests * Correct prompts with improved examples and formatting * Update prompt * Remove outdated docs * Change agent to accept Toolbox object for tools * Remove calculator tool * Propagate removal of calculator in doc * Fix 2 failing workflows * Simplify additional argument passing * AgentType audio * Minor changes: function name, types * Remove calculator tests * Fix test * Fix torch requirement * Fix final answer tests * Style fixes * Fix tests * Update docstrings with calculator removal * Small type hint fixes * Update tests/agents/test_translation.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_python_interpreter.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/default_tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_agents.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/bert/configuration_bert.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/speech_to_text.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_speech_to_text.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_tools_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * pygments * Answer comments * Cleaning up * Simplifying init for all agents * Improving prompts and making code nicer * Style fixes * Add multiple comparator test in interpreter * Style fixes * Improve BERT example in documentation * Add examples to doc * Fix python interpreter quality * Logging improvements * Change test flag to agents * Quality fix * Add example for HfEngine * Improve conversation example for HfEngine * typo fix * Verify doc * Update docs/source/en/agents.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/agents.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/prompts.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/python_interpreter.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/agents.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix style issues * local s2t tool --------- Co-authored-by: Cyril Kondratenko <kkn1993@gmail.com> Co-authored-by: Lysandre <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
0
tests/agents/__init__.py
Normal file
0
tests/agents/__init__.py
Normal file
121
tests/agents/test_agent_types.py
Normal file
121
tests/agents/test_agent_types.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
|
||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_soundfile_availble():
|
||||
import soundfile as sf
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||
|
||||
|
||||
@require_soundfile
|
||||
@require_torch
|
||||
class AgentAudioTests(unittest.TestCase):
|
||||
def test_from_tensor(self):
|
||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||
agent_type = AgentAudio(tensor)
|
||||
path = str(agent_type.to_string())
|
||||
|
||||
# Ensure that the tensor and the agent_type's tensor are the same
|
||||
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
|
||||
|
||||
del agent_type
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
# Ensure that the file contains the same value as the original tensor
|
||||
new_tensor, _ = sf.read(path)
|
||||
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
||||
|
||||
def test_from_string(self):
|
||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||
path = get_new_path(suffix=".wav")
|
||||
sf.write(path, tensor, 16000)
|
||||
|
||||
agent_type = AgentAudio(path)
|
||||
|
||||
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
|
||||
self.assertEqual(agent_type.to_string(), path)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class AgentImageTests(unittest.TestCase):
|
||||
def test_from_tensor(self):
|
||||
tensor = torch.randint(0, 256, (64, 64, 3))
|
||||
agent_type = AgentImage(tensor)
|
||||
path = str(agent_type.to_string())
|
||||
|
||||
# Ensure that the tensor and the agent_type's tensor are the same
|
||||
self.assertTrue(torch.allclose(tensor, agent_type._tensor, atol=1e-4))
|
||||
|
||||
self.assertIsInstance(agent_type.to_raw(), Image.Image)
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
del agent_type
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_from_string(self):
|
||||
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
image = Image.open(path)
|
||||
agent_type = AgentImage(path)
|
||||
|
||||
self.assertTrue(path.samefile(agent_type.to_string()))
|
||||
self.assertTrue(image == agent_type.to_raw())
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
del agent_type
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_from_image(self):
|
||||
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
image = Image.open(path)
|
||||
agent_type = AgentImage(image)
|
||||
|
||||
self.assertFalse(path.samefile(agent_type.to_string()))
|
||||
self.assertTrue(image == agent_type.to_raw())
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
del agent_type
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
|
||||
class AgentTextTests(unittest.TestCase):
|
||||
def test_from_string(self):
|
||||
string = "Hey!"
|
||||
agent_type = AgentText(string)
|
||||
|
||||
self.assertEqual(string, agent_type.to_string())
|
||||
self.assertEqual(string, agent_type.to_raw())
|
||||
self.assertEqual(string, agent_type)
|
||||
161
tests/agents/test_agents.py
Normal file
161
tests/agents/test_agents.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.agents.agent_types import AgentText
|
||||
from transformers.agents.agents import AgentMaxIterationsError, CodeAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||
from transformers.agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||
|
||||
|
||||
def fake_react_json_llm(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Action:
|
||||
{
|
||||
"action": "python_interpreter",
|
||||
"action_input": {"code": "2*3.6452"}
|
||||
}
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_llm(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = 2**3.6452
|
||||
print(result)
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer(7.2904)
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
print(result)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class AgentTests(unittest.TestCase):
|
||||
def test_fake_code_agent(self):
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
|
||||
def test_fake_react_json_agent(self):
|
||||
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[1]["observation"] == "7.2904"
|
||||
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||
assert (
|
||||
agent.logs[2]["llm_output"]
|
||||
== """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_fake_react_code_agent(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
|
||||
assert agent.logs[2]["tool_call"] == {
|
||||
"tool_arguments": "final_answer(7.2904)",
|
||||
"tool_name": "code interpreter",
|
||||
}
|
||||
|
||||
def test_setup_agent_with_empty_toolbox(self):
|
||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||
|
||||
def test_react_fails_max_iterations(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[PythonInterpreterTool()],
|
||||
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends
|
||||
max_iterations=5,
|
||||
)
|
||||
agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert len(agent.logs) == 7
|
||||
assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError
|
||||
|
||||
@require_torch
|
||||
def test_init_agent_with_different_toolsets(self):
|
||||
toolset_1 = []
|
||||
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
||||
|
||||
toolset_3 = Toolbox(toolset_2)
|
||||
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
||||
|
||||
# check that add_base_tools will not interfere with existing tools
|
||||
with pytest.raises(KeyError) as e:
|
||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
||||
assert "python_interpreter already exists in the toolbox" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
|
||||
41
tests/agents/test_document_question_answering.py
Normal file
41
tests/agents/test_document_question_answering.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document-question-answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
result = self.tool(document, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
self.tool(document=document, question="When is the coffee break?")
|
||||
71
tests/agents/test_final_answer.py
Normal file
71
tests/agents/test_final_answer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.inputs = {"answer": "Final answer"}
|
||||
self.tool = load_tool("final_answer")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Final answer")
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(answer=self.inputs["answer"])
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def create_inputs(self):
|
||||
inputs_text = {"answer": "Text input"}
|
||||
inputs_image = {
|
||||
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
|
||||
(512, 512)
|
||||
)
|
||||
}
|
||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||
return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
|
||||
@require_torch
|
||||
def test_agent_type_output(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
@require_torch
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
42
tests/agents/test_image_question_answering.py
Normal file
42
tests/agents/test_image_question_answering.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-question-answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
355
tests/agents/test_python_interpreter.py
Normal file
355
tests/agents/test_python_interpreter.py
Normal file
@@ -0,0 +1,355 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from transformers.agents.python_interpreter import InterpretorError, evaluate_python_code
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("python_interpreter")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(code="(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["2 * 2"]
|
||||
_inputs = []
|
||||
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
# Should not work without the tool
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert "tried to execute add_two" in str(e.value)
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_for(self):
|
||||
code = "x = 0\nfor i in range(3):\n x = i"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"range": range}, state=state)
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_binop(self):
|
||||
code = "y + x"
|
||||
state = {"x": 3, "y": 6}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 9
|
||||
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
|
||||
|
||||
def test_recursive_function(self):
|
||||
code = """
|
||||
def recur_fibo(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
else:
|
||||
return(recur_fibo(n-1) + recur_fibo(n-2))
|
||||
recur_fibo(6)"""
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 8
|
||||
|
||||
def test_evaluate_string_methods(self):
|
||||
code = "'hello'.replace('h', 'o').split('e')"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == ["o", "llo"]
|
||||
|
||||
def test_evaluate_slicing(self):
|
||||
code = "'hello'[1:3][::-1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "le"
|
||||
|
||||
def test_access_attributes(self):
|
||||
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == int
|
||||
|
||||
def test_list_comprehension(self):
|
||||
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "t-h-e-s-e-a-g-u-l-l"
|
||||
|
||||
def test_string_indexing(self):
|
||||
code = """text_block = [
|
||||
"THESE",
|
||||
"AGULL"
|
||||
]
|
||||
sentence = ""
|
||||
for block in text_block:
|
||||
for col in range(len(text_block[0])):
|
||||
sentence += block[col]
|
||||
"""
|
||||
result = evaluate_python_code(code, {"len": len, "range": range}, state={})
|
||||
assert result == "THESEAGULL"
|
||||
|
||||
def test_tuples(self):
|
||||
code = "x = (1, 2, 3)\nx[1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 2
|
||||
|
||||
def test_listcomp(self):
|
||||
code = "x = [i for i in range(3)]"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
def test_break_continue(self):
|
||||
code = "for i in range(10):\n if i == 5:\n break\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 5
|
||||
|
||||
code = "for i in range(10):\n if i == 5:\n continue\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 9
|
||||
|
||||
def test_call_int(self):
|
||||
code = "import math\nstr(math.ceil(149))"
|
||||
result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
|
||||
assert result == "149"
|
||||
|
||||
def test_lambda(self):
|
||||
code = "f = lambda x: x + 2\nf(3)"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 5
|
||||
|
||||
def test_dictcomp(self):
|
||||
code = "x = {i: i**2 for i in range(3)}"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == {0: 0, 1: 1, 2: 4}
|
||||
|
||||
def test_tuple_assignment(self):
|
||||
code = "a, b = 0, 1\nb"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
def test_while(self):
|
||||
code = "i = 0\nwhile i < 3:\n i += 1\ni"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 3
|
||||
|
||||
# test infinite loop
|
||||
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "iterations in While loop exceeded" in str(e)
|
||||
|
||||
def test_generator(self):
|
||||
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [1, 4, 9, 16, 25]
|
||||
|
||||
def test_boolops(self):
|
||||
code = """if (not (a > b and a > c)) or d > e:
|
||||
best_city = "Brooklyn"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Brooklyn"
|
||||
|
||||
code = """if d > e and a < b:
|
||||
best_city = "Brooklyn"
|
||||
elif d < e and a < b:
|
||||
best_city = "Sacramento"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Sacramento"
|
||||
|
||||
def test_if_conditions(self):
|
||||
code = """char='a'
|
||||
if char.isalpha():
|
||||
print('2')"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "2"
|
||||
|
||||
def test_imports(self):
|
||||
code = "import math\nmath.sqrt(4)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.0
|
||||
|
||||
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "lose"
|
||||
|
||||
code = "import time\ntime.sleep(0.1)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result is None
|
||||
|
||||
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
code = "import itertools\nlist(itertools.islice(range(10), 3))"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
code = "import re\nre.search('a', 'abc').group()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "a"
|
||||
|
||||
code = "import stat\nstat.S_ISREG(0o100644)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.8
|
||||
|
||||
code = "import unicodedata\nunicodedata.name('A')"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "LATIN CAPITAL LETTER A"
|
||||
|
||||
def test_multiple_comparators(self):
|
||||
code = "0x30A0 <= ord('a') <= 0x30FF"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
def test_print_output(self):
|
||||
code = "print('Hello world!')\nprint('Ok no one cares')"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert result == "Ok no one cares"
|
||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||
36
tests/agents/test_speech_to_text.py
Normal file
36
tests/agents/test_speech_to_text.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech-to-text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
50
tests/agents/test_text_to_speech.py
Normal file
50
tests/agents/test_text_to_speech.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-to-speech")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
107
tests/agents/test_tools_common.py
Normal file
107
tests/agents/test_tools_common.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, is_agent_test
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = ["text", "audio", "image", "any"]
|
||||
|
||||
|
||||
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||
inputs = {}
|
||||
|
||||
for input_name, input_desc in tool_inputs.items():
|
||||
input_type = input_desc["type"]
|
||||
|
||||
if input_type == "text":
|
||||
inputs[input_name] = "Text input"
|
||||
elif input_type == "image":
|
||||
inputs[input_name] = Image.open(
|
||||
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
).resize((512, 512))
|
||||
elif input_type == "audio":
|
||||
inputs[input_name] = np.ones(3000)
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_type(output):
|
||||
if isinstance(output, (str, AgentText)):
|
||||
return "text"
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
return "image"
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
return "audio"
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
|
||||
@is_agent_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_output(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
self.assertTrue(isinstance(inputs, dict))
|
||||
|
||||
for _, input_spec in inputs.items():
|
||||
self.assertTrue("type" in input_spec)
|
||||
self.assertTrue("description" in input_spec)
|
||||
self.assertTrue(input_spec["type"] in AUTHORIZED_TYPES)
|
||||
self.assertTrue(isinstance(input_spec["description"], str))
|
||||
|
||||
output_type = self.tool.output_type
|
||||
self.assertTrue(output_type in AUTHORIZED_TYPES)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "name"))
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
output = self.tool(**inputs)
|
||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
_inputs = []
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**inputs)
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
68
tests/agents/test_translation.py
Normal file
68
tests/agents/test_translation.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_type
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
self.assertEqual(output_type(output), self.tool.output_type)
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
example_inputs = {
|
||||
"text": "Hey, what's up?",
|
||||
"src_lang": "English",
|
||||
"tgt_lang": "Spanish",
|
||||
}
|
||||
|
||||
_inputs = []
|
||||
for input_name in example_inputs.keys():
|
||||
example_input = example_inputs[input_name]
|
||||
input_description = self.tool.inputs[input_name]
|
||||
input_type = input_description["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](example_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**example_inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
Reference in New Issue
Block a user