[agents] remove agents 🧹 (#37368)
This commit is contained in:
@@ -1,120 +0,0 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
|
||||
from transformers.utils import is_soundfile_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_soundfile_available():
|
||||
import soundfile as sf
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||
|
||||
|
||||
@require_soundfile
|
||||
@require_torch
|
||||
class AgentAudioTests(unittest.TestCase):
|
||||
def test_from_tensor(self):
|
||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||
agent_type = AgentAudio(tensor)
|
||||
path = str(agent_type.to_string())
|
||||
|
||||
# Ensure that the tensor and the agent_type's tensor are the same
|
||||
torch.testing.assert_close(tensor, agent_type.to_raw(), rtol=1e-4, atol=1e-4)
|
||||
|
||||
del agent_type
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
# Ensure that the file contains the same value as the original tensor
|
||||
new_tensor, _ = sf.read(path)
|
||||
torch.testing.assert_close(tensor, torch.tensor(new_tensor), rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_from_string(self):
|
||||
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||
path = get_new_path(suffix=".wav")
|
||||
sf.write(path, tensor, 16000)
|
||||
|
||||
agent_type = AgentAudio(path)
|
||||
|
||||
torch.testing.assert_close(tensor, agent_type.to_raw(), rtol=1e-4, atol=1e-4)
|
||||
self.assertEqual(agent_type.to_string(), path)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class AgentImageTests(unittest.TestCase):
|
||||
def test_from_tensor(self):
|
||||
tensor = torch.randint(0, 256, (64, 64, 3))
|
||||
agent_type = AgentImage(tensor)
|
||||
path = str(agent_type.to_string())
|
||||
|
||||
# Ensure that the tensor and the agent_type's tensor are the same
|
||||
torch.testing.assert_close(tensor, agent_type._tensor, rtol=1e-4, atol=1e-4)
|
||||
|
||||
self.assertIsInstance(agent_type.to_raw(), Image.Image)
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
del agent_type
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_from_string(self):
|
||||
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
image = Image.open(path)
|
||||
agent_type = AgentImage(path)
|
||||
|
||||
self.assertTrue(path.samefile(agent_type.to_string()))
|
||||
self.assertTrue(image == agent_type.to_raw())
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
del agent_type
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
def test_from_image(self):
|
||||
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
image = Image.open(path)
|
||||
agent_type = AgentImage(image)
|
||||
|
||||
self.assertFalse(path.samefile(agent_type.to_string()))
|
||||
self.assertTrue(image == agent_type.to_raw())
|
||||
|
||||
# Ensure the path remains even after the object deletion
|
||||
del agent_type
|
||||
self.assertTrue(os.path.exists(path))
|
||||
|
||||
|
||||
class AgentTextTests(unittest.TestCase):
|
||||
def test_from_string(self):
|
||||
string = "Hey!"
|
||||
agent_type = AgentText(string)
|
||||
|
||||
self.assertEqual(string, agent_type.to_string())
|
||||
self.assertEqual(string, agent_type.to_raw())
|
||||
self.assertEqual(string, agent_type)
|
||||
@@ -1,257 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.agents.agent_types import AgentText
|
||||
from transformers.agents.agents import (
|
||||
AgentMaxIterationsError,
|
||||
CodeAgent,
|
||||
ManagedAgent,
|
||||
ReactCodeAgent,
|
||||
ReactJsonAgent,
|
||||
Toolbox,
|
||||
)
|
||||
from transformers.agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||
|
||||
|
||||
def fake_react_json_llm(messages, stop_sequences=None, grammar=None) -> str:
|
||||
prompt = str(messages)
|
||||
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Action:
|
||||
{
|
||||
"action": "python_interpreter",
|
||||
"action_input": {"code": "2*3.6452"}
|
||||
}
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_llm(messages, stop_sequences=None, grammar=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = 2**3.6452
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer(7.2904)
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
print = 2
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer("got an error")
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: Let's define the function. special_marker
|
||||
Code:
|
||||
```py
|
||||
import numpy as np
|
||||
|
||||
def moving_average(x, w):
|
||||
return np.convolve(x, np.ones(w), 'valid') / w
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
x, w = [0, 1, 2, 3, 4, 5], 2
|
||||
res = moving_average(x, w)
|
||||
final_answer(res)
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_oneshot(messages, stop_sequences=None, grammar=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
final_answer(result)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_no_return(messages, stop_sequences=None, grammar=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
print(result)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class AgentTests(unittest.TestCase):
|
||||
def test_fake_code_agent(self):
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
|
||||
def test_fake_react_json_agent(self):
|
||||
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[1]["observation"] == "7.2904"
|
||||
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||
assert (
|
||||
agent.logs[2]["llm_output"]
|
||||
== """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_fake_react_code_agent(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, float)
|
||||
assert output == 7.2904
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[2]["tool_call"] == {
|
||||
"tool_arguments": "final_answer(7.2904)",
|
||||
"tool_name": "code interpreter",
|
||||
}
|
||||
|
||||
def test_react_code_agent_code_errors_show_offending_lines(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "got an error"
|
||||
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
||||
|
||||
def test_setup_agent_with_empty_toolbox(self):
|
||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||
|
||||
def test_react_fails_max_iterations(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[PythonInterpreterTool()],
|
||||
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
|
||||
max_iterations=5,
|
||||
)
|
||||
agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert len(agent.logs) == 7
|
||||
assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError
|
||||
|
||||
@require_torch
|
||||
def test_init_agent_with_different_toolsets(self):
|
||||
toolset_1 = []
|
||||
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 1
|
||||
) # when no tools are provided, only the final_answer tool is added by default
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 2
|
||||
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||
|
||||
toolset_3 = Toolbox(toolset_2)
|
||||
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||
assert (
|
||||
len(agent.toolbox.tools) == 2
|
||||
) # same as previous one, where toolset_3 is an instantiation of previous one
|
||||
|
||||
# check that add_base_tools will not interfere with existing tools
|
||||
with pytest.raises(KeyError) as e:
|
||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
||||
assert "already exists in the toolbox" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert len(agent.toolbox.tools) == 7 # added final_answer tool + 6 base tools (excluding interpreter)
|
||||
|
||||
def test_function_persistence_across_steps(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
|
||||
)
|
||||
res = agent.run("ok")
|
||||
assert res[0] == 0.5
|
||||
|
||||
def test_init_managed_agent(self):
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||
assert managed_agent.name == "managed_agent"
|
||||
assert managed_agent.description == "Empty"
|
||||
|
||||
def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_functiondef)
|
||||
managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
|
||||
manager_agent = ReactCodeAgent(
|
||||
tools=[], llm_engine=fake_react_code_functiondef, managed_agents=[managed_agent]
|
||||
)
|
||||
assert "You can also give requests to team members." not in agent.system_prompt
|
||||
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
|
||||
assert "You can also give requests to team members." in manager_agent.system_prompt
|
||||
@@ -1,40 +0,0 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document_question_answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
result = self.tool(document, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
self.tool(document=document, question="When is the coffee break?")
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import FinalAnswerTool
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.inputs = {"answer": "Final answer"}
|
||||
self.tool = FinalAnswerTool()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Final answer")
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(answer=self.inputs["answer"])
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def create_inputs(self):
|
||||
inputs_text = {"answer": "Text input"}
|
||||
inputs_image = {
|
||||
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
|
||||
(512, 512)
|
||||
)
|
||||
}
|
||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||
return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
|
||||
@require_torch
|
||||
def test_agent_type_output(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
@require_torch
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
@@ -1,41 +0,0 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image_question_answering")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
@@ -1,165 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.agents.agent_types import AgentImage
|
||||
from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent
|
||||
from transformers.agents.monitoring import stream_to_gradio
|
||||
|
||||
|
||||
class MonitoringTester(unittest.TestCase):
|
||||
def test_code_agent_metrics(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return """
|
||||
Code:
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_json_agent_metrics(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
|
||||
agent = ReactJsonAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 10)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 20)
|
||||
|
||||
def test_code_agent_metrics_max_iterations(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
return "Malformed answer"
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
|
||||
def test_code_agent_metrics_generation_error(self):
|
||||
class FakeLLMEngine:
|
||||
def __init__(self):
|
||||
self.last_input_token_count = 10
|
||||
self.last_output_token_count = 20
|
||||
|
||||
def __call__(self, prompt, **kwargs):
|
||||
raise AgentError
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=FakeLLMEngine(),
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
agent.run("Fake task")
|
||||
|
||||
self.assertEqual(agent.monitor.total_input_token_count, 20)
|
||||
self.assertEqual(agent.monitor.total_output_token_count, 40)
|
||||
|
||||
def test_streaming_agent_text_output(self):
|
||||
def dummy_llm_engine(prompt, **kwargs):
|
||||
return """
|
||||
Code:
|
||||
```py
|
||||
final_answer('This is the final answer.')
|
||||
```"""
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=dummy_llm_engine,
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
# Use stream_to_gradio to capture the output
|
||||
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||
|
||||
self.assertEqual(len(outputs), 3)
|
||||
final_message = outputs[-1]
|
||||
self.assertEqual(final_message.role, "assistant")
|
||||
self.assertIn("This is the final answer.", final_message.content)
|
||||
|
||||
def test_streaming_agent_image_output(self):
|
||||
def dummy_llm_engine(prompt, **kwargs):
|
||||
return 'Action:{"action": "final_answer", "action_input": {"answer": "image"}}'
|
||||
|
||||
agent = ReactJsonAgent(
|
||||
tools=[],
|
||||
llm_engine=dummy_llm_engine,
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
# Use stream_to_gradio to capture the output
|
||||
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))
|
||||
|
||||
self.assertEqual(len(outputs), 2)
|
||||
final_message = outputs[-1]
|
||||
self.assertEqual(final_message.role, "assistant")
|
||||
self.assertIsInstance(final_message.content, dict)
|
||||
self.assertEqual(final_message.content["path"], "path.png")
|
||||
self.assertEqual(final_message.content["mime_type"], "image/png")
|
||||
|
||||
def test_streaming_with_agent_error(self):
|
||||
def dummy_llm_engine(prompt, **kwargs):
|
||||
raise AgentError("Simulated agent error")
|
||||
|
||||
agent = ReactCodeAgent(
|
||||
tools=[],
|
||||
llm_engine=dummy_llm_engine,
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
# Use stream_to_gradio to capture the output
|
||||
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
|
||||
|
||||
self.assertEqual(len(outputs), 3)
|
||||
final_message = outputs[-1]
|
||||
self.assertEqual(final_message.role, "assistant")
|
||||
self.assertIn("Simulated agent error", final_message.content)
|
||||
@@ -1,836 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(code="(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["2 * 2"]
|
||||
_inputs = []
|
||||
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
||||
|
||||
code = "a=1;b=None"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result is None
|
||||
|
||||
def test_assignment_cannot_overwrite_tool(self):
|
||||
code = "print = '3'"
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, {"print": print}, state={})
|
||||
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
# Should not work without the tool
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert "tried to execute add_two" in str(e.value)
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||
state = {}
|
||||
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||
|
||||
def test_subscript_string_with_string_index_raises_appropriate_error(self):
|
||||
code = """
|
||||
search_results = "[{'title': 'Paris, Ville de Paris, France Weather Forecast | AccuWeather', 'href': 'https://www.accuweather.com/en/fr/paris/623/weather-forecast/623', 'body': 'Get the latest weather forecast for Paris, Ville de Paris, France , including hourly, daily, and 10-day outlooks. AccuWeather provides you with reliable and accurate information on temperature ...'}]"
|
||||
for result in search_results:
|
||||
if 'current' in result['title'].lower() or 'temperature' in result['title'].lower():
|
||||
current_weather_url = result['href']
|
||||
print(current_weather_url)
|
||||
break"""
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "You're trying to subscript a string with a string index" in e
|
||||
|
||||
def test_evaluate_for(self):
|
||||
code = "x = 0\nfor i in range(3):\n x = i"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"range": range}, state=state)
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_binop(self):
|
||||
code = "y + x"
|
||||
state = {"x": 3, "y": 6}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 9
|
||||
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
|
||||
|
||||
def test_recursive_function(self):
|
||||
code = """
|
||||
def recur_fibo(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
else:
|
||||
return(recur_fibo(n-1) + recur_fibo(n-2))
|
||||
recur_fibo(6)"""
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 8
|
||||
|
||||
def test_evaluate_string_methods(self):
|
||||
code = "'hello'.replace('h', 'o').split('e')"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == ["o", "llo"]
|
||||
|
||||
def test_evaluate_slicing(self):
|
||||
code = "'hello'[1:3][::-1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "le"
|
||||
|
||||
def test_access_attributes(self):
|
||||
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result is int
|
||||
|
||||
def test_list_comprehension(self):
|
||||
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "t-h-e-s-e-a-g-u-l-l"
|
||||
|
||||
def test_string_indexing(self):
|
||||
code = """text_block = [
|
||||
"THESE",
|
||||
"AGULL"
|
||||
]
|
||||
sentence = ""
|
||||
for block in text_block:
|
||||
for col in range(len(text_block[0])):
|
||||
sentence += block[col]
|
||||
"""
|
||||
result = evaluate_python_code(code, {"len": len, "range": range}, state={})
|
||||
assert result == "THESEAGULL"
|
||||
|
||||
def test_tuples(self):
|
||||
code = "x = (1, 2, 3)\nx[1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 2
|
||||
|
||||
code = """
|
||||
digits, i = [1, 2, 3], 1
|
||||
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
||||
evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
|
||||
|
||||
code = """
|
||||
def calculate_isbn_10_check_digit(number):
|
||||
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
|
||||
remainder = total % 11
|
||||
check_digit = 11 - remainder
|
||||
if check_digit == 10:
|
||||
return 'X'
|
||||
elif check_digit == 11:
|
||||
return '0'
|
||||
else:
|
||||
return str(check_digit)
|
||||
|
||||
# Given 9-digit numbers
|
||||
numbers = [
|
||||
"478225952",
|
||||
"643485613",
|
||||
"739394228",
|
||||
"291726859",
|
||||
"875262394",
|
||||
"542617795",
|
||||
"031810713",
|
||||
"957007669",
|
||||
"871467426"
|
||||
]
|
||||
|
||||
# Calculate check digits for each number
|
||||
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
|
||||
print(check_digits)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(
|
||||
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
|
||||
)
|
||||
|
||||
def test_listcomp(self):
|
||||
code = "x = [i for i in range(3)]"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
def test_break_continue(self):
|
||||
code = "for i in range(10):\n if i == 5:\n break\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 5
|
||||
|
||||
code = "for i in range(10):\n if i == 5:\n continue\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 9
|
||||
|
||||
def test_call_int(self):
|
||||
code = "import math\nstr(math.ceil(149))"
|
||||
result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
|
||||
assert result == "149"
|
||||
|
||||
def test_lambda(self):
|
||||
code = "f = lambda x: x + 2\nf(3)"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 5
|
||||
|
||||
def test_dictcomp(self):
|
||||
code = "x = {i: i**2 for i in range(3)}"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == {0: 0, 1: 1, 2: 4}
|
||||
|
||||
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
assert result == {102: "b"}
|
||||
|
||||
code = """
|
||||
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
|
||||
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
|
||||
"""
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == {"A": ("a", "b"), "B": ("a", "b")}
|
||||
|
||||
def test_tuple_assignment(self):
|
||||
code = "a, b = 0, 1\nb"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
def test_while(self):
|
||||
code = "i = 0\nwhile i < 3:\n i += 1\ni"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 3
|
||||
|
||||
# test infinite loop
|
||||
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "iterations in While loop exceeded" in str(e)
|
||||
|
||||
# test lazy evaluation
|
||||
code = """
|
||||
house_positions = [0, 7, 10, 15, 18, 22, 22]
|
||||
i, n, loc = 0, 7, 30
|
||||
while i < n and house_positions[i] <= loc:
|
||||
i += 1
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
|
||||
def test_generator(self):
|
||||
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [1, 4, 9, 16, 25]
|
||||
|
||||
def test_boolops(self):
|
||||
code = """if (not (a > b and a > c)) or d > e:
|
||||
best_city = "Brooklyn"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Brooklyn"
|
||||
|
||||
code = """if d > e and a < b:
|
||||
best_city = "Brooklyn"
|
||||
elif d < e and a < b:
|
||||
best_city = "Sacramento"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Sacramento"
|
||||
|
||||
def test_if_conditions(self):
|
||||
code = """char='a'
|
||||
if char.isalpha():
|
||||
print('2')"""
|
||||
state = {}
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert state["print_outputs"] == "2\n"
|
||||
|
||||
def test_imports(self):
|
||||
code = "import math\nmath.sqrt(4)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.0
|
||||
|
||||
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "lose"
|
||||
|
||||
code = "import time, re\ntime.sleep(0.1)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result is None
|
||||
|
||||
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
code = "import itertools\nlist(itertools.islice(range(10), 3))"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
code = "import re\nre.search('a', 'abc').group()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "a"
|
||||
|
||||
code = "import stat\nstat.S_ISREG(0o100644)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.8
|
||||
|
||||
code = "import unicodedata\nunicodedata.name('A')"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "LATIN CAPITAL LETTER A"
|
||||
|
||||
# Test submodules are handled properly, thus not raising error
|
||||
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
|
||||
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
|
||||
def test_additional_imports(self):
|
||||
code = "import numpy as np"
|
||||
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||
|
||||
code = "import numpy.random as rd"
|
||||
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
|
||||
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||
with pytest.raises(InterpreterError):
|
||||
evaluate_python_code(code, authorized_imports=["random"], state={})
|
||||
|
||||
def test_multiple_comparators(self):
|
||||
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert not result
|
||||
|
||||
code = "0 <= 1 < 4 and 0 <= -5 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert not result
|
||||
|
||||
code = "0 <= 4 < 4 and 0 <= 3 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert not result
|
||||
|
||||
code = "0 <= 3 < 4 and 0 <= 3 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
def test_print_output(self):
|
||||
code = "print('Hello world!')\nprint('Ok no one cares')"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert result is None
|
||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||
|
||||
# test print in function
|
||||
code = """
|
||||
print("1")
|
||||
def function():
|
||||
print("2")
|
||||
function()"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"print": print}, state=state)
|
||||
assert state["print_outputs"] == "1\n2\n"
|
||||
|
||||
def test_tuple_target_in_iterator(self):
|
||||
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "Samuel"
|
||||
|
||||
def test_classes(self):
|
||||
code = """
|
||||
class Animal:
|
||||
species = "Generic Animal"
|
||||
|
||||
def __init__(self, name, age):
|
||||
self.name = name
|
||||
self.age = age
|
||||
|
||||
def sound(self):
|
||||
return "The animal makes a sound."
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, {self.age} years old"
|
||||
|
||||
class Dog(Animal):
|
||||
species = "Canine"
|
||||
|
||||
def __init__(self, name, age, breed):
|
||||
super().__init__(name, age)
|
||||
self.breed = breed
|
||||
|
||||
def sound(self):
|
||||
return "The dog barks."
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, {self.age} years old, {self.breed}"
|
||||
|
||||
class Cat(Animal):
|
||||
def sound(self):
|
||||
return "The cat meows."
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}, {self.age} years old, {self.species}"
|
||||
|
||||
|
||||
# Testing multiple instances
|
||||
dog1 = Dog("Fido", 3, "Labrador")
|
||||
dog2 = Dog("Buddy", 5, "Golden Retriever")
|
||||
|
||||
# Testing method with built-in function
|
||||
animals = [dog1, dog2, Cat("Whiskers", 2)]
|
||||
num_animals = len(animals)
|
||||
|
||||
# Testing exceptions in methods
|
||||
class ExceptionTest:
|
||||
def method_that_raises(self):
|
||||
raise ValueError("An error occurred")
|
||||
|
||||
try:
|
||||
exc_test = ExceptionTest()
|
||||
exc_test.method_that_raises()
|
||||
except ValueError as e:
|
||||
exception_message = str(e)
|
||||
|
||||
|
||||
# Collecting results
|
||||
dog1_sound = dog1.sound()
|
||||
dog1_str = str(dog1)
|
||||
dog2_sound = dog2.sound()
|
||||
dog2_str = str(dog2)
|
||||
cat = Cat("Whiskers", 2)
|
||||
cat_sound = cat.sound()
|
||||
cat_str = str(cat)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||
|
||||
# Assert results
|
||||
assert state["dog1_sound"] == "The dog barks."
|
||||
assert state["dog1_str"] == "Fido, 3 years old, Labrador"
|
||||
assert state["dog2_sound"] == "The dog barks."
|
||||
assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever"
|
||||
assert state["cat_sound"] == "The cat meows."
|
||||
assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal"
|
||||
assert state["num_animals"] == 3
|
||||
assert state["exception_message"] == "An error occurred"
|
||||
|
||||
def test_variable_args(self):
|
||||
code = """
|
||||
def var_args_method(self, *args, **kwargs):
|
||||
return sum(args) + sum(kwargs.values())
|
||||
|
||||
var_args_method(1, 2, 3, x=4, y=5)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"sum": sum}, state=state)
|
||||
assert result == 15
|
||||
|
||||
def test_exceptions(self):
|
||||
code = """
|
||||
def method_that_raises(self):
|
||||
raise ValueError("An error occurred")
|
||||
|
||||
try:
|
||||
method_that_raises()
|
||||
except ValueError as e:
|
||||
exception_message = str(e)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||
assert state["exception_message"] == "An error occurred"
|
||||
|
||||
def test_print(self):
|
||||
code = "print(min([1, 2, 3]))"
|
||||
state = {}
|
||||
evaluate_python_code(code, {"min": min, "print": print}, state=state)
|
||||
assert state["print_outputs"] == "1\n"
|
||||
|
||||
def test_types_as_objects(self):
|
||||
code = "type_a = float(2); type_b = str; type_c = int"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
||||
assert result is int
|
||||
|
||||
def test_tuple_id(self):
|
||||
code = """
|
||||
food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
|
||||
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == ["orange", "pear"]
|
||||
|
||||
def test_nonsimple_augassign(self):
|
||||
code = """
|
||||
counts_dict = {'a': 0}
|
||||
counts_dict['a'] += 1
|
||||
counts_list = [1, 2, 3]
|
||||
counts_list += [4, 5, 6]
|
||||
|
||||
class Counter:
|
||||
self.count = 0
|
||||
|
||||
a = Counter()
|
||||
a.count += 1
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert state["counts_dict"] == {"a": 1}
|
||||
assert state["counts_list"] == [1, 2, 3, 4, 5, 6]
|
||||
assert state["a"].count == 1
|
||||
|
||||
def test_adding_int_to_list_raises_error(self):
|
||||
code = """
|
||||
counts = [1, 2, 3]
|
||||
counts += 1"""
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "Cannot add non-list value 1 to a list." in str(e)
|
||||
|
||||
def test_error_highlights_correct_line_of_code(self):
|
||||
code = """# Ok this is a very long code
|
||||
# It has many commented lines
|
||||
a = 1
|
||||
b = 2
|
||||
|
||||
# Here is another piece
|
||||
counts = [1, 2, 3]
|
||||
counts += 1
|
||||
b += 1"""
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "Evaluation stopped at line 'counts += 1" in str(e)
|
||||
|
||||
def test_assert(self):
|
||||
code = """
|
||||
assert 1 == 1
|
||||
assert 1 == 2
|
||||
"""
|
||||
with pytest.raises(AssertionError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "1 == 2" in str(e) and "1 == 1" not in str(e)
|
||||
|
||||
def test_with_context_manager(self):
|
||||
code = """
|
||||
class SimpleLock:
|
||||
def __init__(self):
|
||||
self.locked = False
|
||||
|
||||
def __enter__(self):
|
||||
self.locked = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.locked = False
|
||||
|
||||
lock = SimpleLock()
|
||||
|
||||
with lock as l:
|
||||
assert l.locked == True
|
||||
|
||||
assert lock.locked == False
|
||||
"""
|
||||
state = {}
|
||||
tools = {}
|
||||
evaluate_python_code(code, tools, state=state)
|
||||
|
||||
def test_default_arg_in_function(self):
|
||||
code = """
|
||||
def f(a, b=333, n=1000):
|
||||
return b + n
|
||||
n = f(1, n=667)
|
||||
"""
|
||||
res = evaluate_python_code(code, {}, {})
|
||||
assert res == 1000
|
||||
|
||||
def test_set(self):
|
||||
code = """
|
||||
S1 = {'a', 'b', 'c'}
|
||||
S2 = {'b', 'c', 'd'}
|
||||
S3 = S1.difference(S2)
|
||||
S4 = S1.intersection(S2)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert state["S3"] == {"a"}
|
||||
assert state["S4"] == {"b", "c"}
|
||||
|
||||
def test_break(self):
|
||||
code = """
|
||||
i = 0
|
||||
|
||||
while True:
|
||||
i+= 1
|
||||
if i==3:
|
||||
break
|
||||
|
||||
i"""
|
||||
result = evaluate_python_code(code, {"print": print, "round": round}, state={})
|
||||
assert result == 3
|
||||
|
||||
def test_return(self):
|
||||
# test early returns
|
||||
code = """
|
||||
def add_one(n, shift):
|
||||
if True:
|
||||
return n + shift
|
||||
return n
|
||||
|
||||
add_one(1, 1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
assert result == 2
|
||||
|
||||
# test returning None
|
||||
code = """
|
||||
def returns_none(a):
|
||||
return
|
||||
|
||||
returns_none(1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
assert result is None
|
||||
|
||||
def test_nested_for_loop(self):
|
||||
code = """
|
||||
all_res = []
|
||||
for i in range(10):
|
||||
subres = []
|
||||
for j in range(i):
|
||||
subres.append(j)
|
||||
all_res.append(subres)
|
||||
|
||||
out = [i for sublist in all_res for i in sublist]
|
||||
out[:10]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
||||
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||
|
||||
def test_pandas(self):
|
||||
code = """
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
|
||||
|
||||
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
|
||||
|
||||
parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
||||
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
||||
assert np.array_equal(result, [-1, 5])
|
||||
|
||||
code = """
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
|
||||
print("HH0")
|
||||
|
||||
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
assert np.array_equal(result.values[0], [104, 1])
|
||||
|
||||
code = """import pandas as pd
|
||||
data = pd.DataFrame.from_dict([
|
||||
{"Pclass": 1, "Survived": 1},
|
||||
{"Pclass": 2, "Survived": 0},
|
||||
{"Pclass": 2, "Survived": 1}
|
||||
])
|
||||
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
||||
"""
|
||||
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
||||
assert result.values[1] == 0.5
|
||||
|
||||
def test_starred(self):
|
||||
code = """
|
||||
from math import radians, sin, cos, sqrt, atan2
|
||||
|
||||
def haversine(lat1, lon1, lat2, lon2):
|
||||
R = 6371000 # Radius of the Earth in meters
|
||||
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
|
||||
dlat = lat2 - lat1
|
||||
dlon = lon2 - lon1
|
||||
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
|
||||
c = 2 * atan2(sqrt(a), sqrt(1 - a))
|
||||
distance = R * c
|
||||
return distance
|
||||
|
||||
coords_geneva = (46.1978, 6.1342)
|
||||
coords_barcelona = (41.3869, 2.1660)
|
||||
|
||||
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
||||
assert round(result, 1) == 622395.4
|
||||
|
||||
def test_for(self):
|
||||
code = """
|
||||
shifts = {
|
||||
"Worker A": ("6:45 pm", "8:00 pm"),
|
||||
"Worker B": ("10:00 am", "11:45 am")
|
||||
}
|
||||
|
||||
shift_intervals = {}
|
||||
for worker, (start, end) in shifts.items():
|
||||
shift_intervals[worker] = end
|
||||
shift_intervals
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
|
||||
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("web_search")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Agents")
|
||||
assert isinstance(result, list) and isinstance(result[0], dict)
|
||||
@@ -1,35 +0,0 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech_to_text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
@@ -1,49 +0,0 @@
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text_to_speech")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
@@ -1,170 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
from transformers.agents.tools import Tool, tool
|
||||
from transformers.testing_utils import get_tests_dir, is_agent_test
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = ["string", "boolean", "integer", "number", "audio", "image", "any"]
|
||||
|
||||
|
||||
def create_inputs(tool_inputs: dict[str, dict[Union[str, type], str]]):
|
||||
inputs = {}
|
||||
|
||||
for input_name, input_desc in tool_inputs.items():
|
||||
input_type = input_desc["type"]
|
||||
|
||||
if input_type == "string":
|
||||
inputs[input_name] = "Text input"
|
||||
elif input_type == "image":
|
||||
inputs[input_name] = Image.open(
|
||||
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
).resize((512, 512))
|
||||
elif input_type == "audio":
|
||||
inputs[input_name] = np.ones(3000)
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_type(output):
|
||||
if isinstance(output, (str, AgentText)):
|
||||
return "string"
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
return "image"
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
return "audio"
|
||||
else:
|
||||
raise TypeError(f"Invalid output: {output}")
|
||||
|
||||
|
||||
@is_agent_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_output(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
self.assertTrue(isinstance(inputs, dict))
|
||||
|
||||
for _, input_spec in inputs.items():
|
||||
self.assertTrue("type" in input_spec)
|
||||
self.assertTrue("description" in input_spec)
|
||||
self.assertTrue(input_spec["type"] in AUTHORIZED_TYPES)
|
||||
self.assertTrue(isinstance(input_spec["description"], str))
|
||||
|
||||
output_type = self.tool.output_type
|
||||
self.assertTrue(output_type in AUTHORIZED_TYPES)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "name"))
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
output = self.tool(**inputs)
|
||||
if self.tool.output_type != "any":
|
||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
_inputs = []
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
|
||||
class ToolTests(unittest.TestCase):
|
||||
def test_tool_init_with_decorator(self):
|
||||
@tool
|
||||
def coolfunc(a: str, b: int) -> float:
|
||||
"""Cool function
|
||||
|
||||
Args:
|
||||
a: The first argument
|
||||
b: The second one
|
||||
"""
|
||||
return b + 2, a
|
||||
|
||||
assert coolfunc.output_type == "number"
|
||||
|
||||
def test_tool_init_vanilla(self):
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = """
|
||||
This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
|
||||
It returns the name of the checkpoint."""
|
||||
|
||||
inputs = {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "the task category (such as text-classification, depth-estimation, etc)",
|
||||
}
|
||||
}
|
||||
output_type = "integer"
|
||||
|
||||
def forward(self, task):
|
||||
return "best model"
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
assert list(tool.inputs.keys())[0] == "task"
|
||||
|
||||
def test_tool_init_decorator_raises_issues(self):
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
@tool
|
||||
def coolfunc(a: str, b: int):
|
||||
"""Cool function
|
||||
|
||||
Args:
|
||||
a: The first argument
|
||||
b: The second one
|
||||
"""
|
||||
return a + b
|
||||
|
||||
assert coolfunc.output_type == "number"
|
||||
assert "Tool return type not found" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
|
||||
@tool
|
||||
def coolfunc(a: str, b: int) -> int:
|
||||
"""Cool function
|
||||
|
||||
Args:
|
||||
a: The first argument
|
||||
"""
|
||||
return b + a
|
||||
|
||||
assert coolfunc.output_type == "number"
|
||||
assert "docstring has no description for the argument" in str(e)
|
||||
@@ -1,66 +0,0 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_type
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
self.assertEqual(output_type(output), self.tool.output_type)
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
example_inputs = {
|
||||
"text": "Hey, what's up?",
|
||||
"src_lang": "English",
|
||||
"tgt_lang": "Spanish",
|
||||
}
|
||||
|
||||
_inputs = []
|
||||
for input_name in example_inputs.keys():
|
||||
example_input = example_inputs[input_name]
|
||||
input_description = self.tool.inputs[input_name]
|
||||
input_type = input_description["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](example_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**example_inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
Reference in New Issue
Block a user