@@ -25,7 +25,19 @@ from ..utils.import_utils import is_pygments_available
|
|||||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||||
from .llm_engine import HfEngine, MessageRole
|
from .llm_engine import HfEngine, MessageRole
|
||||||
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
from .prompts import (
|
||||||
|
DEFAULT_CODE_SYSTEM_PROMPT,
|
||||||
|
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
|
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||||
|
PLAN_UPDATE_FINAL_PLAN_REDACTION,
|
||||||
|
SYSTEM_PROMPT_FACTS,
|
||||||
|
SYSTEM_PROMPT_FACTS_UPDATE,
|
||||||
|
SYSTEM_PROMPT_PLAN,
|
||||||
|
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||||
|
USER_PROMPT_FACTS_UPDATE,
|
||||||
|
USER_PROMPT_PLAN,
|
||||||
|
USER_PROMPT_PLAN_UPDATE,
|
||||||
|
)
|
||||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||||
from .tools import (
|
from .tools import (
|
||||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
@@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
|||||||
|
|
||||||
def parse_code_blob(code_blob: str) -> str:
|
def parse_code_blob(code_blob: str) -> str:
|
||||||
try:
|
try:
|
||||||
pattern = r"```(?:py|python)?\n(.*?)```"
|
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||||
match = re.search(pattern, code_blob, re.DOTALL)
|
match = re.search(pattern, code_blob, re.DOTALL)
|
||||||
return match.group(1).strip()
|
return match.group(1).strip()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}"
|
f"""
|
||||||
|
The code blob you used is invalid: due to the following error: {e}
|
||||||
|
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
|
||||||
|
Thoughts: Your thoughts
|
||||||
|
Code:
|
||||||
|
```py
|
||||||
|
# Your python code here
|
||||||
|
```<end_action>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
|||||||
tool_call = parse_json_blob(json_blob)
|
tool_call = parse_json_blob(json_blob)
|
||||||
if "action" in tool_call and "action_input" in tool_call:
|
if "action" in tool_call and "action_input" in tool_call:
|
||||||
return tool_call["action"], tool_call["action_input"]
|
return tool_call["action"], tool_call["action_input"]
|
||||||
|
elif "action" in tool_call:
|
||||||
|
return tool_call["action"], None
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
||||||
@@ -208,7 +229,7 @@ class Toolbox:
|
|||||||
The tool to add to the toolbox.
|
The tool to add to the toolbox.
|
||||||
"""
|
"""
|
||||||
if tool.name in self._tools:
|
if tool.name in self._tools:
|
||||||
raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.")
|
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
|
||||||
self._tools[tool.name] = tool
|
self._tools[tool.name] = tool
|
||||||
|
|
||||||
def remove_tool(self, tool_name: str):
|
def remove_tool(self, tool_name: str):
|
||||||
@@ -359,12 +380,8 @@ class Agent:
|
|||||||
"""Get the toolbox currently available to the agent"""
|
"""Get the toolbox currently available to the agent"""
|
||||||
return self._toolbox
|
return self._toolbox
|
||||||
|
|
||||||
def initialize_for_run(self, task: str, **kwargs):
|
def initialize_for_run(self):
|
||||||
self.token_count = 0
|
self.token_count = 0
|
||||||
self.task = task
|
|
||||||
if len(kwargs) > 0:
|
|
||||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
|
||||||
self.state = kwargs.copy()
|
|
||||||
self.system_prompt = format_prompt_with_tools(
|
self.system_prompt = format_prompt_with_tools(
|
||||||
self._toolbox,
|
self._toolbox,
|
||||||
self.system_prompt_template,
|
self.system_prompt_template,
|
||||||
@@ -380,7 +397,7 @@ class Agent:
|
|||||||
self.logger.debug("System prompt is as follows:")
|
self.logger.debug("System prompt is as follows:")
|
||||||
self.logger.debug(self.system_prompt)
|
self.logger.debug(self.system_prompt)
|
||||||
|
|
||||||
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
|
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||||
that can be used as input to the LLM.
|
that can be used as input to the LLM.
|
||||||
@@ -390,43 +407,51 @@ class Agent:
|
|||||||
"role": MessageRole.USER,
|
"role": MessageRole.USER,
|
||||||
"content": "Task: " + self.logs[0]["task"],
|
"content": "Task: " + self.logs[0]["task"],
|
||||||
}
|
}
|
||||||
|
if summary_mode:
|
||||||
|
memory = [task_message]
|
||||||
|
else:
|
||||||
memory = [prompt_message, task_message]
|
memory = [prompt_message, task_message]
|
||||||
for i, step_log in enumerate(self.logs[1:]):
|
for i, step_log in enumerate(self.logs[1:]):
|
||||||
if "llm_output" in step_log:
|
if "llm_output" in step_log and not summary_mode:
|
||||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"}
|
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
|
||||||
|
memory.append(thought_message)
|
||||||
|
if "facts" in step_log:
|
||||||
|
thought_message = {
|
||||||
|
"role": MessageRole.ASSISTANT,
|
||||||
|
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
|
||||||
|
}
|
||||||
memory.append(thought_message)
|
memory.append(thought_message)
|
||||||
|
|
||||||
|
if "plan" in step_log and not summary_mode:
|
||||||
|
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
|
||||||
|
memory.append(thought_message)
|
||||||
|
|
||||||
|
if "tool_call" in step_log and summary_mode:
|
||||||
|
tool_call_message = {
|
||||||
|
"role": MessageRole.ASSISTANT,
|
||||||
|
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
|
||||||
|
}
|
||||||
|
memory.append(tool_call_message)
|
||||||
|
|
||||||
|
if "task" in step_log:
|
||||||
|
tool_call_message = {
|
||||||
|
"role": MessageRole.USER,
|
||||||
|
"content": "New task:\n" + step_log["task"],
|
||||||
|
}
|
||||||
|
memory.append(tool_call_message)
|
||||||
|
|
||||||
|
if "error" in step_log or "observation" in step_log:
|
||||||
if "error" in step_log:
|
if "error" in step_log:
|
||||||
message_content = (
|
message_content = (
|
||||||
"Error: "
|
f"[OUTPUT OF STEP {i}] Error: "
|
||||||
+ str(step_log["error"])
|
+ str(step_log["error"])
|
||||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||||
)
|
)
|
||||||
elif "observation" in step_log:
|
elif "observation" in step_log:
|
||||||
message_content = f"Observation: {step_log['observation']}"
|
message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
|
||||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||||
memory.append(tool_response_message)
|
memory.append(tool_response_message)
|
||||||
|
|
||||||
if len(memory) % 3 == 0:
|
|
||||||
reminder_content = (
|
|
||||||
"Reminder: you are working towards solving the following task: " + self.logs[0]["task"]
|
|
||||||
)
|
|
||||||
reminder_content += "\nHere is a summary of your past tool calls and their results:"
|
|
||||||
for j in range(i + 1):
|
|
||||||
reminder_content += "\nStep " + str(j + 1)
|
|
||||||
if "tool_call" in self.logs[j]:
|
|
||||||
reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"])
|
|
||||||
if self.memory_verbose:
|
|
||||||
if "observation" in self.logs[j]:
|
|
||||||
reminder_content += "\nObservation:" + str(self.logs[j]["observation"])
|
|
||||||
if "error" in self.logs[j]:
|
|
||||||
reminder_content += "\nError:" + str(self.logs[j]["error"])
|
|
||||||
memory.append(
|
|
||||||
{
|
|
||||||
"role": MessageRole.USER,
|
|
||||||
"content": reminder_content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def get_succinct_logs(self):
|
def get_succinct_logs(self):
|
||||||
@@ -459,7 +484,7 @@ class Agent:
|
|||||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox).
|
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||||
"""
|
"""
|
||||||
if tool_name not in self.toolbox.tools:
|
if tool_name not in self.toolbox.tools:
|
||||||
@@ -559,7 +584,11 @@ class CodeAgent(Agent):
|
|||||||
agent.run("What is the result of 2 power 3.7384?")
|
agent.run("What is the result of 2 power 3.7384?")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
self.initialize_for_run(task, **kwargs)
|
self.task = task
|
||||||
|
if len(kwargs) > 0:
|
||||||
|
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||||
|
self.state = kwargs.copy()
|
||||||
|
self.initialize_for_run()
|
||||||
|
|
||||||
# Run LLM
|
# Run LLM
|
||||||
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
|
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
|
||||||
@@ -598,7 +627,8 @@ class CodeAgent(Agent):
|
|||||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||||
output = self.python_evaluator(
|
output = self.python_evaluator(
|
||||||
code_action,
|
code_action,
|
||||||
available_tools,
|
static_tools=available_tools,
|
||||||
|
custom_tools={},
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
@@ -623,6 +653,7 @@ class ReactAgent(Agent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -632,6 +663,7 @@ class ReactAgent(Agent):
|
|||||||
tool_description_template=tool_description_template,
|
tool_description_template=tool_description_template,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
self.planning_interval = planning_interval
|
||||||
|
|
||||||
def provide_final_answer(self, task) -> str:
|
def provide_final_answer(self, task) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -655,11 +687,13 @@ class ReactAgent(Agent):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error in generating final llm output: {e}."
|
return f"Error in generating final llm output: {e}."
|
||||||
|
|
||||||
def run(self, task: str, stream: bool = False, **kwargs):
|
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Runs the agent for the given task.
|
Runs the agent for the given task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (`str`): The task to perform
|
task (`str`): The task to perform
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```py
|
```py
|
||||||
from transformers.agents import ReactCodeAgent
|
from transformers.agents import ReactCodeAgent
|
||||||
@@ -667,14 +701,23 @@ class ReactAgent(Agent):
|
|||||||
agent.run("What is the result of 2 power 3.7384?")
|
agent.run("What is the result of 2 power 3.7384?")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if stream:
|
self.task = task
|
||||||
return self.stream_run(task, **kwargs)
|
if len(kwargs) > 0:
|
||||||
|
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||||
|
self.state = kwargs.copy()
|
||||||
|
if reset:
|
||||||
|
self.initialize_for_run()
|
||||||
else:
|
else:
|
||||||
return self.direct_run(task, **kwargs)
|
self.logs.append({"task": task})
|
||||||
|
if stream:
|
||||||
def stream_run(self, task: str, **kwargs):
|
return self.stream_run(task)
|
||||||
self.initialize_for_run(task, **kwargs)
|
else:
|
||||||
|
return self.direct_run(task)
|
||||||
|
|
||||||
|
def stream_run(self, task: str):
|
||||||
|
"""
|
||||||
|
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
||||||
|
"""
|
||||||
final_answer = None
|
final_answer = None
|
||||||
iteration = 0
|
iteration = 0
|
||||||
while final_answer is None and iteration < self.max_iterations:
|
while final_answer is None and iteration < self.max_iterations:
|
||||||
@@ -700,13 +743,16 @@ class ReactAgent(Agent):
|
|||||||
|
|
||||||
yield final_answer
|
yield final_answer
|
||||||
|
|
||||||
def direct_run(self, task: str, **kwargs):
|
def direct_run(self, task: str):
|
||||||
self.initialize_for_run(task, **kwargs)
|
"""
|
||||||
|
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
||||||
|
"""
|
||||||
final_answer = None
|
final_answer = None
|
||||||
iteration = 0
|
iteration = 0
|
||||||
while final_answer is None and iteration < self.max_iterations:
|
while final_answer is None and iteration < self.max_iterations:
|
||||||
try:
|
try:
|
||||||
|
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||||
|
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||||
step_logs = self.step()
|
step_logs = self.step()
|
||||||
if "final_answer" in step_logs:
|
if "final_answer" in step_logs:
|
||||||
final_answer = step_logs["final_answer"]
|
final_answer = step_logs["final_answer"]
|
||||||
@@ -726,6 +772,96 @@ class ReactAgent(Agent):
|
|||||||
|
|
||||||
return final_answer
|
return final_answer
|
||||||
|
|
||||||
|
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
|
||||||
|
"""
|
||||||
|
Used periodically by the agent to plan the next steps to reach the objective.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (`str`): The task to perform
|
||||||
|
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
|
||||||
|
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
||||||
|
"""
|
||||||
|
if is_first_step:
|
||||||
|
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
|
||||||
|
message_prompt_task = {
|
||||||
|
"role": MessageRole.USER,
|
||||||
|
"content": f"""Here is the task:
|
||||||
|
```
|
||||||
|
{task}
|
||||||
|
```
|
||||||
|
Now begin!""",
|
||||||
|
}
|
||||||
|
|
||||||
|
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
|
||||||
|
|
||||||
|
message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN}
|
||||||
|
message_user_prompt_plan = {
|
||||||
|
"role": MessageRole.USER,
|
||||||
|
"content": USER_PROMPT_PLAN.format(
|
||||||
|
task=task,
|
||||||
|
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||||
|
answer_facts=answer_facts,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
answer_plan = self.llm_engine(
|
||||||
|
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
|
||||||
|
)
|
||||||
|
|
||||||
|
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
||||||
|
```
|
||||||
|
{answer_plan}
|
||||||
|
```"""
|
||||||
|
final_facts_redaction = f"""Here are the facts that I know so far:
|
||||||
|
```
|
||||||
|
{answer_facts}
|
||||||
|
```""".strip()
|
||||||
|
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||||
|
self.logger.debug("===== Initial plan: =====")
|
||||||
|
self.logger.debug(final_plan_redaction)
|
||||||
|
else: # update plan
|
||||||
|
agent_memory = self.write_inner_memory_from_logs(
|
||||||
|
summary_mode=False
|
||||||
|
) # This will not log the plan but will log facts
|
||||||
|
|
||||||
|
# Redact updated facts
|
||||||
|
facts_update_system_prompt = {
|
||||||
|
"role": MessageRole.SYSTEM,
|
||||||
|
"content": SYSTEM_PROMPT_FACTS_UPDATE,
|
||||||
|
}
|
||||||
|
facts_update_message = {
|
||||||
|
"role": MessageRole.USER,
|
||||||
|
"content": USER_PROMPT_FACTS_UPDATE,
|
||||||
|
}
|
||||||
|
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
|
||||||
|
|
||||||
|
# Redact updated plan
|
||||||
|
plan_update_message = {
|
||||||
|
"role": MessageRole.SYSTEM,
|
||||||
|
"content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task),
|
||||||
|
}
|
||||||
|
plan_update_message_user = {
|
||||||
|
"role": MessageRole.USER,
|
||||||
|
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||||
|
task=task,
|
||||||
|
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||||
|
facts_update=facts_update,
|
||||||
|
remaining_steps=(self.max_iterations - iteration),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
plan_update = self.llm_engine(
|
||||||
|
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log final facts and plan
|
||||||
|
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
||||||
|
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
||||||
|
```
|
||||||
|
{facts_update}
|
||||||
|
```"""
|
||||||
|
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||||
|
self.logger.debug("===== Updated plan: =====")
|
||||||
|
self.logger.debug(final_plan_redaction)
|
||||||
|
|
||||||
|
|
||||||
class ReactJsonAgent(ReactAgent):
|
class ReactJsonAgent(ReactAgent):
|
||||||
"""
|
"""
|
||||||
@@ -740,6 +876,7 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -747,6 +884,7 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
llm_engine=llm_engine,
|
llm_engine=llm_engine,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tool_description_template=tool_description_template,
|
tool_description_template=tool_description_template,
|
||||||
|
planning_interval=planning_interval,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -792,11 +930,16 @@ class ReactJsonAgent(ReactAgent):
|
|||||||
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||||
if tool_name == "final_answer":
|
if tool_name == "final_answer":
|
||||||
if isinstance(arguments, dict):
|
if isinstance(arguments, dict):
|
||||||
|
if "answer" in arguments:
|
||||||
answer = arguments["answer"]
|
answer = arguments["answer"]
|
||||||
|
if (
|
||||||
|
isinstance(answer, str) and answer in self.state.keys()
|
||||||
|
): # if the answer is a state variable, return the value
|
||||||
|
answer = self.state[answer]
|
||||||
|
else:
|
||||||
|
answer = arguments
|
||||||
else:
|
else:
|
||||||
answer = arguments
|
answer = arguments
|
||||||
if answer in self.state: # if the answer is a state variable, return the value
|
|
||||||
answer = self.state[answer]
|
|
||||||
current_step_logs["final_answer"] = answer
|
current_step_logs["final_answer"] = answer
|
||||||
return current_step_logs
|
return current_step_logs
|
||||||
else:
|
else:
|
||||||
@@ -835,6 +978,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
additional_authorized_imports: Optional[List[str]] = None,
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
|
planning_interval: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -842,6 +986,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
llm_engine=llm_engine,
|
llm_engine=llm_engine,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
tool_description_template=tool_description_template,
|
tool_description_template=tool_description_template,
|
||||||
|
planning_interval=planning_interval,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -856,10 +1001,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||||
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||||
self.available_tools = {
|
self.custom_tools = {}
|
||||||
**BASE_PYTHON_TOOLS.copy(),
|
|
||||||
**self.toolbox.tools,
|
|
||||||
} # This list can be augmented by the code agent creating some new functions
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
"""
|
"""
|
||||||
@@ -911,7 +1053,11 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
try:
|
try:
|
||||||
result = self.python_evaluator(
|
result = self.python_evaluator(
|
||||||
code_action,
|
code_action,
|
||||||
tools=self.available_tools,
|
static_tools={
|
||||||
|
**BASE_PYTHON_TOOLS.copy(),
|
||||||
|
**self.toolbox.tools,
|
||||||
|
},
|
||||||
|
custom_tools=self.custom_tools,
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=self.authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
@@ -920,7 +1066,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
self.logger.log(32, information)
|
self.logger.log(32, information)
|
||||||
current_step_logs["observation"] = information
|
current_step_logs["observation"] = information
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
|
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||||
if "'dict' object has no attribute 'read'" in str(e):
|
if "'dict' object has no attribute 'read'" in str(e):
|
||||||
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
|
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
|
||||||
raise AgentExecutionError(error_msg)
|
raise AgentExecutionError(error_msg)
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool):
|
|||||||
|
|
||||||
def forward(self, code):
|
def forward(self, code):
|
||||||
output = str(
|
output = str(
|
||||||
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports)
|
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task:
|
|||||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||||
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
||||||
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
||||||
9. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||||
|
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||||
|
|
||||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SYSTEM_PROMPT_FACTS = """Below I will present you a task.
|
||||||
|
|
||||||
|
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
|
||||||
|
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
|
||||||
|
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
|
||||||
|
|
||||||
|
---
|
||||||
|
### 1. Facts given in the task
|
||||||
|
List here the specific facts given in the task that could help you (there might be nothing here).
|
||||||
|
|
||||||
|
### 2. Facts to look up
|
||||||
|
List here any facts that we may need to look up.
|
||||||
|
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
|
||||||
|
|
||||||
|
### 3. Facts to derive
|
||||||
|
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
|
||||||
|
|
||||||
|
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
|
||||||
|
### 1. Facts given in the task
|
||||||
|
### 2. Facts to look up
|
||||||
|
### 3. Facts to derive
|
||||||
|
Do not add anything else."""
|
||||||
|
|
||||||
|
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||||
|
|
||||||
|
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||||
|
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||||
|
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||||
|
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
|
||||||
|
|
||||||
|
USER_PROMPT_PLAN = """
|
||||||
|
Here is your task:
|
||||||
|
|
||||||
|
Task:
|
||||||
|
```
|
||||||
|
{task}
|
||||||
|
```
|
||||||
|
|
||||||
|
Your plan can leverage any of these tools:
|
||||||
|
{tool_descriptions}
|
||||||
|
|
||||||
|
List of facts that you know:
|
||||||
|
```
|
||||||
|
{answer_facts}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now begin! Write your plan below."""
|
||||||
|
|
||||||
|
SYSTEM_PROMPT_FACTS_UPDATE = """
|
||||||
|
You are a world expert at gathering known and unknown facts based on a conversation.
|
||||||
|
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
|
||||||
|
### 1. Facts given in the task
|
||||||
|
### 2. Facts that we have learned
|
||||||
|
### 3. Facts still to look up
|
||||||
|
### 4. Facts still to derive
|
||||||
|
Find the task and history below."""
|
||||||
|
|
||||||
|
USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
|
||||||
|
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
|
||||||
|
Please update your list of facts based on the previous history, and provide these headings:
|
||||||
|
### 1. Facts given in the task
|
||||||
|
### 2. Facts that we have learned
|
||||||
|
### 3. Facts still to look up
|
||||||
|
### 4. Facts still to derive
|
||||||
|
|
||||||
|
Now write your new list of facts below."""
|
||||||
|
|
||||||
|
SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||||
|
|
||||||
|
You have been given a task:
|
||||||
|
```
|
||||||
|
{task}
|
||||||
|
```
|
||||||
|
|
||||||
|
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
|
||||||
|
If the previous tries so far have met some success, you can make an updated plan based on these actions.
|
||||||
|
If you are stalled, you can make a completely new plan starting from scratch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
|
||||||
|
```
|
||||||
|
{task}
|
||||||
|
```
|
||||||
|
|
||||||
|
You have access to these tools:
|
||||||
|
{tool_descriptions}
|
||||||
|
|
||||||
|
Here is the up to date list of facts that you know:
|
||||||
|
```
|
||||||
|
{facts_update}
|
||||||
|
```
|
||||||
|
|
||||||
|
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||||
|
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||||
|
Beware that you have {remaining_steps} steps remaining.
|
||||||
|
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||||
|
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
|
||||||
|
|
||||||
|
Now write your new plan below."""
|
||||||
|
|
||||||
|
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
|
||||||
|
```
|
||||||
|
{task}
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is my new/updated plan of action to solve the task:
|
||||||
|
```
|
||||||
|
{plan_update}
|
||||||
|
```"""
|
||||||
|
|||||||
@@ -18,8 +18,17 @@ import ast
|
|||||||
import builtins
|
import builtins
|
||||||
import difflib
|
import difflib
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from importlib import import_module
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..utils import is_pandas_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_pandas_available():
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class InterpreterError(ValueError):
|
class InterpreterError(ValueError):
|
||||||
"""
|
"""
|
||||||
@@ -50,7 +59,8 @@ LIST_SAFE_MODULES = [
|
|||||||
"unicodedata",
|
"unicodedata",
|
||||||
]
|
]
|
||||||
|
|
||||||
PRINT_OUTPUTS = ""
|
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
||||||
|
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
||||||
|
|
||||||
|
|
||||||
class BreakException(Exception):
|
class BreakException(Exception):
|
||||||
@@ -75,8 +85,8 @@ def get_iterable(obj):
|
|||||||
raise InterpreterError("Object is not iterable")
|
raise InterpreterError("Object is not iterable")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_unaryop(expression, state, tools):
|
def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
||||||
operand = evaluate_ast(expression.operand, state, tools)
|
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
|
||||||
if isinstance(expression.op, ast.USub):
|
if isinstance(expression.op, ast.USub):
|
||||||
return -operand
|
return -operand
|
||||||
elif isinstance(expression.op, ast.UAdd):
|
elif isinstance(expression.op, ast.UAdd):
|
||||||
@@ -89,25 +99,25 @@ def evaluate_unaryop(expression, state, tools):
|
|||||||
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_lambda(lambda_expression, state, tools):
|
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
||||||
args = [arg.arg for arg in lambda_expression.args.args]
|
args = [arg.arg for arg in lambda_expression.args.args]
|
||||||
|
|
||||||
def lambda_func(*values):
|
def lambda_func(*values):
|
||||||
new_state = state.copy()
|
new_state = state.copy()
|
||||||
for arg, value in zip(args, values):
|
for arg, value in zip(args, values):
|
||||||
new_state[arg] = value
|
new_state[arg] = value
|
||||||
return evaluate_ast(lambda_expression.body, new_state, tools)
|
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
|
||||||
|
|
||||||
return lambda_func
|
return lambda_func
|
||||||
|
|
||||||
|
|
||||||
def evaluate_while(while_loop, state, tools):
|
def evaluate_while(while_loop, state, static_tools, custom_tools):
|
||||||
max_iterations = 1000
|
max_iterations = 1000
|
||||||
iterations = 0
|
iterations = 0
|
||||||
while evaluate_ast(while_loop.test, state, tools):
|
while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
|
||||||
for node in while_loop.body:
|
for node in while_loop.body:
|
||||||
try:
|
try:
|
||||||
evaluate_ast(node, state, tools)
|
evaluate_ast(node, state, static_tools, custom_tools)
|
||||||
except BreakException:
|
except BreakException:
|
||||||
return None
|
return None
|
||||||
except ContinueException:
|
except ContinueException:
|
||||||
@@ -118,11 +128,11 @@ def evaluate_while(while_loop, state, tools):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def create_function(func_def, state, tools):
|
def create_function(func_def, state, static_tools, custom_tools):
|
||||||
def new_func(*args, **kwargs):
|
def new_func(*args, **kwargs):
|
||||||
func_state = state.copy()
|
func_state = state.copy()
|
||||||
arg_names = [arg.arg for arg in func_def.args.args]
|
arg_names = [arg.arg for arg in func_def.args.args]
|
||||||
default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults]
|
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
|
||||||
|
|
||||||
# Apply default values
|
# Apply default values
|
||||||
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
||||||
@@ -158,7 +168,7 @@ def create_function(func_def, state, tools):
|
|||||||
result = None
|
result = None
|
||||||
try:
|
try:
|
||||||
for stmt in func_def.body:
|
for stmt in func_def.body:
|
||||||
result = evaluate_ast(stmt, func_state, tools)
|
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
|
||||||
except ReturnException as e:
|
except ReturnException as e:
|
||||||
result = e.value
|
result = e.value
|
||||||
return result
|
return result
|
||||||
@@ -173,25 +183,25 @@ def create_class(class_name, class_bases, class_body):
|
|||||||
return type(class_name, tuple(class_bases), class_dict)
|
return type(class_name, tuple(class_bases), class_dict)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_function_def(func_def, state, tools):
|
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
||||||
tools[func_def.name] = create_function(func_def, state, tools)
|
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
|
||||||
return tools[func_def.name]
|
return custom_tools[func_def.name]
|
||||||
|
|
||||||
|
|
||||||
def evaluate_class_def(class_def, state, tools):
|
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
||||||
class_name = class_def.name
|
class_name = class_def.name
|
||||||
bases = [evaluate_ast(base, state, tools) for base in class_def.bases]
|
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
|
||||||
class_dict = {}
|
class_dict = {}
|
||||||
|
|
||||||
for stmt in class_def.body:
|
for stmt in class_def.body:
|
||||||
if isinstance(stmt, ast.FunctionDef):
|
if isinstance(stmt, ast.FunctionDef):
|
||||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
|
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
|
||||||
elif isinstance(stmt, ast.Assign):
|
elif isinstance(stmt, ast.Assign):
|
||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
|
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||||
elif isinstance(target, ast.Attribute):
|
elif isinstance(target, ast.Attribute):
|
||||||
class_dict[target.attr] = evaluate_ast(stmt.value, state, tools)
|
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||||
|
|
||||||
@@ -200,17 +210,17 @@ def evaluate_class_def(class_def, state, tools):
|
|||||||
return new_class
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||||
# Helper function to get current value and set new value based on the target type
|
# Helper function to get current value and set new value based on the target type
|
||||||
def get_current_value(target):
|
def get_current_value(target):
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
return state.get(target.id, 0)
|
return state.get(target.id, 0)
|
||||||
elif isinstance(target, ast.Subscript):
|
elif isinstance(target, ast.Subscript):
|
||||||
obj = evaluate_ast(target.value, state, tools)
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||||
key = evaluate_ast(target.slice, state, tools)
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
||||||
return obj[key]
|
return obj[key]
|
||||||
elif isinstance(target, ast.Attribute):
|
elif isinstance(target, ast.Attribute):
|
||||||
obj = evaluate_ast(target.value, state, tools)
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||||
return getattr(obj, target.attr)
|
return getattr(obj, target.attr)
|
||||||
elif isinstance(target, ast.Tuple):
|
elif isinstance(target, ast.Tuple):
|
||||||
return tuple(get_current_value(elt) for elt in target.elts)
|
return tuple(get_current_value(elt) for elt in target.elts)
|
||||||
@@ -220,7 +230,7 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
|
|||||||
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
||||||
|
|
||||||
current_value = get_current_value(expression.target)
|
current_value = get_current_value(expression.target)
|
||||||
value_to_add = evaluate_ast(expression.value, state, tools)
|
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
|
|
||||||
# Determine the operation and apply it
|
# Determine the operation and apply it
|
||||||
if isinstance(expression.op, ast.Add):
|
if isinstance(expression.op, ast.Add):
|
||||||
@@ -256,28 +266,28 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
|
|||||||
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
||||||
|
|
||||||
# Update the state
|
# Update the state
|
||||||
set_value(expression.target, updated_value, state, tools)
|
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
||||||
|
|
||||||
return updated_value
|
return updated_value
|
||||||
|
|
||||||
|
|
||||||
def evaluate_boolop(node, state, tools):
|
def evaluate_boolop(node, state, static_tools, custom_tools):
|
||||||
if isinstance(node.op, ast.And):
|
if isinstance(node.op, ast.And):
|
||||||
for value in node.values:
|
for value in node.values:
|
||||||
if not evaluate_ast(value, state, tools):
|
if not evaluate_ast(value, state, static_tools, custom_tools):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
elif isinstance(node.op, ast.Or):
|
elif isinstance(node.op, ast.Or):
|
||||||
for value in node.values:
|
for value in node.values:
|
||||||
if evaluate_ast(value, state, tools):
|
if evaluate_ast(value, state, static_tools, custom_tools):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def evaluate_binop(binop, state, tools):
|
def evaluate_binop(binop, state, static_tools, custom_tools):
|
||||||
# Recursively evaluate the left and right operands
|
# Recursively evaluate the left and right operands
|
||||||
left_val = evaluate_ast(binop.left, state, tools)
|
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
|
||||||
right_val = evaluate_ast(binop.right, state, tools)
|
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
|
||||||
|
|
||||||
# Determine the operation based on the type of the operator in the BinOp
|
# Determine the operation based on the type of the operator in the BinOp
|
||||||
if isinstance(binop.op, ast.Add):
|
if isinstance(binop.op, ast.Add):
|
||||||
@@ -308,66 +318,92 @@ def evaluate_binop(binop, state, tools):
|
|||||||
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_assign(assign, state, tools):
|
def evaluate_assign(assign, state, static_tools, custom_tools):
|
||||||
result = evaluate_ast(assign.value, state, tools)
|
result = evaluate_ast(assign.value, state, static_tools, custom_tools)
|
||||||
if len(assign.targets) == 1:
|
if len(assign.targets) == 1:
|
||||||
target = assign.targets[0]
|
target = assign.targets[0]
|
||||||
set_value(target, result, state, tools)
|
set_value(target, result, state, static_tools, custom_tools)
|
||||||
else:
|
else:
|
||||||
if len(assign.targets) != len(result):
|
if len(assign.targets) != len(result):
|
||||||
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
||||||
for tgt, val in zip(assign.targets, result):
|
expanded_values = []
|
||||||
set_value(tgt, val, state, tools)
|
for tgt in assign.targets:
|
||||||
|
if isinstance(tgt, ast.Starred):
|
||||||
|
expanded_values.extend(result)
|
||||||
|
else:
|
||||||
|
expanded_values.append(result)
|
||||||
|
for tgt, val in zip(assign.targets, expanded_values):
|
||||||
|
set_value(tgt, val, state, static_tools, custom_tools)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def set_value(target, value, state, tools):
|
def set_value(target, value, state, static_tools, custom_tools):
|
||||||
if isinstance(target, ast.Name):
|
if isinstance(target, ast.Name):
|
||||||
if target.id in tools:
|
if target.id in static_tools:
|
||||||
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
||||||
state[target.id] = value
|
state[target.id] = value
|
||||||
elif isinstance(target, ast.Tuple):
|
elif isinstance(target, ast.Tuple):
|
||||||
if not isinstance(value, tuple):
|
if not isinstance(value, tuple):
|
||||||
|
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||||
|
value = tuple(value)
|
||||||
|
else:
|
||||||
raise InterpreterError("Cannot unpack non-tuple value")
|
raise InterpreterError("Cannot unpack non-tuple value")
|
||||||
if len(target.elts) != len(value):
|
if len(target.elts) != len(value):
|
||||||
raise InterpreterError("Cannot unpack tuple of wrong size")
|
raise InterpreterError("Cannot unpack tuple of wrong size")
|
||||||
for i, elem in enumerate(target.elts):
|
for i, elem in enumerate(target.elts):
|
||||||
set_value(elem, value[i], state, tools)
|
set_value(elem, value[i], state, static_tools, custom_tools)
|
||||||
elif isinstance(target, ast.Subscript):
|
elif isinstance(target, ast.Subscript):
|
||||||
obj = evaluate_ast(target.value, state, tools)
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||||
key = evaluate_ast(target.slice, state, tools)
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
||||||
obj[key] = value
|
obj[key] = value
|
||||||
elif isinstance(target, ast.Attribute):
|
elif isinstance(target, ast.Attribute):
|
||||||
obj = evaluate_ast(target.value, state, tools)
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||||
setattr(obj, target.attr, value)
|
setattr(obj, target.attr, value)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_call(call, state, tools):
|
def evaluate_call(call, state, static_tools, custom_tools):
|
||||||
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
||||||
raise InterpreterError(
|
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
|
||||||
)
|
|
||||||
if isinstance(call.func, ast.Attribute):
|
if isinstance(call.func, ast.Attribute):
|
||||||
obj = evaluate_ast(call.func.value, state, tools)
|
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
|
||||||
func_name = call.func.attr
|
func_name = call.func.attr
|
||||||
if not hasattr(obj, func_name):
|
if not hasattr(obj, func_name):
|
||||||
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
||||||
func = getattr(obj, func_name)
|
func = getattr(obj, func_name)
|
||||||
|
|
||||||
elif isinstance(call.func, ast.Name):
|
elif isinstance(call.func, ast.Name):
|
||||||
func_name = call.func.id
|
func_name = call.func.id
|
||||||
if func_name in state:
|
if func_name in state:
|
||||||
func = state[func_name]
|
func = state[func_name]
|
||||||
elif func_name in tools:
|
elif func_name in static_tools:
|
||||||
func = tools[func_name]
|
func = static_tools[func_name]
|
||||||
|
elif func_name in custom_tools:
|
||||||
|
func = custom_tools[func_name]
|
||||||
elif func_name in ERRORS:
|
elif func_name in ERRORS:
|
||||||
func = ERRORS[func_name]
|
func = ERRORS[func_name]
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(
|
raise InterpreterError(
|
||||||
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
||||||
)
|
)
|
||||||
|
|
||||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
args = []
|
||||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
for arg in call.args:
|
||||||
|
if isinstance(arg, ast.Starred):
|
||||||
|
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
|
||||||
|
else:
|
||||||
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||||
|
|
||||||
|
args = []
|
||||||
|
for arg in call.args:
|
||||||
|
if isinstance(arg, ast.Starred):
|
||||||
|
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
|
||||||
|
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
|
||||||
|
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
|
||||||
|
args.extend(unpacked)
|
||||||
|
else:
|
||||||
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||||
|
|
||||||
|
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
|
||||||
|
|
||||||
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
||||||
# Instantiate the class using its constructor
|
# Instantiate the class using its constructor
|
||||||
@@ -397,24 +433,31 @@ def evaluate_call(call, state, tools):
|
|||||||
output = " ".join(map(str, args))
|
output = " ".join(map(str, args))
|
||||||
global PRINT_OUTPUTS
|
global PRINT_OUTPUTS
|
||||||
PRINT_OUTPUTS += output + "\n"
|
PRINT_OUTPUTS += output + "\n"
|
||||||
|
# cap the number of lines
|
||||||
return output
|
return output
|
||||||
else: # Assume it's a callable object
|
else: # Assume it's a callable object
|
||||||
output = func(*args, **kwargs)
|
output = func(*args, **kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def evaluate_subscript(subscript, state, tools):
|
def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
||||||
index = evaluate_ast(subscript.slice, state, tools)
|
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
||||||
value = evaluate_ast(subscript.value, state, tools)
|
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
||||||
if isinstance(index, slice):
|
|
||||||
|
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||||
|
parent_object = value.obj
|
||||||
|
return parent_object.loc[index]
|
||||||
|
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
|
||||||
|
return value[index]
|
||||||
|
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
|
||||||
|
return value[index]
|
||||||
|
elif isinstance(index, slice):
|
||||||
return value[index]
|
return value[index]
|
||||||
elif isinstance(value, (list, tuple)):
|
elif isinstance(value, (list, tuple)):
|
||||||
# Ensure the index is within bounds
|
|
||||||
if not (-len(value) <= index < len(value)):
|
if not (-len(value) <= index < len(value)):
|
||||||
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
||||||
return value[int(index)]
|
return value[int(index)]
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
# Ensure the index is within bounds
|
|
||||||
if not (-len(value) <= index < len(value)):
|
if not (-len(value) <= index < len(value)):
|
||||||
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
||||||
return value[index]
|
return value[index]
|
||||||
@@ -427,11 +470,11 @@ def evaluate_subscript(subscript, state, tools):
|
|||||||
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_name(name, state, tools):
|
def evaluate_name(name, state, static_tools, custom_tools):
|
||||||
if name.id in state:
|
if name.id in state:
|
||||||
return state[name.id]
|
return state[name.id]
|
||||||
elif name.id in tools:
|
elif name.id in static_tools:
|
||||||
return tools[name.id]
|
return static_tools[name.id]
|
||||||
elif name.id in ERRORS:
|
elif name.id in ERRORS:
|
||||||
return ERRORS[name.id]
|
return ERRORS[name.id]
|
||||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||||
@@ -440,9 +483,9 @@ def evaluate_name(name, state, tools):
|
|||||||
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_condition(condition, state, tools):
|
def evaluate_condition(condition, state, static_tools, custom_tools):
|
||||||
left = evaluate_ast(condition.left, state, tools)
|
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
||||||
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
|
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
|
||||||
ops = [type(op) for op in condition.ops]
|
ops = [type(op) for op in condition.ops]
|
||||||
|
|
||||||
result = True
|
result = True
|
||||||
@@ -450,63 +493,61 @@ def evaluate_condition(condition, state, tools):
|
|||||||
|
|
||||||
for op, comparator in zip(ops, comparators):
|
for op, comparator in zip(ops, comparators):
|
||||||
if op == ast.Eq:
|
if op == ast.Eq:
|
||||||
result = result and (current_left == comparator)
|
current_result = current_left == comparator
|
||||||
elif op == ast.NotEq:
|
elif op == ast.NotEq:
|
||||||
result = result and (current_left != comparator)
|
current_result = current_left != comparator
|
||||||
elif op == ast.Lt:
|
elif op == ast.Lt:
|
||||||
result = result and (current_left < comparator)
|
current_result = current_left < comparator
|
||||||
elif op == ast.LtE:
|
elif op == ast.LtE:
|
||||||
result = result and (current_left <= comparator)
|
current_result = current_left <= comparator
|
||||||
elif op == ast.Gt:
|
elif op == ast.Gt:
|
||||||
result = result and (current_left > comparator)
|
current_result = current_left > comparator
|
||||||
elif op == ast.GtE:
|
elif op == ast.GtE:
|
||||||
result = result and (current_left >= comparator)
|
current_result = current_left >= comparator
|
||||||
elif op == ast.Is:
|
elif op == ast.Is:
|
||||||
result = result and (current_left is comparator)
|
current_result = current_left is comparator
|
||||||
elif op == ast.IsNot:
|
elif op == ast.IsNot:
|
||||||
result = result and (current_left is not comparator)
|
current_result = current_left is not comparator
|
||||||
elif op == ast.In:
|
elif op == ast.In:
|
||||||
result = result and (current_left in comparator)
|
current_result = current_left in comparator
|
||||||
elif op == ast.NotIn:
|
elif op == ast.NotIn:
|
||||||
result = result and (current_left not in comparator)
|
current_result = current_left not in comparator
|
||||||
else:
|
else:
|
||||||
raise InterpreterError(f"Operator not supported: {op}")
|
raise InterpreterError(f"Operator not supported: {op}")
|
||||||
|
|
||||||
|
result = result & current_result
|
||||||
current_left = comparator
|
current_left = comparator
|
||||||
if not result:
|
|
||||||
|
if isinstance(result, bool) and not result:
|
||||||
break
|
break
|
||||||
|
|
||||||
return result
|
return result if isinstance(result, (bool, pd.Series)) else result.all()
|
||||||
|
|
||||||
|
|
||||||
def evaluate_if(if_statement, state, tools):
|
def evaluate_if(if_statement, state, static_tools, custom_tools):
|
||||||
result = None
|
result = None
|
||||||
test_result = evaluate_ast(if_statement.test, state, tools)
|
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
|
||||||
if test_result:
|
if test_result:
|
||||||
for line in if_statement.body:
|
for line in if_statement.body:
|
||||||
line_result = evaluate_ast(line, state, tools)
|
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
||||||
if line_result is not None:
|
if line_result is not None:
|
||||||
result = line_result
|
result = line_result
|
||||||
else:
|
else:
|
||||||
for line in if_statement.orelse:
|
for line in if_statement.orelse:
|
||||||
line_result = evaluate_ast(line, state, tools)
|
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
||||||
if line_result is not None:
|
if line_result is not None:
|
||||||
result = line_result
|
result = line_result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def evaluate_for(for_loop, state, tools):
|
def evaluate_for(for_loop, state, static_tools, custom_tools):
|
||||||
result = None
|
result = None
|
||||||
iterator = evaluate_ast(for_loop.iter, state, tools)
|
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
|
||||||
for counter in iterator:
|
for counter in iterator:
|
||||||
if isinstance(for_loop.target, ast.Tuple):
|
set_value(for_loop.target, counter, state, static_tools, custom_tools)
|
||||||
for i, elem in enumerate(for_loop.target.elts):
|
|
||||||
state[elem.id] = counter[i]
|
|
||||||
else:
|
|
||||||
state[for_loop.target.id] = counter
|
|
||||||
for node in for_loop.body:
|
for node in for_loop.body:
|
||||||
try:
|
try:
|
||||||
line_result = evaluate_ast(node, state, tools)
|
line_result = evaluate_ast(node, state, static_tools, custom_tools)
|
||||||
if line_result is not None:
|
if line_result is not None:
|
||||||
result = line_result
|
result = line_result
|
||||||
except BreakException:
|
except BreakException:
|
||||||
@@ -519,55 +560,60 @@ def evaluate_for(for_loop, state, tools):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def evaluate_listcomp(listcomp, state, tools):
|
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
||||||
|
def inner_evaluate(generators, index, current_state):
|
||||||
|
if index >= len(generators):
|
||||||
|
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
|
||||||
|
generator = generators[index]
|
||||||
|
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
|
||||||
result = []
|
result = []
|
||||||
for generator in listcomp.generators:
|
|
||||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
|
||||||
for value in iter_value:
|
for value in iter_value:
|
||||||
new_state = state.copy()
|
new_state = current_state.copy()
|
||||||
if isinstance(generator.target, ast.Tuple):
|
if isinstance(generator.target, ast.Tuple):
|
||||||
for idx, elem in enumerate(generator.target.elts):
|
for idx, elem in enumerate(generator.target.elts):
|
||||||
new_state[elem.id] = value[idx]
|
new_state[elem.id] = value[idx]
|
||||||
else:
|
else:
|
||||||
new_state[generator.target.id] = value
|
new_state[generator.target.id] = value
|
||||||
if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs):
|
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
|
||||||
result.append(evaluate_ast(listcomp.elt, new_state, tools))
|
result.extend(inner_evaluate(generators, index + 1, new_state))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
return inner_evaluate(listcomp.generators, 0, state)
|
||||||
|
|
||||||
def evaluate_try(try_node, state, tools):
|
|
||||||
|
def evaluate_try(try_node, state, static_tools, custom_tools):
|
||||||
try:
|
try:
|
||||||
for stmt in try_node.body:
|
for stmt in try_node.body:
|
||||||
evaluate_ast(stmt, state, tools)
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
matched = False
|
matched = False
|
||||||
for handler in try_node.handlers:
|
for handler in try_node.handlers:
|
||||||
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)):
|
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
|
||||||
matched = True
|
matched = True
|
||||||
if handler.name:
|
if handler.name:
|
||||||
state[handler.name] = e
|
state[handler.name] = e
|
||||||
for stmt in handler.body:
|
for stmt in handler.body:
|
||||||
evaluate_ast(stmt, state, tools)
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||||
break
|
break
|
||||||
if not matched:
|
if not matched:
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
if try_node.orelse:
|
if try_node.orelse:
|
||||||
for stmt in try_node.orelse:
|
for stmt in try_node.orelse:
|
||||||
evaluate_ast(stmt, state, tools)
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||||
finally:
|
finally:
|
||||||
if try_node.finalbody:
|
if try_node.finalbody:
|
||||||
for stmt in try_node.finalbody:
|
for stmt in try_node.finalbody:
|
||||||
evaluate_ast(stmt, state, tools)
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_raise(raise_node, state, tools):
|
def evaluate_raise(raise_node, state, static_tools, custom_tools):
|
||||||
if raise_node.exc is not None:
|
if raise_node.exc is not None:
|
||||||
exc = evaluate_ast(raise_node.exc, state, tools)
|
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
|
||||||
else:
|
else:
|
||||||
exc = None
|
exc = None
|
||||||
if raise_node.cause is not None:
|
if raise_node.cause is not None:
|
||||||
cause = evaluate_ast(raise_node.cause, state, tools)
|
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
|
||||||
else:
|
else:
|
||||||
cause = None
|
cause = None
|
||||||
if exc is not None:
|
if exc is not None:
|
||||||
@@ -579,11 +625,11 @@ def evaluate_raise(raise_node, state, tools):
|
|||||||
raise InterpreterError("Re-raise is not supported without an active exception")
|
raise InterpreterError("Re-raise is not supported without an active exception")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_assert(assert_node, state, tools):
|
def evaluate_assert(assert_node, state, static_tools, custom_tools):
|
||||||
test_result = evaluate_ast(assert_node.test, state, tools)
|
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
|
||||||
if not test_result:
|
if not test_result:
|
||||||
if assert_node.msg:
|
if assert_node.msg:
|
||||||
msg = evaluate_ast(assert_node.msg, state, tools)
|
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg)
|
||||||
else:
|
else:
|
||||||
# Include the failing condition in the assertion message
|
# Include the failing condition in the assertion message
|
||||||
@@ -591,10 +637,10 @@ def evaluate_assert(assert_node, state, tools):
|
|||||||
raise AssertionError(f"Assertion failed: {test_code}")
|
raise AssertionError(f"Assertion failed: {test_code}")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_with(with_node, state, tools):
|
def evaluate_with(with_node, state, static_tools, custom_tools):
|
||||||
contexts = []
|
contexts = []
|
||||||
for item in with_node.items:
|
for item in with_node.items:
|
||||||
context_expr = evaluate_ast(item.context_expr, state, tools)
|
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
|
||||||
if item.optional_vars:
|
if item.optional_vars:
|
||||||
state[item.optional_vars.id] = context_expr.__enter__()
|
state[item.optional_vars.id] = context_expr.__enter__()
|
||||||
contexts.append(state[item.optional_vars.id])
|
contexts.append(state[item.optional_vars.id])
|
||||||
@@ -604,7 +650,7 @@ def evaluate_with(with_node, state, tools):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for stmt in with_node.body:
|
for stmt in with_node.body:
|
||||||
evaluate_ast(stmt, state, tools)
|
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
for context in reversed(contexts):
|
for context in reversed(contexts):
|
||||||
context.__exit__(type(e), e, e.__traceback__)
|
context.__exit__(type(e), e, e.__traceback__)
|
||||||
@@ -614,10 +660,51 @@ def evaluate_with(with_node, state, tools):
|
|||||||
context.__exit__(None, None, None)
|
context.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def import_modules(expression, state, authorized_imports):
|
||||||
|
def check_module_authorized(module_name):
|
||||||
|
module_path = module_name.split(".")
|
||||||
|
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||||
|
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||||
|
|
||||||
|
if isinstance(expression, ast.Import):
|
||||||
|
for alias in expression.names:
|
||||||
|
if check_module_authorized(alias.name):
|
||||||
|
module = import_module(alias.name)
|
||||||
|
state[alias.asname or alias.name] = module
|
||||||
|
else:
|
||||||
|
raise InterpreterError(
|
||||||
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
|
if check_module_authorized(expression.module):
|
||||||
|
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||||
|
for alias in expression.names:
|
||||||
|
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||||
|
else:
|
||||||
|
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
|
||||||
|
result = {}
|
||||||
|
for gen in dictcomp.generators:
|
||||||
|
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
|
||||||
|
for value in iter_value:
|
||||||
|
new_state = state.copy()
|
||||||
|
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
||||||
|
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
|
||||||
|
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
||||||
|
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
|
||||||
|
result[key] = val
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def evaluate_ast(
|
def evaluate_ast(
|
||||||
expression: ast.AST,
|
expression: ast.AST,
|
||||||
state: Dict[str, Any],
|
state: Dict[str, Any],
|
||||||
tools: Dict[str, Callable],
|
static_tools: Dict[str, Callable],
|
||||||
|
custom_tools: Dict[str, Callable],
|
||||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -632,146 +719,128 @@ def evaluate_ast(
|
|||||||
state (`Dict[str, Any]`):
|
state (`Dict[str, Any]`):
|
||||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||||
encounters assignements.
|
encounters assignements.
|
||||||
tools (`Dict[str, Callable]`):
|
static_tools (`Dict[str, Callable]`):
|
||||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
|
||||||
`InterpreterError`.
|
custom_tools (`Dict[str, Callable]`):
|
||||||
|
Functions that may be called during the evaluation. These static_tools can be overwritten.
|
||||||
authorized_imports (`List[str]`):
|
authorized_imports (`List[str]`):
|
||||||
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
||||||
Add more at your own risk!
|
Add more at your own risk!
|
||||||
"""
|
"""
|
||||||
|
global OPERATIONS_COUNT
|
||||||
|
if OPERATIONS_COUNT >= MAX_OPERATIONS:
|
||||||
|
raise InterpreterError(
|
||||||
|
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
|
||||||
|
)
|
||||||
|
OPERATIONS_COUNT += 1
|
||||||
if isinstance(expression, ast.Assign):
|
if isinstance(expression, ast.Assign):
|
||||||
# Assignement -> we evaluate the assignment which should update the state
|
# Assignement -> we evaluate the assignment which should update the state
|
||||||
# We return the variable assigned as it may be used to determine the final result.
|
# We return the variable assigned as it may be used to determine the final result.
|
||||||
return evaluate_assign(expression, state, tools)
|
return evaluate_assign(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.AugAssign):
|
elif isinstance(expression, ast.AugAssign):
|
||||||
return evaluate_augassign(expression, state, tools)
|
return evaluate_augassign(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Call):
|
elif isinstance(expression, ast.Call):
|
||||||
# Function call -> we return the value of the function call
|
# Function call -> we return the value of the function call
|
||||||
return evaluate_call(expression, state, tools)
|
return evaluate_call(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Constant):
|
elif isinstance(expression, ast.Constant):
|
||||||
# Constant -> just return the value
|
# Constant -> just return the value
|
||||||
return expression.value
|
return expression.value
|
||||||
elif isinstance(expression, ast.Tuple):
|
elif isinstance(expression, ast.Tuple):
|
||||||
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts)
|
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
|
||||||
elif isinstance(expression, ast.ListComp):
|
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
||||||
return evaluate_listcomp(expression, state, tools)
|
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.UnaryOp):
|
elif isinstance(expression, ast.UnaryOp):
|
||||||
return evaluate_unaryop(expression, state, tools)
|
return evaluate_unaryop(expression, state, static_tools, custom_tools)
|
||||||
|
elif isinstance(expression, ast.Starred):
|
||||||
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.BoolOp):
|
elif isinstance(expression, ast.BoolOp):
|
||||||
# Boolean operation -> evaluate the operation
|
# Boolean operation -> evaluate the operation
|
||||||
return evaluate_boolop(expression, state, tools)
|
return evaluate_boolop(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Break):
|
elif isinstance(expression, ast.Break):
|
||||||
raise BreakException()
|
raise BreakException()
|
||||||
elif isinstance(expression, ast.Continue):
|
elif isinstance(expression, ast.Continue):
|
||||||
raise ContinueException()
|
raise ContinueException()
|
||||||
elif isinstance(expression, ast.BinOp):
|
elif isinstance(expression, ast.BinOp):
|
||||||
# Binary operation -> execute operation
|
# Binary operation -> execute operation
|
||||||
return evaluate_binop(expression, state, tools)
|
return evaluate_binop(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Compare):
|
elif isinstance(expression, ast.Compare):
|
||||||
# Comparison -> evaluate the comparison
|
# Comparison -> evaluate the comparison
|
||||||
return evaluate_condition(expression, state, tools)
|
return evaluate_condition(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Lambda):
|
elif isinstance(expression, ast.Lambda):
|
||||||
return evaluate_lambda(expression, state, tools)
|
return evaluate_lambda(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.FunctionDef):
|
elif isinstance(expression, ast.FunctionDef):
|
||||||
return evaluate_function_def(expression, state, tools)
|
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Dict):
|
elif isinstance(expression, ast.Dict):
|
||||||
# Dict -> evaluate all keys and values
|
# Dict -> evaluate all keys and values
|
||||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
|
||||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
|
||||||
return dict(zip(keys, values))
|
return dict(zip(keys, values))
|
||||||
elif isinstance(expression, ast.Expr):
|
elif isinstance(expression, ast.Expr):
|
||||||
# Expression -> evaluate the content
|
# Expression -> evaluate the content
|
||||||
return evaluate_ast(expression.value, state, tools)
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.For):
|
elif isinstance(expression, ast.For):
|
||||||
# For loop -> execute the loop
|
# For loop -> execute the loop
|
||||||
return evaluate_for(expression, state, tools)
|
return evaluate_for(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.FormattedValue):
|
elif isinstance(expression, ast.FormattedValue):
|
||||||
# Formatted value (part of f-string) -> evaluate the content and return
|
# Formatted value (part of f-string) -> evaluate the content and return
|
||||||
return evaluate_ast(expression.value, state, tools)
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.If):
|
elif isinstance(expression, ast.If):
|
||||||
# If -> execute the right branch
|
# If -> execute the right branch
|
||||||
return evaluate_if(expression, state, tools)
|
return evaluate_if(expression, state, static_tools, custom_tools)
|
||||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||||
return evaluate_ast(expression.value, state, tools)
|
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.JoinedStr):
|
elif isinstance(expression, ast.JoinedStr):
|
||||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
|
||||||
elif isinstance(expression, ast.List):
|
elif isinstance(expression, ast.List):
|
||||||
# List -> evaluate all elements
|
# List -> evaluate all elements
|
||||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
|
||||||
elif isinstance(expression, ast.Name):
|
elif isinstance(expression, ast.Name):
|
||||||
# Name -> pick up the value in the state
|
# Name -> pick up the value in the state
|
||||||
return evaluate_name(expression, state, tools)
|
return evaluate_name(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Subscript):
|
elif isinstance(expression, ast.Subscript):
|
||||||
# Subscript -> return the value of the indexing
|
# Subscript -> return the value of the indexing
|
||||||
return evaluate_subscript(expression, state, tools)
|
return evaluate_subscript(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.IfExp):
|
elif isinstance(expression, ast.IfExp):
|
||||||
test_val = evaluate_ast(expression.test, state, tools)
|
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
|
||||||
if test_val:
|
if test_val:
|
||||||
return evaluate_ast(expression.body, state, tools)
|
return evaluate_ast(expression.body, state, static_tools, custom_tools)
|
||||||
else:
|
else:
|
||||||
return evaluate_ast(expression.orelse, state, tools)
|
return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Attribute):
|
elif isinstance(expression, ast.Attribute):
|
||||||
obj = evaluate_ast(expression.value, state, tools)
|
value = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||||
return getattr(obj, expression.attr)
|
return getattr(value, expression.attr)
|
||||||
elif isinstance(expression, ast.Slice):
|
elif isinstance(expression, ast.Slice):
|
||||||
return slice(
|
return slice(
|
||||||
evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None,
|
evaluate_ast(expression.lower, state, static_tools, custom_tools)
|
||||||
evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None,
|
if expression.lower is not None
|
||||||
evaluate_ast(expression.step, state, tools) if expression.step is not None else None,
|
else None,
|
||||||
|
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
||||||
|
if expression.upper is not None
|
||||||
|
else None,
|
||||||
|
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
|
||||||
)
|
)
|
||||||
elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp):
|
|
||||||
result = []
|
|
||||||
vars = {}
|
|
||||||
for generator in expression.generators:
|
|
||||||
var_name = generator.target.id
|
|
||||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
|
||||||
for value in iter_value:
|
|
||||||
vars[var_name] = value
|
|
||||||
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
|
|
||||||
elem = evaluate_ast(expression.elt, {**state, **vars}, tools)
|
|
||||||
result.append(elem)
|
|
||||||
return result
|
|
||||||
elif isinstance(expression, ast.DictComp):
|
elif isinstance(expression, ast.DictComp):
|
||||||
result = {}
|
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
||||||
for gen in expression.generators:
|
|
||||||
for container in get_iterable(evaluate_ast(gen.iter, state, tools)):
|
|
||||||
state[gen.target.id] = container
|
|
||||||
key = evaluate_ast(expression.key, state, tools)
|
|
||||||
value = evaluate_ast(expression.value, state, tools)
|
|
||||||
result[key] = value
|
|
||||||
return result
|
|
||||||
elif isinstance(expression, ast.Import):
|
|
||||||
for alias in expression.names:
|
|
||||||
if alias.name in authorized_imports:
|
|
||||||
module = __import__(alias.name)
|
|
||||||
state[alias.asname or alias.name] = module
|
|
||||||
else:
|
|
||||||
raise InterpreterError(f"Import of {alias.name} is not allowed.")
|
|
||||||
return None
|
|
||||||
elif isinstance(expression, ast.While):
|
elif isinstance(expression, ast.While):
|
||||||
return evaluate_while(expression, state, tools)
|
return evaluate_while(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
||||||
if expression.module in authorized_imports:
|
return import_modules(expression, state, authorized_imports)
|
||||||
module = __import__(expression.module)
|
|
||||||
for alias in expression.names:
|
|
||||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
|
||||||
else:
|
|
||||||
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
|
||||||
return None
|
|
||||||
elif isinstance(expression, ast.ClassDef):
|
elif isinstance(expression, ast.ClassDef):
|
||||||
return evaluate_class_def(expression, state, tools)
|
return evaluate_class_def(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Try):
|
elif isinstance(expression, ast.Try):
|
||||||
return evaluate_try(expression, state, tools)
|
return evaluate_try(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Raise):
|
elif isinstance(expression, ast.Raise):
|
||||||
return evaluate_raise(expression, state, tools)
|
return evaluate_raise(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Assert):
|
elif isinstance(expression, ast.Assert):
|
||||||
return evaluate_assert(expression, state, tools)
|
return evaluate_assert(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.With):
|
elif isinstance(expression, ast.With):
|
||||||
return evaluate_with(expression, state, tools)
|
return evaluate_with(expression, state, static_tools, custom_tools)
|
||||||
elif isinstance(expression, ast.Set):
|
elif isinstance(expression, ast.Set):
|
||||||
return {evaluate_ast(elt, state, tools) for elt in expression.elts}
|
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
|
||||||
elif isinstance(expression, ast.Return):
|
elif isinstance(expression, ast.Return):
|
||||||
raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None)
|
raise ReturnException(
|
||||||
|
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# For now we refuse anything else. Let's add things as we need them.
|
# For now we refuse anything else. Let's add things as we need them.
|
||||||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||||
@@ -779,7 +848,8 @@ def evaluate_ast(
|
|||||||
|
|
||||||
def evaluate_python_code(
|
def evaluate_python_code(
|
||||||
code: str,
|
code: str,
|
||||||
tools: Optional[Dict[str, Callable]] = None,
|
static_tools: Optional[Dict[str, Callable]] = None,
|
||||||
|
custom_tools: Optional[Dict[str, Callable]] = None,
|
||||||
state: Optional[Dict[str, Any]] = None,
|
state: Optional[Dict[str, Any]] = None,
|
||||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||||
):
|
):
|
||||||
@@ -792,9 +862,12 @@ def evaluate_python_code(
|
|||||||
Args:
|
Args:
|
||||||
code (`str`):
|
code (`str`):
|
||||||
The code to evaluate.
|
The code to evaluate.
|
||||||
tools (`Dict[str, Callable]`):
|
static_tools (`Dict[str, Callable]`):
|
||||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
The functions that may be called during the evaluation.
|
||||||
`InterpreterError`.
|
These tools cannot be overwritten in the code: any assignment to their name will raise an error.
|
||||||
|
custom_tools (`Dict[str, Callable]`):
|
||||||
|
The functions that may be called during the evaluation.
|
||||||
|
These tools can be overwritten in the code: any assignment to their name will overwrite them.
|
||||||
state (`Dict[str, Any]`):
|
state (`Dict[str, Any]`):
|
||||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||||
updated by this function to contain all variables as they are evaluated.
|
updated by this function to contain all variables as they are evaluated.
|
||||||
@@ -806,20 +879,34 @@ def evaluate_python_code(
|
|||||||
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
||||||
if state is None:
|
if state is None:
|
||||||
state = {}
|
state = {}
|
||||||
if tools is None:
|
if static_tools is None:
|
||||||
tools = {}
|
static_tools = {}
|
||||||
|
if custom_tools is None:
|
||||||
|
custom_tools = {}
|
||||||
result = None
|
result = None
|
||||||
global PRINT_OUTPUTS
|
global PRINT_OUTPUTS
|
||||||
PRINT_OUTPUTS = ""
|
PRINT_OUTPUTS = ""
|
||||||
|
global OPERATIONS_COUNT
|
||||||
|
OPERATIONS_COUNT = 0
|
||||||
for node in expression.body:
|
for node in expression.body:
|
||||||
try:
|
try:
|
||||||
result = evaluate_ast(node, state, tools, authorized_imports)
|
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||||
except InterpreterError as e:
|
except InterpreterError as e:
|
||||||
msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
msg = ""
|
||||||
if len(PRINT_OUTPUTS) > 0:
|
if len(PRINT_OUTPUTS) > 0:
|
||||||
msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n"
|
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
||||||
|
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
|
||||||
|
else:
|
||||||
|
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
|
||||||
|
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||||
raise InterpreterError(msg)
|
raise InterpreterError(msg)
|
||||||
finally:
|
finally:
|
||||||
|
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
||||||
state["print_outputs"] = PRINT_OUTPUTS
|
state["print_outputs"] = PRINT_OUTPUTS
|
||||||
|
else:
|
||||||
|
state["print_outputs"] = (
|
||||||
|
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
|
||||||
|
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ Action:
|
|||||||
# check that add_base_tools will not interfere with existing tools
|
# check that add_base_tools will not interfere with existing tools
|
||||||
with pytest.raises(KeyError) as e:
|
with pytest.raises(KeyError) as e:
|
||||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
||||||
assert "python_interpreter already exists in the toolbox" in str(e)
|
assert "already exists in the toolbox" in str(e)
|
||||||
|
|
||||||
# check that python_interpreter base tool does not get added to code agents
|
# check that python_interpreter base tool does not get added to code agents
|
||||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import load_tool
|
from transformers import load_tool
|
||||||
@@ -241,8 +242,41 @@ for block in text_block:
|
|||||||
code = """
|
code = """
|
||||||
digits, i = [1, 2, 3], 1
|
digits, i = [1, 2, 3], 1
|
||||||
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
||||||
|
evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
|
||||||
|
|
||||||
|
code = """
|
||||||
|
def calculate_isbn_10_check_digit(number):
|
||||||
|
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
|
||||||
|
remainder = total % 11
|
||||||
|
check_digit = 11 - remainder
|
||||||
|
if check_digit == 10:
|
||||||
|
return 'X'
|
||||||
|
elif check_digit == 11:
|
||||||
|
return '0'
|
||||||
|
else:
|
||||||
|
return str(check_digit)
|
||||||
|
|
||||||
|
# Given 9-digit numbers
|
||||||
|
numbers = [
|
||||||
|
"478225952",
|
||||||
|
"643485613",
|
||||||
|
"739394228",
|
||||||
|
"291726859",
|
||||||
|
"875262394",
|
||||||
|
"542617795",
|
||||||
|
"031810713",
|
||||||
|
"957007669",
|
||||||
|
"871467426"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate check digits for each number
|
||||||
|
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
|
||||||
|
print(check_digits)
|
||||||
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state)
|
evaluate_python_code(
|
||||||
|
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
|
||||||
|
)
|
||||||
|
|
||||||
def test_listcomp(self):
|
def test_listcomp(self):
|
||||||
code = "x = [i for i in range(3)]"
|
code = "x = [i for i in range(3)]"
|
||||||
@@ -273,6 +307,17 @@ digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
|||||||
result = evaluate_python_code(code, {"range": range}, state={})
|
result = evaluate_python_code(code, {"range": range}, state={})
|
||||||
assert result == {0: 0, 1: 1, 2: 4}
|
assert result == {0: 0, 1: 1, 2: 4}
|
||||||
|
|
||||||
|
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||||
|
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||||
|
assert result == {102: "b"}
|
||||||
|
|
||||||
|
code = """
|
||||||
|
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
|
||||||
|
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
|
||||||
|
"""
|
||||||
|
result = evaluate_python_code(code, {}, state={})
|
||||||
|
assert result == {"A": ("a", "b"), "B": ("a", "b")}
|
||||||
|
|
||||||
def test_tuple_assignment(self):
|
def test_tuple_assignment(self):
|
||||||
code = "a, b = 0, 1\nb"
|
code = "a, b = 0, 1\nb"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
@@ -341,7 +386,7 @@ if char.isalpha():
|
|||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "lose"
|
assert result == "lose"
|
||||||
|
|
||||||
code = "import time\ntime.sleep(0.1)"
|
code = "import time, re\ntime.sleep(0.1)"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@@ -369,6 +414,23 @@ if char.isalpha():
|
|||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result == "LATIN CAPITAL LETTER A"
|
assert result == "LATIN CAPITAL LETTER A"
|
||||||
|
|
||||||
|
# Test submodules are handled properly, thus not raising error
|
||||||
|
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||||
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||||
|
|
||||||
|
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||||
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||||
|
|
||||||
|
def test_additional_imports(self):
|
||||||
|
code = "import numpy as np"
|
||||||
|
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||||
|
|
||||||
|
code = "import numpy.random as rd"
|
||||||
|
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
|
||||||
|
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||||
|
with pytest.raises(InterpreterError):
|
||||||
|
evaluate_python_code(code, authorized_imports=["random"], state={})
|
||||||
|
|
||||||
def test_multiple_comparators(self):
|
def test_multiple_comparators(self):
|
||||||
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
@@ -400,7 +462,7 @@ def function():
|
|||||||
print("2")
|
print("2")
|
||||||
function()"""
|
function()"""
|
||||||
state = {}
|
state = {}
|
||||||
evaluate_python_code(code, {"print": print}, state)
|
evaluate_python_code(code, {"print": print}, state=state)
|
||||||
assert state["print_outputs"] == "1\n2\n"
|
assert state["print_outputs"] == "1\n2\n"
|
||||||
|
|
||||||
def test_tuple_target_in_iterator(self):
|
def test_tuple_target_in_iterator(self):
|
||||||
@@ -612,7 +674,7 @@ assert lock.locked == False
|
|||||||
"""
|
"""
|
||||||
state = {}
|
state = {}
|
||||||
tools = {}
|
tools = {}
|
||||||
evaluate_python_code(code, tools, state)
|
evaluate_python_code(code, tools, state=state)
|
||||||
|
|
||||||
def test_default_arg_in_function(self):
|
def test_default_arg_in_function(self):
|
||||||
code = """
|
code = """
|
||||||
@@ -672,3 +734,94 @@ returns_none(1)
|
|||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
def test_nested_for_loop(self):
|
||||||
|
code = """
|
||||||
|
all_res = []
|
||||||
|
for i in range(10):
|
||||||
|
subres = []
|
||||||
|
for j in range(i):
|
||||||
|
subres.append(j)
|
||||||
|
all_res.append(subres)
|
||||||
|
|
||||||
|
out = [i for sublist in all_res for i in sublist]
|
||||||
|
out[:10]
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
||||||
|
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||||
|
|
||||||
|
def test_pandas(self):
|
||||||
|
code = """
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
|
||||||
|
|
||||||
|
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
|
||||||
|
|
||||||
|
parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
||||||
|
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
||||||
|
assert np.array_equal(result, [-1, 5])
|
||||||
|
|
||||||
|
code = """
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
|
||||||
|
print("HH0")
|
||||||
|
|
||||||
|
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||||
|
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||||
|
"""
|
||||||
|
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||||
|
assert np.array_equal(result.values[0], [104, 1])
|
||||||
|
|
||||||
|
code = """import pandas as pd
|
||||||
|
data = pd.DataFrame.from_dict([
|
||||||
|
{"Pclass": 1, "Survived": 1},
|
||||||
|
{"Pclass": 2, "Survived": 0},
|
||||||
|
{"Pclass": 2, "Survived": 1}
|
||||||
|
])
|
||||||
|
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
||||||
|
"""
|
||||||
|
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
||||||
|
assert result.values[1] == 0.5
|
||||||
|
|
||||||
|
def test_starred(self):
|
||||||
|
code = """
|
||||||
|
from math import radians, sin, cos, sqrt, atan2
|
||||||
|
|
||||||
|
def haversine(lat1, lon1, lat2, lon2):
|
||||||
|
R = 6371000 # Radius of the Earth in meters
|
||||||
|
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
|
||||||
|
dlat = lat2 - lat1
|
||||||
|
dlon = lon2 - lon1
|
||||||
|
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
|
||||||
|
c = 2 * atan2(sqrt(a), sqrt(1 - a))
|
||||||
|
distance = R * c
|
||||||
|
return distance
|
||||||
|
|
||||||
|
coords_geneva = (46.1978, 6.1342)
|
||||||
|
coords_barcelona = (41.3869, 2.1660)
|
||||||
|
|
||||||
|
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||||
|
"""
|
||||||
|
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
||||||
|
assert round(result, 1) == 622395.4
|
||||||
|
|
||||||
|
def test_for(self):
|
||||||
|
code = """
|
||||||
|
shifts = {
|
||||||
|
"Worker A": ("6:45 pm", "8:00 pm"),
|
||||||
|
"Worker B": ("10:00 am", "11:45 am")
|
||||||
|
}
|
||||||
|
|
||||||
|
shift_intervals = {}
|
||||||
|
for worker, (start, end) in shifts.items():
|
||||||
|
shift_intervals[worker] = end
|
||||||
|
shift_intervals
|
||||||
|
"""
|
||||||
|
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
|
||||||
|
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
|
||||||
|
|||||||
Reference in New Issue
Block a user