diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 9adf55289d..6ab971a480 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -34,11 +34,16 @@ def custom_print(*args): BASE_PYTHON_TOOLS = { "print": custom_print, + "isinstance": isinstance, "range": range, "float": float, "int": int, "bool": bool, "str": str, + "set": set, + "list": list, + "dict": dict, + "tuple": tuple, "round": round, "ceil": math.ceil, "floor": math.floor, @@ -60,10 +65,6 @@ BASE_PYTHON_TOOLS = { "max": max, "min": min, "abs": abs, - "list": list, - "dict": dict, - "tuple": tuple, - "set": set, "enumerate": enumerate, "zip": zip, "reversed": reversed, @@ -74,6 +75,15 @@ BASE_PYTHON_TOOLS = { "filter": filter, "ord": ord, "chr": chr, + "next": next, + "iter": iter, + "divmod": divmod, + "callable": callable, + "getattr": getattr, + "hasattr": hasattr, + "setattr": setattr, + "issubclass": issubclass, + "type": type, } @@ -147,9 +157,9 @@ class PythonInterpreterTool(Tool): def __init__(self, *args, authorized_imports=None, **kwargs): if authorized_imports is None: - authorized_imports = list(set(LIST_SAFE_MODULES)) + self.authorized_imports = list(set(LIST_SAFE_MODULES)) else: - authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) + self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) self.inputs = { "code": { "type": "text", @@ -162,7 +172,9 @@ class PythonInterpreterTool(Tool): super().__init__(*args, **kwargs) def forward(self, code): - output = str(evaluate_python_code(code, tools=self.available_tools)) + output = str( + evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports) + ) return output diff --git a/src/transformers/agents/evaluate_agent.py b/src/transformers/agents/evaluate_agent.py index 4948dce283..66f734be5b 100644 --- a/src/transformers/agents/evaluate_agent.py +++ b/src/transformers/agents/evaluate_agent.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .agents import BASE_PYTHON_TOOLS -from .python_interpreter import InterpretorError, evaluate +from .python_interpreter import InterpreterError, evaluate ### Fake tools for test @@ -256,7 +256,7 @@ def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpret try: return evaluate(code, tools, state) - except InterpretorError as e: + except InterpreterError as e: return str(e) except Exception as e: if verbose: diff --git a/src/transformers/agents/llm_engine.py b/src/transformers/agents/llm_engine.py index 76458b0267..eb5edf7515 100644 --- a/src/transformers/agents/llm_engine.py +++ b/src/transformers/agents/llm_engine.py @@ -54,7 +54,7 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: message["role"] = role_conversions[role] if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: - final_message_list[-1]["content"] += "\n===\n" + message["content"] + final_message_list[-1]["content"] += "\n=======\n" + message["content"] else: final_message_list.append(message) return final_message_list diff --git a/src/transformers/agents/python_interpreter.py b/src/transformers/agents/python_interpreter.py index 992e9d14f1..39814daa7f 100644 --- a/src/transformers/agents/python_interpreter.py +++ b/src/transformers/agents/python_interpreter.py @@ -21,7 +21,7 @@ from collections.abc import Mapping from typing import Any, Callable, Dict, List, Optional -class InterpretorError(ValueError): +class InterpreterError(ValueError): """ An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported operations. @@ -50,6 +50,8 @@ LIST_SAFE_MODULES = [ "unicodedata", ] +PRINT_OUTPUTS = "" + class BreakException(Exception): pass @@ -59,13 +61,18 @@ class ContinueException(Exception): pass +class ReturnException(Exception): + def __init__(self, value): + self.value = value + + def get_iterable(obj): if isinstance(obj, list): return obj elif hasattr(obj, "__iter__"): return list(obj) else: - raise InterpretorError("Object is not iterable") + raise InterpreterError("Object is not iterable") def evaluate_unaryop(expression, state, tools): @@ -79,7 +86,7 @@ def evaluate_unaryop(expression, state, tools): elif isinstance(expression.op, ast.Invert): return ~operand else: - raise InterpretorError(f"Unary operation {expression.op.__class__.__name__} is not supported.") + raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") def evaluate_lambda(lambda_expression, state, tools): @@ -99,10 +106,15 @@ def evaluate_while(while_loop, state, tools): iterations = 0 while evaluate_ast(while_loop.test, state, tools): for node in while_loop.body: - evaluate_ast(node, state, tools) + try: + evaluate_ast(node, state, tools) + except BreakException: + return None + except ContinueException: + break iterations += 1 if iterations > max_iterations: - raise InterpretorError(f"Maximum number of {max_iterations} iterations in While loop exceeded") + raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded") return None @@ -110,15 +122,33 @@ def create_function(func_def, state, tools): def new_func(*args, **kwargs): func_state = state.copy() arg_names = [arg.arg for arg in func_def.args.args] + default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults] + + # Apply default values + defaults = dict(zip(arg_names[-len(default_values) :], default_values)) + + # Set positional arguments for name, value in zip(arg_names, args): func_state[name] = value + + # # Set keyword arguments + for name, value in kwargs.items(): + func_state[name] = value + + # Handle variable arguments if func_def.args.vararg: vararg_name = func_def.args.vararg.arg func_state[vararg_name] = args + if func_def.args.kwarg: kwarg_name = func_def.args.kwarg.arg func_state[kwarg_name] = kwargs + # Set default values for arguments that were not provided + for name, value in defaults.items(): + if name not in func_state: + func_state[name] = value + # Update function state with self and __class__ if func_def.args.args and func_def.args.args[0].arg == "self": if args: @@ -126,8 +156,11 @@ def create_function(func_def, state, tools): func_state["__class__"] = args[0].__class__ result = None - for stmt in func_def.body: - result = evaluate_ast(stmt, func_state, tools) + try: + for stmt in func_def.body: + result = evaluate_ast(stmt, func_state, tools) + except ReturnException as e: + result = e.value return result return new_func @@ -155,9 +188,12 @@ def evaluate_class_def(class_def, state, tools): class_dict[stmt.name] = evaluate_function_def(stmt, state, tools) elif isinstance(stmt, ast.Assign): for target in stmt.targets: - class_dict[target.id] = evaluate_ast(stmt.value, state, tools) + if isinstance(target, ast.Name): + class_dict[target.id] = evaluate_ast(stmt.value, state, tools) + elif isinstance(target, ast.Attribute): + class_dict[target.attr] = evaluate_ast(stmt.value, state, tools) else: - raise InterpretorError(f"Unsupported statement in class body: {stmt.__class__.__name__}") + raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") new_class = type(class_name, tuple(bases), class_dict) state[class_name] = new_class @@ -165,37 +201,77 @@ def evaluate_class_def(class_def, state, tools): def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]): - # Extract the target variable name and the operation - if isinstance(expression.target, ast.Name): - var_name = expression.target.id - current_value = state.get(var_name, 0) # Assuming default of 0 if not in state - value_to_add = evaluate_ast(expression.value, state, tools) + # Helper function to get current value and set new value based on the target type + def get_current_value(target): + if isinstance(target, ast.Name): + return state.get(target.id, 0) + elif isinstance(target, ast.Subscript): + obj = evaluate_ast(target.value, state, tools) + key = evaluate_ast(target.slice, state, tools) + return obj[key] + elif isinstance(target, ast.Attribute): + obj = evaluate_ast(target.value, state, tools) + return getattr(obj, target.attr) + elif isinstance(target, ast.Tuple): + return tuple(get_current_value(elt) for elt in target.elts) + elif isinstance(target, ast.List): + return [get_current_value(elt) for elt in target.elts] + else: + raise InterpreterError("AugAssign not supported for {type(target)} targets.") - # Determine the operation and apply it - if isinstance(expression.op, ast.Add): + current_value = get_current_value(expression.target) + value_to_add = evaluate_ast(expression.value, state, tools) + + # Determine the operation and apply it + if isinstance(expression.op, ast.Add): + if isinstance(current_value, list): + if not isinstance(value_to_add, list): + raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") updated_value = current_value + value_to_add - elif isinstance(expression.op, ast.Sub): - updated_value = current_value - value_to_add - elif isinstance(expression.op, ast.Mult): - updated_value = current_value * value_to_add - elif isinstance(expression.op, ast.Div): - updated_value = current_value / value_to_add - # Add other operations as needed - - # Update the state - state[var_name] = updated_value - return updated_value + else: + updated_value = current_value + value_to_add + elif isinstance(expression.op, ast.Sub): + updated_value = current_value - value_to_add + elif isinstance(expression.op, ast.Mult): + updated_value = current_value * value_to_add + elif isinstance(expression.op, ast.Div): + updated_value = current_value / value_to_add + elif isinstance(expression.op, ast.Mod): + updated_value = current_value % value_to_add + elif isinstance(expression.op, ast.Pow): + updated_value = current_value**value_to_add + elif isinstance(expression.op, ast.FloorDiv): + updated_value = current_value // value_to_add + elif isinstance(expression.op, ast.BitAnd): + updated_value = current_value & value_to_add + elif isinstance(expression.op, ast.BitOr): + updated_value = current_value | value_to_add + elif isinstance(expression.op, ast.BitXor): + updated_value = current_value ^ value_to_add + elif isinstance(expression.op, ast.LShift): + updated_value = current_value << value_to_add + elif isinstance(expression.op, ast.RShift): + updated_value = current_value >> value_to_add else: - raise InterpretorError("AugAssign not supported for non-simple variable targets.") + raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") + + # Update the state + set_value(expression.target, updated_value, state, tools) + + return updated_value -def evaluate_boolop(boolop, state, tools): - values = [evaluate_ast(val, state, tools) for val in boolop.values] - op = boolop.op - if isinstance(op, ast.And): - return all(values) - elif isinstance(op, ast.Or): - return any(values) +def evaluate_boolop(node, state, tools): + if isinstance(node.op, ast.And): + for value in node.values: + if not evaluate_ast(value, state, tools): + return False + return True + elif isinstance(node.op, ast.Or): + for value in node.values: + if evaluate_ast(value, state, tools): + return True + return False def evaluate_binop(binop, state, tools): @@ -233,41 +309,49 @@ def evaluate_binop(binop, state, tools): def evaluate_assign(assign, state, tools): - var_names = assign.targets result = evaluate_ast(assign.value, state, tools) - if len(var_names) == 1: - target = var_names[0] - if isinstance(target, ast.Tuple): - for i, elem in enumerate(target.elts): - state[elem.id] = result[i] - elif isinstance(target, ast.Attribute): - obj = evaluate_ast(target.value, state, tools) - setattr(obj, target.attr, result) - elif isinstance(target, ast.Subscript): - obj = evaluate_ast(target.value, state, tools) - key = evaluate_ast(target.slice, state, tools) - obj[key] = result - else: - state[target.id] = result - + if len(assign.targets) == 1: + target = assign.targets[0] + set_value(target, result, state, tools) else: - if len(result) != len(var_names): - raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.") - for var_name, r in zip(var_names, result): - state[var_name.id] = r + if len(assign.targets) != len(result): + raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") + for tgt, val in zip(assign.targets, result): + set_value(tgt, val, state, tools) return result +def set_value(target, value, state, tools): + if isinstance(target, ast.Name): + if target.id in tools: + raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") + state[target.id] = value + elif isinstance(target, ast.Tuple): + if not isinstance(value, tuple): + raise InterpreterError("Cannot unpack non-tuple value") + if len(target.elts) != len(value): + raise InterpreterError("Cannot unpack tuple of wrong size") + for i, elem in enumerate(target.elts): + set_value(elem, value[i], state, tools) + elif isinstance(target, ast.Subscript): + obj = evaluate_ast(target.value, state, tools) + key = evaluate_ast(target.slice, state, tools) + obj[key] = value + elif isinstance(target, ast.Attribute): + obj = evaluate_ast(target.value, state, tools) + setattr(obj, target.attr, value) + + def evaluate_call(call, state, tools): if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): - raise InterpretorError( + raise InterpreterError( f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})." ) if isinstance(call.func, ast.Attribute): obj = evaluate_ast(call.func.value, state, tools) func_name = call.func.attr if not hasattr(obj, func_name): - raise InterpretorError(f"Object {obj} has no attribute {func_name}") + raise InterpreterError(f"Object {obj} has no attribute {func_name}") func = getattr(obj, func_name) elif isinstance(call.func, ast.Name): func_name = call.func.id @@ -278,7 +362,7 @@ def evaluate_call(call, state, tools): elif func_name in ERRORS: func = ERRORS[func_name] else: - raise InterpretorError( + raise InterpreterError( f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})." ) @@ -297,22 +381,22 @@ def evaluate_call(call, state, tools): if "__class__" in state and "self" in state: return super(state["__class__"], state["self"]) else: - raise InterpretorError("super() needs at least one argument") + raise InterpreterError("super() needs at least one argument") cls = args[0] if not isinstance(cls, type): - raise InterpretorError("super() argument 1 must be type") + raise InterpreterError("super() argument 1 must be type") if len(args) == 1: return super(cls) elif len(args) == 2: instance = args[1] return super(cls, instance) else: - raise InterpretorError("super() takes at most 2 arguments") - + raise InterpreterError("super() takes at most 2 arguments") else: if func_name == "print": output = " ".join(map(str, args)) - state["print_outputs"] += output + "\n" + global PRINT_OUTPUTS + PRINT_OUTPUTS += output + "\n" return output else: # Assume it's a callable object output = func(*args, **kwargs) @@ -325,8 +409,14 @@ def evaluate_subscript(subscript, state, tools): if isinstance(index, slice): return value[index] elif isinstance(value, (list, tuple)): + # Ensure the index is within bounds + if not (-len(value) <= index < len(value)): + raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") return value[int(index)] elif isinstance(value, str): + # Ensure the index is within bounds + if not (-len(value) <= index < len(value)): + raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") return value[index] elif index in value: return value[index] @@ -334,7 +424,7 @@ def evaluate_subscript(subscript, state, tools): close_matches = difflib.get_close_matches(index, list(value.keys())) if len(close_matches) > 0: return value[close_matches[0]] - raise InterpretorError(f"Could not index {value} with '{index}'.") + raise InterpreterError(f"Could not index {value} with '{index}'.") def evaluate_name(name, state, tools): @@ -347,7 +437,7 @@ def evaluate_name(name, state, tools): close_matches = difflib.get_close_matches(name.id, list(state.keys())) if len(close_matches) > 0: return state[close_matches[0]] - raise InterpretorError(f"The variable `{name.id}` is not defined.") + raise InterpreterError(f"The variable `{name.id}` is not defined.") def evaluate_condition(condition, state, tools): @@ -355,30 +445,36 @@ def evaluate_condition(condition, state, tools): comparators = [evaluate_ast(c, state, tools) for c in condition.comparators] ops = [type(op) for op in condition.ops] - result = left + result = True + current_left = left + for op, comparator in zip(ops, comparators): if op == ast.Eq: - result = result == comparator + result = result and (current_left == comparator) elif op == ast.NotEq: - result = result != comparator + result = result and (current_left != comparator) elif op == ast.Lt: - result = result < comparator + result = result and (current_left < comparator) elif op == ast.LtE: - result = result <= comparator + result = result and (current_left <= comparator) elif op == ast.Gt: - result = result > comparator + result = result and (current_left > comparator) elif op == ast.GtE: - result = result >= comparator + result = result and (current_left >= comparator) elif op == ast.Is: - result = result is comparator + result = result and (current_left is comparator) elif op == ast.IsNot: - result = result is not comparator + result = result and (current_left is not comparator) elif op == ast.In: - result = result in comparator + result = result and (current_left in comparator) elif op == ast.NotIn: - result = result not in comparator + result = result and (current_left not in comparator) else: - raise InterpretorError(f"Operator not supported: {op}") + raise InterpreterError(f"Operator not supported: {op}") + + current_left = comparator + if not result: + break return result @@ -425,15 +521,17 @@ def evaluate_for(for_loop, state, tools): def evaluate_listcomp(listcomp, state, tools): result = [] - vars = {} for generator in listcomp.generators: - var_name = generator.target.id iter_value = evaluate_ast(generator.iter, state, tools) for value in iter_value: - vars[var_name] = value - if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs): - elem = evaluate_ast(listcomp.elt, {**state, **vars}, tools) - result.append(elem) + new_state = state.copy() + if isinstance(generator.target, ast.Tuple): + for idx, elem in enumerate(generator.target.elts): + new_state[elem.id] = value[idx] + else: + new_state[generator.target.id] = value + if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs): + result.append(evaluate_ast(listcomp.elt, new_state, tools)) return result @@ -478,7 +576,42 @@ def evaluate_raise(raise_node, state, tools): else: raise exc else: - raise InterpretorError("Re-raise is not supported without an active exception") + raise InterpreterError("Re-raise is not supported without an active exception") + + +def evaluate_assert(assert_node, state, tools): + test_result = evaluate_ast(assert_node.test, state, tools) + if not test_result: + if assert_node.msg: + msg = evaluate_ast(assert_node.msg, state, tools) + raise AssertionError(msg) + else: + # Include the failing condition in the assertion message + test_code = ast.unparse(assert_node.test) + raise AssertionError(f"Assertion failed: {test_code}") + + +def evaluate_with(with_node, state, tools): + contexts = [] + for item in with_node.items: + context_expr = evaluate_ast(item.context_expr, state, tools) + if item.optional_vars: + state[item.optional_vars.id] = context_expr.__enter__() + contexts.append(state[item.optional_vars.id]) + else: + context_var = context_expr.__enter__() + contexts.append(context_var) + + try: + for stmt in with_node.body: + evaluate_ast(stmt, state, tools) + except Exception as e: + for context in reversed(contexts): + context.__exit__(type(e), e, e.__traceback__) + raise + else: + for context in reversed(contexts): + context.__exit__(None, None, None) def evaluate_ast( @@ -501,7 +634,7 @@ def evaluate_ast( encounters assignements. tools (`Dict[str, Callable]`): The functions that may be called during the evaluation. Any call to another function will fail with an - `InterpretorError`. + `InterpreterError`. authorized_imports (`List[str]`): The list of modules that can be imported by the code. By default, only a few safe modules are allowed. Add more at your own risk! @@ -537,8 +670,6 @@ def evaluate_ast( elif isinstance(expression, ast.Compare): # Comparison -> evaluate the comparison return evaluate_condition(expression, state, tools) - elif isinstance(expression, ast.Return): - return evaluate_ast(expression.value, state, tools) elif isinstance(expression, ast.Lambda): return evaluate_lambda(expression, state, tools) elif isinstance(expression, ast.FunctionDef): @@ -615,7 +746,7 @@ def evaluate_ast( module = __import__(alias.name) state[alias.asname or alias.name] = module else: - raise InterpretorError(f"Import of {alias.name} is not allowed.") + raise InterpreterError(f"Import of {alias.name} is not allowed.") return None elif isinstance(expression, ast.While): return evaluate_while(expression, state, tools) @@ -625,7 +756,7 @@ def evaluate_ast( for alias in expression.names: state[alias.asname or alias.name] = getattr(module, alias.name) else: - raise InterpretorError(f"Import from {expression.module} is not allowed.") + raise InterpreterError(f"Import from {expression.module} is not allowed.") return None elif isinstance(expression, ast.ClassDef): return evaluate_class_def(expression, state, tools) @@ -633,9 +764,17 @@ def evaluate_ast( return evaluate_try(expression, state, tools) elif isinstance(expression, ast.Raise): return evaluate_raise(expression, state, tools) + elif isinstance(expression, ast.Assert): + return evaluate_assert(expression, state, tools) + elif isinstance(expression, ast.With): + return evaluate_with(expression, state, tools) + elif isinstance(expression, ast.Set): + return {evaluate_ast(elt, state, tools) for elt in expression.elts} + elif isinstance(expression, ast.Return): + raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None) else: # For now we refuse anything else. Let's add things as we need them. - raise InterpretorError(f"{expression.__class__.__name__} is not supported.") + raise InterpreterError(f"{expression.__class__.__name__} is not supported.") def evaluate_python_code( @@ -652,7 +791,7 @@ def evaluate_python_code( The code to evaluate. tools (`Dict[str, Callable]`): The functions that may be called during the evaluation. Any call to another function will fail with an - `InterpretorError`. + `InterpreterError`. state (`Dict[str, Any]`): A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be updated by this function to contain all variables as they are evaluated. @@ -665,17 +804,17 @@ def evaluate_python_code( if state is None: state = {} result = None - state["print_outputs"] = "" - - for idx, node in enumerate(expression.body): + global PRINT_OUTPUTS + PRINT_OUTPUTS = "" + for node in expression.body: try: - line_result = evaluate_ast(node, state, tools, authorized_imports) - except InterpretorError as e: - msg = f"You tried to execute the following code:\n{code}\n" - msg += f"You got these outputs:\n{state['print_outputs']}\n" - msg += f"Evaluation stopped at line '{node}' because of the following error:\n{e}" - raise InterpretorError(msg) - if line_result is not None: - result = line_result + result = evaluate_ast(node, state, tools, authorized_imports) + except InterpreterError as e: + msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" + if len(PRINT_OUTPUTS) > 0: + msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n" + raise InterpreterError(msg) + finally: + state["print_outputs"] = PRINT_OUTPUTS return result diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 79e55bf652..062b98abd4 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -74,6 +74,26 @@ final_answer(7.2904) """ +def fake_react_code_llm_error(messages, stop_sequences=None) -> str: + prompt = str(messages) + if "special_marker" not in prompt: + return """ +Thought: I should multiply 2 by 3.6452. special_marker +Code: +```py +print = 2 +``` +""" + else: # We're at step 2 + return """ +Thought: I can now answer the initial question +Code: +```py +final_answer("got an error") +``` +""" + + def fake_code_llm_oneshot(messages, stop_sequences=None) -> str: return """ Thought: I should multiply 2 by 3.6452. special_marker @@ -124,6 +144,13 @@ Action: "tool_name": "code interpreter", } + def test_react_code_agent_code_errors_show_offending_lines(self): + agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error) + output = agent.run("What is 2 multiplied by 3.6452?") + assert isinstance(output, AgentText) + assert output == "got an error" + assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs) + def test_setup_agent_with_empty_toolbox(self): ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[]) diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index 51775e31e7..6f5907e27b 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -20,7 +20,7 @@ import pytest from transformers import load_tool from transformers.agents.agent_types import AGENT_TYPE_MAPPING from transformers.agents.default_tools import BASE_PYTHON_TOOLS -from transformers.agents.python_interpreter import InterpretorError, evaluate_python_code +from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code from .test_tools_common import ToolTesterMixin @@ -35,16 +35,6 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"]) self.tool.setup() - def test_exact_match_input_spec(self): - inputs_spec = self.tool.inputs - expected_description = ( - "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " - "else you will get an error. This code can only import the following python libraries: " - "['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', " - "'random', 're']." - ) - self.assertEqual(inputs_spec["code"]["description"], expected_description) - def test_exact_match_arg(self): result = self.tool("(2 / 2) * 4") self.assertEqual(result, "4.0") @@ -91,6 +81,17 @@ class PythonInterpreterTester(unittest.TestCase): assert result == 5 self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""}) + code = "a=1;b=None" + result = evaluate_python_code(code, {}, state={}) + # evaluate returns the value of the last assignment. + assert result is None + + def test_assignment_cannot_overwrite_tool(self): + code = "print = '3'" + with pytest.raises(InterpreterError) as e: + evaluate_python_code(code, {"print": print}, state={}) + assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e) + def test_evaluate_call(self): code = "y = add_two(x)" state = {"x": 3} @@ -99,7 +100,7 @@ class PythonInterpreterTester(unittest.TestCase): self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) # Should not work without the tool - with pytest.raises(InterpretorError) as e: + with pytest.raises(InterpreterError) as e: evaluate_python_code(code, {}, state=state) assert "tried to execute add_two" in str(e.value) @@ -237,6 +238,12 @@ for block in text_block: result = evaluate_python_code(code, {}, state={}) assert result == 2 + code = """ +digits, i = [1, 2, 3], 1 +digits[i], digits[i + 1] = digits[i + 1], digits[i]""" + state = {} + evaluate_python_code(code, {"range": range, "print": print, "int": int}, state) + def test_listcomp(self): code = "x = [i for i in range(3)]" result = evaluate_python_code(code, {"range": range}, state={}) @@ -278,10 +285,20 @@ for block in text_block: # test infinite loop code = "i = 0\nwhile i < 3:\n i -= 1\ni" - with pytest.raises(InterpretorError) as e: + with pytest.raises(InterpreterError) as e: evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert "iterations in While loop exceeded" in str(e) + # test lazy evaluation + code = """ +house_positions = [0, 7, 10, 15, 18, 22, 22] +i, n, loc = 0, 7, 30 +while i < n and house_positions[i] <= loc: + i += 1 +""" + state = {} + evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state) + def test_generator(self): code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)" result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) @@ -353,7 +370,19 @@ if char.isalpha(): assert result == "LATIN CAPITAL LETTER A" def test_multiple_comparators(self): - code = "0x30A0 <= ord('a') <= 0x30FF" + code = "0 <= -1 < 4 and 0 <= -5 < 4" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert not result + + code = "0 <= 1 < 4 and 0 <= -5 < 4" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert not result + + code = "0 <= 4 < 4 and 0 <= 3 < 4" + result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert not result + + code = "0 <= 3 < 4 and 0 <= 3 < 4" result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) assert result @@ -364,6 +393,16 @@ if char.isalpha(): assert result == "Ok no one cares" assert state["print_outputs"] == "Hello world!\nOk no one cares\n" + # test print in function + code = """ +print("1") +def function(): + print("2") +function()""" + state = {} + evaluate_python_code(code, {"print": print}, state) + assert state["print_outputs"] == "1\n2\n" + def test_tuple_target_in_iterator(self): code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]" result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) @@ -491,3 +530,147 @@ except ValueError as e: state = {} result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) assert result == int + + def test_tuple_id(self): + code = """ +food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1} +unique_food_items = [item for item, count in food_item_counts.items() if count == 1] +""" + state = {} + result = evaluate_python_code(code, {}, state=state) + assert result == ["orange", "pear"] + + def test_nonsimple_augassign(self): + code = """ +counts_dict = {'a': 0} +counts_dict['a'] += 1 +counts_list = [1, 2, 3] +counts_list += [4, 5, 6] + +class Counter: + self.count = 0 + +a = Counter() +a.count += 1 +""" + state = {} + evaluate_python_code(code, {}, state=state) + assert state["counts_dict"] == {"a": 1} + assert state["counts_list"] == [1, 2, 3, 4, 5, 6] + assert state["a"].count == 1 + + def test_adding_int_to_list_raises_error(self): + code = """ +counts = [1, 2, 3] +counts += 1""" + with pytest.raises(InterpreterError) as e: + evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert "Cannot add non-list value 1 to a list." in str(e) + + def test_error_highlights_correct_line_of_code(self): + code = """# Ok this is a very long code +# It has many commented lines +a = 1 +b = 2 + +# Here is another piece +counts = [1, 2, 3] +counts += 1 +b += 1""" + with pytest.raises(InterpreterError) as e: + evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert "Evaluation stopped at line 'counts += 1" in str(e) + + def test_assert(self): + code = """ +assert 1 == 1 +assert 1 == 2 +""" + with pytest.raises(AssertionError) as e: + evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) + assert "1 == 2" in str(e) and "1 == 1" not in str(e) + + def test_with_context_manager(self): + code = """ +class SimpleLock: + def __init__(self): + self.locked = False + + def __enter__(self): + self.locked = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.locked = False + +lock = SimpleLock() + +with lock as l: + assert l.locked == True + +assert lock.locked == False + """ + state = {} + tools = {} + evaluate_python_code(code, tools, state) + + def test_default_arg_in_function(self): + code = """ +def f(a, b=333, n=1000): + return b + n +n = f(1, n=667) +""" + res = evaluate_python_code(code, {}, {}) + assert res == 1000 + + def test_set(self): + code = """ +S1 = {'a', 'b', 'c'} +S2 = {'b', 'c', 'd'} +S3 = S1.difference(S2) +S4 = S1.intersection(S2) +""" + state = {} + evaluate_python_code(code, {}, state=state) + assert state["S3"] == {"a"} + assert state["S4"] == {"b", "c"} + + def test_break(self): + code = """ +i = 0 + +while True: + i+= 1 + if i==3: + break + +i""" + result = evaluate_python_code(code, {"print": print, "round": round}, state={}) + assert result == 3 + + def test_return(self): + # test early returns + code = """ +def add_one(n, shift): + if True: + return n + shift + return n + +add_one(1, 1) +""" + state = {} + result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) + print(state) + assert result == 2 + + # test returning None + code = """ +def returns_none(a): + return + +returns_none(1) +""" + state = {} + result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) + print(state) + assert result is None