Agents: Improve python interpreter (#31409)
* Improve Python interpreter * Add with and assert statements * Prevent overwriting existing tools * Check interpreter errors are well logged in code agent * Add lazy evaluation for and and or * Improve variable assignment * Fix early return statements in functions * Add small import fix on interpreter tool
This commit is contained in:
@@ -74,6 +74,26 @@ final_answer(7.2904)
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
print = 2
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer("got an error")
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
@@ -124,6 +144,13 @@ Action:
|
||||
"tool_name": "code interpreter",
|
||||
}
|
||||
|
||||
def test_react_code_agent_code_errors_show_offending_lines(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "got an error"
|
||||
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
||||
|
||||
def test_setup_agent_with_empty_toolbox(self):
|
||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import pytest
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from transformers.agents.python_interpreter import InterpretorError, evaluate_python_code
|
||||
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
@@ -35,16 +35,6 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_input_spec(self):
|
||||
inputs_spec = self.tool.inputs
|
||||
expected_description = (
|
||||
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
||||
"else you will get an error. This code can only import the following python libraries: "
|
||||
"['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', "
|
||||
"'random', 're']."
|
||||
)
|
||||
self.assertEqual(inputs_spec["code"]["description"], expected_description)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
@@ -91,6 +81,17 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
||||
|
||||
code = "a=1;b=None"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result is None
|
||||
|
||||
def test_assignment_cannot_overwrite_tool(self):
|
||||
code = "print = '3'"
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, {"print": print}, state={})
|
||||
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
@@ -99,7 +100,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
# Should not work without the tool
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert "tried to execute add_two" in str(e.value)
|
||||
|
||||
@@ -237,6 +238,12 @@ for block in text_block:
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 2
|
||||
|
||||
code = """
|
||||
digits, i = [1, 2, 3], 1
|
||||
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state)
|
||||
|
||||
def test_listcomp(self):
|
||||
code = "x = [i for i in range(3)]"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
@@ -278,10 +285,20 @@ for block in text_block:
|
||||
|
||||
# test infinite loop
|
||||
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "iterations in While loop exceeded" in str(e)
|
||||
|
||||
# test lazy evaluation
|
||||
code = """
|
||||
house_positions = [0, 7, 10, 15, 18, 22, 22]
|
||||
i, n, loc = 0, 7, 30
|
||||
while i < n and house_positions[i] <= loc:
|
||||
i += 1
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
|
||||
def test_generator(self):
|
||||
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
@@ -353,7 +370,19 @@ if char.isalpha():
|
||||
assert result == "LATIN CAPITAL LETTER A"
|
||||
|
||||
def test_multiple_comparators(self):
|
||||
code = "0x30A0 <= ord('a') <= 0x30FF"
|
||||
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert not result
|
||||
|
||||
code = "0 <= 1 < 4 and 0 <= -5 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert not result
|
||||
|
||||
code = "0 <= 4 < 4 and 0 <= 3 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert not result
|
||||
|
||||
code = "0 <= 3 < 4 and 0 <= 3 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
@@ -364,6 +393,16 @@ if char.isalpha():
|
||||
assert result == "Ok no one cares"
|
||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||
|
||||
# test print in function
|
||||
code = """
|
||||
print("1")
|
||||
def function():
|
||||
print("2")
|
||||
function()"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"print": print}, state)
|
||||
assert state["print_outputs"] == "1\n2\n"
|
||||
|
||||
def test_tuple_target_in_iterator(self):
|
||||
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
@@ -491,3 +530,147 @@ except ValueError as e:
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
||||
assert result == int
|
||||
|
||||
def test_tuple_id(self):
|
||||
code = """
|
||||
food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
|
||||
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == ["orange", "pear"]
|
||||
|
||||
def test_nonsimple_augassign(self):
|
||||
code = """
|
||||
counts_dict = {'a': 0}
|
||||
counts_dict['a'] += 1
|
||||
counts_list = [1, 2, 3]
|
||||
counts_list += [4, 5, 6]
|
||||
|
||||
class Counter:
|
||||
self.count = 0
|
||||
|
||||
a = Counter()
|
||||
a.count += 1
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert state["counts_dict"] == {"a": 1}
|
||||
assert state["counts_list"] == [1, 2, 3, 4, 5, 6]
|
||||
assert state["a"].count == 1
|
||||
|
||||
def test_adding_int_to_list_raises_error(self):
|
||||
code = """
|
||||
counts = [1, 2, 3]
|
||||
counts += 1"""
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "Cannot add non-list value 1 to a list." in str(e)
|
||||
|
||||
def test_error_highlights_correct_line_of_code(self):
|
||||
code = """# Ok this is a very long code
|
||||
# It has many commented lines
|
||||
a = 1
|
||||
b = 2
|
||||
|
||||
# Here is another piece
|
||||
counts = [1, 2, 3]
|
||||
counts += 1
|
||||
b += 1"""
|
||||
with pytest.raises(InterpreterError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "Evaluation stopped at line 'counts += 1" in str(e)
|
||||
|
||||
def test_assert(self):
|
||||
code = """
|
||||
assert 1 == 1
|
||||
assert 1 == 2
|
||||
"""
|
||||
with pytest.raises(AssertionError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "1 == 2" in str(e) and "1 == 1" not in str(e)
|
||||
|
||||
def test_with_context_manager(self):
|
||||
code = """
|
||||
class SimpleLock:
|
||||
def __init__(self):
|
||||
self.locked = False
|
||||
|
||||
def __enter__(self):
|
||||
self.locked = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.locked = False
|
||||
|
||||
lock = SimpleLock()
|
||||
|
||||
with lock as l:
|
||||
assert l.locked == True
|
||||
|
||||
assert lock.locked == False
|
||||
"""
|
||||
state = {}
|
||||
tools = {}
|
||||
evaluate_python_code(code, tools, state)
|
||||
|
||||
def test_default_arg_in_function(self):
|
||||
code = """
|
||||
def f(a, b=333, n=1000):
|
||||
return b + n
|
||||
n = f(1, n=667)
|
||||
"""
|
||||
res = evaluate_python_code(code, {}, {})
|
||||
assert res == 1000
|
||||
|
||||
def test_set(self):
|
||||
code = """
|
||||
S1 = {'a', 'b', 'c'}
|
||||
S2 = {'b', 'c', 'd'}
|
||||
S3 = S1.difference(S2)
|
||||
S4 = S1.intersection(S2)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert state["S3"] == {"a"}
|
||||
assert state["S4"] == {"b", "c"}
|
||||
|
||||
def test_break(self):
|
||||
code = """
|
||||
i = 0
|
||||
|
||||
while True:
|
||||
i+= 1
|
||||
if i==3:
|
||||
break
|
||||
|
||||
i"""
|
||||
result = evaluate_python_code(code, {"print": print, "round": round}, state={})
|
||||
assert result == 3
|
||||
|
||||
def test_return(self):
|
||||
# test early returns
|
||||
code = """
|
||||
def add_one(n, shift):
|
||||
if True:
|
||||
return n + shift
|
||||
return n
|
||||
|
||||
add_one(1, 1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
print(state)
|
||||
assert result == 2
|
||||
|
||||
# test returning None
|
||||
code = """
|
||||
def returns_none(a):
|
||||
return
|
||||
|
||||
returns_none(1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
print(state)
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user