Code agent: allow function persistence between steps (#31769)

* Code agent: allow function persistence between steps
This commit is contained in:
Aymeric Roucher
2024-07-05 11:09:11 +02:00
committed by GitHub
parent eef0507f3d
commit 1556025271
5 changed files with 63 additions and 11 deletions

View File

@@ -94,12 +94,48 @@ final_answer("got an error")
"""
def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Thought: Let's define the function. special_marker
Code:
```py
import numpy as np
def moving_average(x, w):
return np.convolve(x, np.ones(w), 'valid') / w
```<end_code>
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Code:
```py
x, w = [0, 1, 2, 3, 4, 5], 2
res = moving_average(x, w)
final_answer(res)
```<end_code>
"""
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
final_answer(result)
```
"""
def fake_code_llm_no_return(messages, stop_sequences=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
print(result)
```
"""
@@ -135,8 +171,8 @@ Action:
def test_fake_react_code_agent(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "7.2904"
assert isinstance(output, float)
assert output == 7.2904
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
assert agent.logs[2]["tool_call"] == {
@@ -157,7 +193,7 @@ Action:
def test_react_fails_max_iterations(self):
agent = ReactCodeAgent(
tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
max_iterations=5,
)
agent.run("What is 2 multiplied by 3.6452?")
@@ -192,3 +228,10 @@ Action:
# check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
def test_function_persistence_across_steps(self):
agent = ReactCodeAgent(
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
)
res = agent.run("ok")
assert res[0] == 0.5

View File

@@ -660,7 +660,6 @@ add_one(1, 1)
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
print(state)
assert result == 2
# test returning None
@@ -672,5 +671,4 @@ returns_none(1)
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
print(state)
assert result is None