Agents: Improve python interpreter (#31409)
* Improve Python interpreter * Add with and assert statements * Prevent overwriting existing tools * Check interpreter errors are well logged in code agent * Add lazy evaluation for and and or * Improve variable assignment * Fix early return statements in functions * Add small import fix on interpreter tool
This commit is contained in:
@@ -34,11 +34,16 @@ def custom_print(*args):
|
|||||||
|
|
||||||
BASE_PYTHON_TOOLS = {
|
BASE_PYTHON_TOOLS = {
|
||||||
"print": custom_print,
|
"print": custom_print,
|
||||||
|
"isinstance": isinstance,
|
||||||
"range": range,
|
"range": range,
|
||||||
"float": float,
|
"float": float,
|
||||||
"int": int,
|
"int": int,
|
||||||
"bool": bool,
|
"bool": bool,
|
||||||
"str": str,
|
"str": str,
|
||||||
|
"set": set,
|
||||||
|
"list": list,
|
||||||
|
"dict": dict,
|
||||||
|
"tuple": tuple,
|
||||||
"round": round,
|
"round": round,
|
||||||
"ceil": math.ceil,
|
"ceil": math.ceil,
|
||||||
"floor": math.floor,
|
"floor": math.floor,
|
||||||
@@ -60,10 +65,6 @@ BASE_PYTHON_TOOLS = {
|
|||||||
"max": max,
|
"max": max,
|
||||||
"min": min,
|
"min": min,
|
||||||
"abs": abs,
|
"abs": abs,
|
||||||
"list": list,
|
|
||||||
"dict": dict,
|
|
||||||
"tuple": tuple,
|
|
||||||
"set": set,
|
|
||||||
"enumerate": enumerate,
|
"enumerate": enumerate,
|
||||||
"zip": zip,
|
"zip": zip,
|
||||||
"reversed": reversed,
|
"reversed": reversed,
|
||||||
@@ -74,6 +75,15 @@ BASE_PYTHON_TOOLS = {
|
|||||||
"filter": filter,
|
"filter": filter,
|
||||||
"ord": ord,
|
"ord": ord,
|
||||||
"chr": chr,
|
"chr": chr,
|
||||||
|
"next": next,
|
||||||
|
"iter": iter,
|
||||||
|
"divmod": divmod,
|
||||||
|
"callable": callable,
|
||||||
|
"getattr": getattr,
|
||||||
|
"hasattr": hasattr,
|
||||||
|
"setattr": setattr,
|
||||||
|
"issubclass": issubclass,
|
||||||
|
"type": type,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -147,9 +157,9 @@ class PythonInterpreterTool(Tool):
|
|||||||
|
|
||||||
def __init__(self, *args, authorized_imports=None, **kwargs):
|
def __init__(self, *args, authorized_imports=None, **kwargs):
|
||||||
if authorized_imports is None:
|
if authorized_imports is None:
|
||||||
authorized_imports = list(set(LIST_SAFE_MODULES))
|
self.authorized_imports = list(set(LIST_SAFE_MODULES))
|
||||||
else:
|
else:
|
||||||
authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
|
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
|
||||||
self.inputs = {
|
self.inputs = {
|
||||||
"code": {
|
"code": {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
@@ -162,7 +172,9 @@ class PythonInterpreterTool(Tool):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, code):
|
def forward(self, code):
|
||||||
output = str(evaluate_python_code(code, tools=self.available_tools))
|
output = str(
|
||||||
|
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports)
|
||||||
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from .agents import BASE_PYTHON_TOOLS
|
from .agents import BASE_PYTHON_TOOLS
|
||||||
from .python_interpreter import InterpretorError, evaluate
|
from .python_interpreter import InterpreterError, evaluate
|
||||||
|
|
||||||
|
|
||||||
### Fake tools for test
|
### Fake tools for test
|
||||||
@@ -256,7 +256,7 @@ def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpret
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return evaluate(code, tools, state)
|
return evaluate(code, tools, state)
|
||||||
except InterpretorError as e:
|
except InterpreterError as e:
|
||||||
return str(e)
|
return str(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
|
|||||||
message["role"] = role_conversions[role]
|
message["role"] = role_conversions[role]
|
||||||
|
|
||||||
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
||||||
final_message_list[-1]["content"] += "\n===\n" + message["content"]
|
final_message_list[-1]["content"] += "\n=======\n" + message["content"]
|
||||||
else:
|
else:
|
||||||
final_message_list.append(message)
|
final_message_list.append(message)
|
||||||
return final_message_list
|
return final_message_list
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from collections.abc import Mapping
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class InterpretorError(ValueError):
|
class InterpreterError(ValueError):
|
||||||
"""
|
"""
|
||||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||||
operations.
|
operations.
|
||||||
@@ -50,6 +50,8 @@ LIST_SAFE_MODULES = [
|
|||||||
"unicodedata",
|
"unicodedata",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PRINT_OUTPUTS = ""
|
||||||
|
|
||||||
|
|
||||||
class BreakException(Exception):
|
class BreakException(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -59,13 +61,18 @@ class ContinueException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnException(Exception):
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
def get_iterable(obj):
|
def get_iterable(obj):
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return obj
|
return obj
|
||||||
elif hasattr(obj, "__iter__"):
|
elif hasattr(obj, "__iter__"):
|
||||||
return list(obj)
|
return list(obj)
|
||||||
else:
|
else:
|
||||||
raise InterpretorError("Object is not iterable")
|
raise InterpreterError("Object is not iterable")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_unaryop(expression, state, tools):
|
def evaluate_unaryop(expression, state, tools):
|
||||||
@@ -79,7 +86,7 @@ def evaluate_unaryop(expression, state, tools):
|
|||||||
elif isinstance(expression.op, ast.Invert):
|
elif isinstance(expression.op, ast.Invert):
|
||||||
return ~operand
|
return ~operand
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_lambda(lambda_expression, state, tools):
|
def evaluate_lambda(lambda_expression, state, tools):
|
||||||
@@ -99,10 +106,15 @@ def evaluate_while(while_loop, state, tools):
|
|||||||
iterations = 0
|
iterations = 0
|
||||||
while evaluate_ast(while_loop.test, state, tools):
|
while evaluate_ast(while_loop.test, state, tools):
|
||||||
for node in while_loop.body:
|
for node in while_loop.body:
|
||||||
|
try:
|
||||||
evaluate_ast(node, state, tools)
|
evaluate_ast(node, state, tools)
|
||||||
|
except BreakException:
|
||||||
|
return None
|
||||||
|
except ContinueException:
|
||||||
|
break
|
||||||
iterations += 1
|
iterations += 1
|
||||||
if iterations > max_iterations:
|
if iterations > max_iterations:
|
||||||
raise InterpretorError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -110,15 +122,33 @@ def create_function(func_def, state, tools):
|
|||||||
def new_func(*args, **kwargs):
|
def new_func(*args, **kwargs):
|
||||||
func_state = state.copy()
|
func_state = state.copy()
|
||||||
arg_names = [arg.arg for arg in func_def.args.args]
|
arg_names = [arg.arg for arg in func_def.args.args]
|
||||||
|
default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults]
|
||||||
|
|
||||||
|
# Apply default values
|
||||||
|
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
||||||
|
|
||||||
|
# Set positional arguments
|
||||||
for name, value in zip(arg_names, args):
|
for name, value in zip(arg_names, args):
|
||||||
func_state[name] = value
|
func_state[name] = value
|
||||||
|
|
||||||
|
# # Set keyword arguments
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
func_state[name] = value
|
||||||
|
|
||||||
|
# Handle variable arguments
|
||||||
if func_def.args.vararg:
|
if func_def.args.vararg:
|
||||||
vararg_name = func_def.args.vararg.arg
|
vararg_name = func_def.args.vararg.arg
|
||||||
func_state[vararg_name] = args
|
func_state[vararg_name] = args
|
||||||
|
|
||||||
if func_def.args.kwarg:
|
if func_def.args.kwarg:
|
||||||
kwarg_name = func_def.args.kwarg.arg
|
kwarg_name = func_def.args.kwarg.arg
|
||||||
func_state[kwarg_name] = kwargs
|
func_state[kwarg_name] = kwargs
|
||||||
|
|
||||||
|
# Set default values for arguments that were not provided
|
||||||
|
for name, value in defaults.items():
|
||||||
|
if name not in func_state:
|
||||||
|
func_state[name] = value
|
||||||
|
|
||||||
# Update function state with self and __class__
|
# Update function state with self and __class__
|
||||||
if func_def.args.args and func_def.args.args[0].arg == "self":
|
if func_def.args.args and func_def.args.args[0].arg == "self":
|
||||||
if args:
|
if args:
|
||||||
@@ -126,8 +156,11 @@ def create_function(func_def, state, tools):
|
|||||||
func_state["__class__"] = args[0].__class__
|
func_state["__class__"] = args[0].__class__
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
|
try:
|
||||||
for stmt in func_def.body:
|
for stmt in func_def.body:
|
||||||
result = evaluate_ast(stmt, func_state, tools)
|
result = evaluate_ast(stmt, func_state, tools)
|
||||||
|
except ReturnException as e:
|
||||||
|
result = e.value
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return new_func
|
return new_func
|
||||||
@@ -155,9 +188,12 @@ def evaluate_class_def(class_def, state, tools):
|
|||||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
|
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
|
||||||
elif isinstance(stmt, ast.Assign):
|
elif isinstance(stmt, ast.Assign):
|
||||||
for target in stmt.targets:
|
for target in stmt.targets:
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
|
class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
|
||||||
|
elif isinstance(target, ast.Attribute):
|
||||||
|
class_dict[target.attr] = evaluate_ast(stmt.value, state, tools)
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||||
|
|
||||||
new_class = type(class_name, tuple(bases), class_dict)
|
new_class = type(class_name, tuple(bases), class_dict)
|
||||||
state[class_name] = new_class
|
state[class_name] = new_class
|
||||||
@@ -165,14 +201,34 @@ def evaluate_class_def(class_def, state, tools):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||||
# Extract the target variable name and the operation
|
# Helper function to get current value and set new value based on the target type
|
||||||
if isinstance(expression.target, ast.Name):
|
def get_current_value(target):
|
||||||
var_name = expression.target.id
|
if isinstance(target, ast.Name):
|
||||||
current_value = state.get(var_name, 0) # Assuming default of 0 if not in state
|
return state.get(target.id, 0)
|
||||||
|
elif isinstance(target, ast.Subscript):
|
||||||
|
obj = evaluate_ast(target.value, state, tools)
|
||||||
|
key = evaluate_ast(target.slice, state, tools)
|
||||||
|
return obj[key]
|
||||||
|
elif isinstance(target, ast.Attribute):
|
||||||
|
obj = evaluate_ast(target.value, state, tools)
|
||||||
|
return getattr(obj, target.attr)
|
||||||
|
elif isinstance(target, ast.Tuple):
|
||||||
|
return tuple(get_current_value(elt) for elt in target.elts)
|
||||||
|
elif isinstance(target, ast.List):
|
||||||
|
return [get_current_value(elt) for elt in target.elts]
|
||||||
|
else:
|
||||||
|
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
||||||
|
|
||||||
|
current_value = get_current_value(expression.target)
|
||||||
value_to_add = evaluate_ast(expression.value, state, tools)
|
value_to_add = evaluate_ast(expression.value, state, tools)
|
||||||
|
|
||||||
# Determine the operation and apply it
|
# Determine the operation and apply it
|
||||||
if isinstance(expression.op, ast.Add):
|
if isinstance(expression.op, ast.Add):
|
||||||
|
if isinstance(current_value, list):
|
||||||
|
if not isinstance(value_to_add, list):
|
||||||
|
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
|
||||||
|
updated_value = current_value + value_to_add
|
||||||
|
else:
|
||||||
updated_value = current_value + value_to_add
|
updated_value = current_value + value_to_add
|
||||||
elif isinstance(expression.op, ast.Sub):
|
elif isinstance(expression.op, ast.Sub):
|
||||||
updated_value = current_value - value_to_add
|
updated_value = current_value - value_to_add
|
||||||
@@ -180,22 +236,42 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
|
|||||||
updated_value = current_value * value_to_add
|
updated_value = current_value * value_to_add
|
||||||
elif isinstance(expression.op, ast.Div):
|
elif isinstance(expression.op, ast.Div):
|
||||||
updated_value = current_value / value_to_add
|
updated_value = current_value / value_to_add
|
||||||
# Add other operations as needed
|
elif isinstance(expression.op, ast.Mod):
|
||||||
|
updated_value = current_value % value_to_add
|
||||||
|
elif isinstance(expression.op, ast.Pow):
|
||||||
|
updated_value = current_value**value_to_add
|
||||||
|
elif isinstance(expression.op, ast.FloorDiv):
|
||||||
|
updated_value = current_value // value_to_add
|
||||||
|
elif isinstance(expression.op, ast.BitAnd):
|
||||||
|
updated_value = current_value & value_to_add
|
||||||
|
elif isinstance(expression.op, ast.BitOr):
|
||||||
|
updated_value = current_value | value_to_add
|
||||||
|
elif isinstance(expression.op, ast.BitXor):
|
||||||
|
updated_value = current_value ^ value_to_add
|
||||||
|
elif isinstance(expression.op, ast.LShift):
|
||||||
|
updated_value = current_value << value_to_add
|
||||||
|
elif isinstance(expression.op, ast.RShift):
|
||||||
|
updated_value = current_value >> value_to_add
|
||||||
|
else:
|
||||||
|
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
||||||
|
|
||||||
# Update the state
|
# Update the state
|
||||||
state[var_name] = updated_value
|
set_value(expression.target, updated_value, state, tools)
|
||||||
|
|
||||||
return updated_value
|
return updated_value
|
||||||
else:
|
|
||||||
raise InterpretorError("AugAssign not supported for non-simple variable targets.")
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_boolop(boolop, state, tools):
|
def evaluate_boolop(node, state, tools):
|
||||||
values = [evaluate_ast(val, state, tools) for val in boolop.values]
|
if isinstance(node.op, ast.And):
|
||||||
op = boolop.op
|
for value in node.values:
|
||||||
if isinstance(op, ast.And):
|
if not evaluate_ast(value, state, tools):
|
||||||
return all(values)
|
return False
|
||||||
elif isinstance(op, ast.Or):
|
return True
|
||||||
return any(values)
|
elif isinstance(node.op, ast.Or):
|
||||||
|
for value in node.values:
|
||||||
|
if evaluate_ast(value, state, tools):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def evaluate_binop(binop, state, tools):
|
def evaluate_binop(binop, state, tools):
|
||||||
@@ -233,41 +309,49 @@ def evaluate_binop(binop, state, tools):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_assign(assign, state, tools):
|
def evaluate_assign(assign, state, tools):
|
||||||
var_names = assign.targets
|
|
||||||
result = evaluate_ast(assign.value, state, tools)
|
result = evaluate_ast(assign.value, state, tools)
|
||||||
if len(var_names) == 1:
|
if len(assign.targets) == 1:
|
||||||
target = var_names[0]
|
target = assign.targets[0]
|
||||||
if isinstance(target, ast.Tuple):
|
set_value(target, result, state, tools)
|
||||||
|
else:
|
||||||
|
if len(assign.targets) != len(result):
|
||||||
|
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
||||||
|
for tgt, val in zip(assign.targets, result):
|
||||||
|
set_value(tgt, val, state, tools)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def set_value(target, value, state, tools):
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
|
if target.id in tools:
|
||||||
|
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
||||||
|
state[target.id] = value
|
||||||
|
elif isinstance(target, ast.Tuple):
|
||||||
|
if not isinstance(value, tuple):
|
||||||
|
raise InterpreterError("Cannot unpack non-tuple value")
|
||||||
|
if len(target.elts) != len(value):
|
||||||
|
raise InterpreterError("Cannot unpack tuple of wrong size")
|
||||||
for i, elem in enumerate(target.elts):
|
for i, elem in enumerate(target.elts):
|
||||||
state[elem.id] = result[i]
|
set_value(elem, value[i], state, tools)
|
||||||
elif isinstance(target, ast.Attribute):
|
|
||||||
obj = evaluate_ast(target.value, state, tools)
|
|
||||||
setattr(obj, target.attr, result)
|
|
||||||
elif isinstance(target, ast.Subscript):
|
elif isinstance(target, ast.Subscript):
|
||||||
obj = evaluate_ast(target.value, state, tools)
|
obj = evaluate_ast(target.value, state, tools)
|
||||||
key = evaluate_ast(target.slice, state, tools)
|
key = evaluate_ast(target.slice, state, tools)
|
||||||
obj[key] = result
|
obj[key] = value
|
||||||
else:
|
elif isinstance(target, ast.Attribute):
|
||||||
state[target.id] = result
|
obj = evaluate_ast(target.value, state, tools)
|
||||||
|
setattr(obj, target.attr, value)
|
||||||
else:
|
|
||||||
if len(result) != len(var_names):
|
|
||||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
|
||||||
for var_name, r in zip(var_names, result):
|
|
||||||
state[var_name.id] = r
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_call(call, state, tools):
|
def evaluate_call(call, state, tools):
|
||||||
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
||||||
raise InterpretorError(
|
raise InterpreterError(
|
||||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
||||||
)
|
)
|
||||||
if isinstance(call.func, ast.Attribute):
|
if isinstance(call.func, ast.Attribute):
|
||||||
obj = evaluate_ast(call.func.value, state, tools)
|
obj = evaluate_ast(call.func.value, state, tools)
|
||||||
func_name = call.func.attr
|
func_name = call.func.attr
|
||||||
if not hasattr(obj, func_name):
|
if not hasattr(obj, func_name):
|
||||||
raise InterpretorError(f"Object {obj} has no attribute {func_name}")
|
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
||||||
func = getattr(obj, func_name)
|
func = getattr(obj, func_name)
|
||||||
elif isinstance(call.func, ast.Name):
|
elif isinstance(call.func, ast.Name):
|
||||||
func_name = call.func.id
|
func_name = call.func.id
|
||||||
@@ -278,7 +362,7 @@ def evaluate_call(call, state, tools):
|
|||||||
elif func_name in ERRORS:
|
elif func_name in ERRORS:
|
||||||
func = ERRORS[func_name]
|
func = ERRORS[func_name]
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(
|
raise InterpreterError(
|
||||||
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -297,22 +381,22 @@ def evaluate_call(call, state, tools):
|
|||||||
if "__class__" in state and "self" in state:
|
if "__class__" in state and "self" in state:
|
||||||
return super(state["__class__"], state["self"])
|
return super(state["__class__"], state["self"])
|
||||||
else:
|
else:
|
||||||
raise InterpretorError("super() needs at least one argument")
|
raise InterpreterError("super() needs at least one argument")
|
||||||
cls = args[0]
|
cls = args[0]
|
||||||
if not isinstance(cls, type):
|
if not isinstance(cls, type):
|
||||||
raise InterpretorError("super() argument 1 must be type")
|
raise InterpreterError("super() argument 1 must be type")
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
return super(cls)
|
return super(cls)
|
||||||
elif len(args) == 2:
|
elif len(args) == 2:
|
||||||
instance = args[1]
|
instance = args[1]
|
||||||
return super(cls, instance)
|
return super(cls, instance)
|
||||||
else:
|
else:
|
||||||
raise InterpretorError("super() takes at most 2 arguments")
|
raise InterpreterError("super() takes at most 2 arguments")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if func_name == "print":
|
if func_name == "print":
|
||||||
output = " ".join(map(str, args))
|
output = " ".join(map(str, args))
|
||||||
state["print_outputs"] += output + "\n"
|
global PRINT_OUTPUTS
|
||||||
|
PRINT_OUTPUTS += output + "\n"
|
||||||
return output
|
return output
|
||||||
else: # Assume it's a callable object
|
else: # Assume it's a callable object
|
||||||
output = func(*args, **kwargs)
|
output = func(*args, **kwargs)
|
||||||
@@ -325,8 +409,14 @@ def evaluate_subscript(subscript, state, tools):
|
|||||||
if isinstance(index, slice):
|
if isinstance(index, slice):
|
||||||
return value[index]
|
return value[index]
|
||||||
elif isinstance(value, (list, tuple)):
|
elif isinstance(value, (list, tuple)):
|
||||||
|
# Ensure the index is within bounds
|
||||||
|
if not (-len(value) <= index < len(value)):
|
||||||
|
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
||||||
return value[int(index)]
|
return value[int(index)]
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
|
# Ensure the index is within bounds
|
||||||
|
if not (-len(value) <= index < len(value)):
|
||||||
|
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
||||||
return value[index]
|
return value[index]
|
||||||
elif index in value:
|
elif index in value:
|
||||||
return value[index]
|
return value[index]
|
||||||
@@ -334,7 +424,7 @@ def evaluate_subscript(subscript, state, tools):
|
|||||||
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
||||||
if len(close_matches) > 0:
|
if len(close_matches) > 0:
|
||||||
return value[close_matches[0]]
|
return value[close_matches[0]]
|
||||||
raise InterpretorError(f"Could not index {value} with '{index}'.")
|
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_name(name, state, tools):
|
def evaluate_name(name, state, tools):
|
||||||
@@ -347,7 +437,7 @@ def evaluate_name(name, state, tools):
|
|||||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||||
if len(close_matches) > 0:
|
if len(close_matches) > 0:
|
||||||
return state[close_matches[0]]
|
return state[close_matches[0]]
|
||||||
raise InterpretorError(f"The variable `{name.id}` is not defined.")
|
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_condition(condition, state, tools):
|
def evaluate_condition(condition, state, tools):
|
||||||
@@ -355,30 +445,36 @@ def evaluate_condition(condition, state, tools):
|
|||||||
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
|
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
|
||||||
ops = [type(op) for op in condition.ops]
|
ops = [type(op) for op in condition.ops]
|
||||||
|
|
||||||
result = left
|
result = True
|
||||||
|
current_left = left
|
||||||
|
|
||||||
for op, comparator in zip(ops, comparators):
|
for op, comparator in zip(ops, comparators):
|
||||||
if op == ast.Eq:
|
if op == ast.Eq:
|
||||||
result = result == comparator
|
result = result and (current_left == comparator)
|
||||||
elif op == ast.NotEq:
|
elif op == ast.NotEq:
|
||||||
result = result != comparator
|
result = result and (current_left != comparator)
|
||||||
elif op == ast.Lt:
|
elif op == ast.Lt:
|
||||||
result = result < comparator
|
result = result and (current_left < comparator)
|
||||||
elif op == ast.LtE:
|
elif op == ast.LtE:
|
||||||
result = result <= comparator
|
result = result and (current_left <= comparator)
|
||||||
elif op == ast.Gt:
|
elif op == ast.Gt:
|
||||||
result = result > comparator
|
result = result and (current_left > comparator)
|
||||||
elif op == ast.GtE:
|
elif op == ast.GtE:
|
||||||
result = result >= comparator
|
result = result and (current_left >= comparator)
|
||||||
elif op == ast.Is:
|
elif op == ast.Is:
|
||||||
result = result is comparator
|
result = result and (current_left is comparator)
|
||||||
elif op == ast.IsNot:
|
elif op == ast.IsNot:
|
||||||
result = result is not comparator
|
result = result and (current_left is not comparator)
|
||||||
elif op == ast.In:
|
elif op == ast.In:
|
||||||
result = result in comparator
|
result = result and (current_left in comparator)
|
||||||
elif op == ast.NotIn:
|
elif op == ast.NotIn:
|
||||||
result = result not in comparator
|
result = result and (current_left not in comparator)
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(f"Operator not supported: {op}")
|
raise InterpreterError(f"Operator not supported: {op}")
|
||||||
|
|
||||||
|
current_left = comparator
|
||||||
|
if not result:
|
||||||
|
break
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -425,15 +521,17 @@ def evaluate_for(for_loop, state, tools):
|
|||||||
|
|
||||||
def evaluate_listcomp(listcomp, state, tools):
|
def evaluate_listcomp(listcomp, state, tools):
|
||||||
result = []
|
result = []
|
||||||
vars = {}
|
|
||||||
for generator in listcomp.generators:
|
for generator in listcomp.generators:
|
||||||
var_name = generator.target.id
|
|
||||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
iter_value = evaluate_ast(generator.iter, state, tools)
|
||||||
for value in iter_value:
|
for value in iter_value:
|
||||||
vars[var_name] = value
|
new_state = state.copy()
|
||||||
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
|
if isinstance(generator.target, ast.Tuple):
|
||||||
elem = evaluate_ast(listcomp.elt, {**state, **vars}, tools)
|
for idx, elem in enumerate(generator.target.elts):
|
||||||
result.append(elem)
|
new_state[elem.id] = value[idx]
|
||||||
|
else:
|
||||||
|
new_state[generator.target.id] = value
|
||||||
|
if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs):
|
||||||
|
result.append(evaluate_ast(listcomp.elt, new_state, tools))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -478,7 +576,42 @@ def evaluate_raise(raise_node, state, tools):
|
|||||||
else:
|
else:
|
||||||
raise exc
|
raise exc
|
||||||
else:
|
else:
|
||||||
raise InterpretorError("Re-raise is not supported without an active exception")
|
raise InterpreterError("Re-raise is not supported without an active exception")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_assert(assert_node, state, tools):
|
||||||
|
test_result = evaluate_ast(assert_node.test, state, tools)
|
||||||
|
if not test_result:
|
||||||
|
if assert_node.msg:
|
||||||
|
msg = evaluate_ast(assert_node.msg, state, tools)
|
||||||
|
raise AssertionError(msg)
|
||||||
|
else:
|
||||||
|
# Include the failing condition in the assertion message
|
||||||
|
test_code = ast.unparse(assert_node.test)
|
||||||
|
raise AssertionError(f"Assertion failed: {test_code}")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_with(with_node, state, tools):
|
||||||
|
contexts = []
|
||||||
|
for item in with_node.items:
|
||||||
|
context_expr = evaluate_ast(item.context_expr, state, tools)
|
||||||
|
if item.optional_vars:
|
||||||
|
state[item.optional_vars.id] = context_expr.__enter__()
|
||||||
|
contexts.append(state[item.optional_vars.id])
|
||||||
|
else:
|
||||||
|
context_var = context_expr.__enter__()
|
||||||
|
contexts.append(context_var)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for stmt in with_node.body:
|
||||||
|
evaluate_ast(stmt, state, tools)
|
||||||
|
except Exception as e:
|
||||||
|
for context in reversed(contexts):
|
||||||
|
context.__exit__(type(e), e, e.__traceback__)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
for context in reversed(contexts):
|
||||||
|
context.__exit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_ast(
|
def evaluate_ast(
|
||||||
@@ -501,7 +634,7 @@ def evaluate_ast(
|
|||||||
encounters assignements.
|
encounters assignements.
|
||||||
tools (`Dict[str, Callable]`):
|
tools (`Dict[str, Callable]`):
|
||||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||||
`InterpretorError`.
|
`InterpreterError`.
|
||||||
authorized_imports (`List[str]`):
|
authorized_imports (`List[str]`):
|
||||||
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
||||||
Add more at your own risk!
|
Add more at your own risk!
|
||||||
@@ -537,8 +670,6 @@ def evaluate_ast(
|
|||||||
elif isinstance(expression, ast.Compare):
|
elif isinstance(expression, ast.Compare):
|
||||||
# Comparison -> evaluate the comparison
|
# Comparison -> evaluate the comparison
|
||||||
return evaluate_condition(expression, state, tools)
|
return evaluate_condition(expression, state, tools)
|
||||||
elif isinstance(expression, ast.Return):
|
|
||||||
return evaluate_ast(expression.value, state, tools)
|
|
||||||
elif isinstance(expression, ast.Lambda):
|
elif isinstance(expression, ast.Lambda):
|
||||||
return evaluate_lambda(expression, state, tools)
|
return evaluate_lambda(expression, state, tools)
|
||||||
elif isinstance(expression, ast.FunctionDef):
|
elif isinstance(expression, ast.FunctionDef):
|
||||||
@@ -615,7 +746,7 @@ def evaluate_ast(
|
|||||||
module = __import__(alias.name)
|
module = __import__(alias.name)
|
||||||
state[alias.asname or alias.name] = module
|
state[alias.asname or alias.name] = module
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(f"Import of {alias.name} is not allowed.")
|
raise InterpreterError(f"Import of {alias.name} is not allowed.")
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.While):
|
elif isinstance(expression, ast.While):
|
||||||
return evaluate_while(expression, state, tools)
|
return evaluate_while(expression, state, tools)
|
||||||
@@ -625,7 +756,7 @@ def evaluate_ast(
|
|||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(f"Import from {expression.module} is not allowed.")
|
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
||||||
return None
|
return None
|
||||||
elif isinstance(expression, ast.ClassDef):
|
elif isinstance(expression, ast.ClassDef):
|
||||||
return evaluate_class_def(expression, state, tools)
|
return evaluate_class_def(expression, state, tools)
|
||||||
@@ -633,9 +764,17 @@ def evaluate_ast(
|
|||||||
return evaluate_try(expression, state, tools)
|
return evaluate_try(expression, state, tools)
|
||||||
elif isinstance(expression, ast.Raise):
|
elif isinstance(expression, ast.Raise):
|
||||||
return evaluate_raise(expression, state, tools)
|
return evaluate_raise(expression, state, tools)
|
||||||
|
elif isinstance(expression, ast.Assert):
|
||||||
|
return evaluate_assert(expression, state, tools)
|
||||||
|
elif isinstance(expression, ast.With):
|
||||||
|
return evaluate_with(expression, state, tools)
|
||||||
|
elif isinstance(expression, ast.Set):
|
||||||
|
return {evaluate_ast(elt, state, tools) for elt in expression.elts}
|
||||||
|
elif isinstance(expression, ast.Return):
|
||||||
|
raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None)
|
||||||
else:
|
else:
|
||||||
# For now we refuse anything else. Let's add things as we need them.
|
# For now we refuse anything else. Let's add things as we need them.
|
||||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_python_code(
|
def evaluate_python_code(
|
||||||
@@ -652,7 +791,7 @@ def evaluate_python_code(
|
|||||||
The code to evaluate.
|
The code to evaluate.
|
||||||
tools (`Dict[str, Callable]`):
|
tools (`Dict[str, Callable]`):
|
||||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||||
`InterpretorError`.
|
`InterpreterError`.
|
||||||
state (`Dict[str, Any]`):
|
state (`Dict[str, Any]`):
|
||||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||||
updated by this function to contain all variables as they are evaluated.
|
updated by this function to contain all variables as they are evaluated.
|
||||||
@@ -665,17 +804,17 @@ def evaluate_python_code(
|
|||||||
if state is None:
|
if state is None:
|
||||||
state = {}
|
state = {}
|
||||||
result = None
|
result = None
|
||||||
state["print_outputs"] = ""
|
global PRINT_OUTPUTS
|
||||||
|
PRINT_OUTPUTS = ""
|
||||||
for idx, node in enumerate(expression.body):
|
for node in expression.body:
|
||||||
try:
|
try:
|
||||||
line_result = evaluate_ast(node, state, tools, authorized_imports)
|
result = evaluate_ast(node, state, tools, authorized_imports)
|
||||||
except InterpretorError as e:
|
except InterpreterError as e:
|
||||||
msg = f"You tried to execute the following code:\n{code}\n"
|
msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||||
msg += f"You got these outputs:\n{state['print_outputs']}\n"
|
if len(PRINT_OUTPUTS) > 0:
|
||||||
msg += f"Evaluation stopped at line '{node}' because of the following error:\n{e}"
|
msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n"
|
||||||
raise InterpretorError(msg)
|
raise InterpreterError(msg)
|
||||||
if line_result is not None:
|
finally:
|
||||||
result = line_result
|
state["print_outputs"] = PRINT_OUTPUTS
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -74,6 +74,26 @@ final_answer(7.2904)
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
|
||||||
|
prompt = str(messages)
|
||||||
|
if "special_marker" not in prompt:
|
||||||
|
return """
|
||||||
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
|
Code:
|
||||||
|
```py
|
||||||
|
print = 2
|
||||||
|
```<end_code>
|
||||||
|
"""
|
||||||
|
else: # We're at step 2
|
||||||
|
return """
|
||||||
|
Thought: I can now answer the initial question
|
||||||
|
Code:
|
||||||
|
```py
|
||||||
|
final_answer("got an error")
|
||||||
|
```<end_code>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
||||||
return """
|
return """
|
||||||
Thought: I should multiply 2 by 3.6452. special_marker
|
Thought: I should multiply 2 by 3.6452. special_marker
|
||||||
@@ -124,6 +144,13 @@ Action:
|
|||||||
"tool_name": "code interpreter",
|
"tool_name": "code interpreter",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_react_code_agent_code_errors_show_offending_lines(self):
|
||||||
|
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
|
||||||
|
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||||
|
assert isinstance(output, AgentText)
|
||||||
|
assert output == "got an error"
|
||||||
|
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
|
||||||
|
|
||||||
def test_setup_agent_with_empty_toolbox(self):
|
def test_setup_agent_with_empty_toolbox(self):
|
||||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import pytest
|
|||||||
from transformers import load_tool
|
from transformers import load_tool
|
||||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||||
from transformers.agents.python_interpreter import InterpretorError, evaluate_python_code
|
from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin
|
from .test_tools_common import ToolTesterMixin
|
||||||
|
|
||||||
@@ -35,16 +35,6 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
|||||||
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
|
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
def test_exact_match_input_spec(self):
|
|
||||||
inputs_spec = self.tool.inputs
|
|
||||||
expected_description = (
|
|
||||||
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
|
||||||
"else you will get an error. This code can only import the following python libraries: "
|
|
||||||
"['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', "
|
|
||||||
"'random', 're']."
|
|
||||||
)
|
|
||||||
self.assertEqual(inputs_spec["code"]["description"], expected_description)
|
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
result = self.tool("(2 / 2) * 4")
|
result = self.tool("(2 / 2) * 4")
|
||||||
self.assertEqual(result, "4.0")
|
self.assertEqual(result, "4.0")
|
||||||
@@ -91,6 +81,17 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||||||
assert result == 5
|
assert result == 5
|
||||||
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
||||||
|
|
||||||
|
code = "a=1;b=None"
|
||||||
|
result = evaluate_python_code(code, {}, state={})
|
||||||
|
# evaluate returns the value of the last assignment.
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_assignment_cannot_overwrite_tool(self):
|
||||||
|
code = "print = '3'"
|
||||||
|
with pytest.raises(InterpreterError) as e:
|
||||||
|
evaluate_python_code(code, {"print": print}, state={})
|
||||||
|
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
|
||||||
|
|
||||||
def test_evaluate_call(self):
|
def test_evaluate_call(self):
|
||||||
code = "y = add_two(x)"
|
code = "y = add_two(x)"
|
||||||
state = {"x": 3}
|
state = {"x": 3}
|
||||||
@@ -99,7 +100,7 @@ class PythonInterpreterTester(unittest.TestCase):
|
|||||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||||
|
|
||||||
# Should not work without the tool
|
# Should not work without the tool
|
||||||
with pytest.raises(InterpretorError) as e:
|
with pytest.raises(InterpreterError) as e:
|
||||||
evaluate_python_code(code, {}, state=state)
|
evaluate_python_code(code, {}, state=state)
|
||||||
assert "tried to execute add_two" in str(e.value)
|
assert "tried to execute add_two" in str(e.value)
|
||||||
|
|
||||||
@@ -237,6 +238,12 @@ for block in text_block:
|
|||||||
result = evaluate_python_code(code, {}, state={})
|
result = evaluate_python_code(code, {}, state={})
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
|
||||||
|
code = """
|
||||||
|
digits, i = [1, 2, 3], 1
|
||||||
|
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state)
|
||||||
|
|
||||||
def test_listcomp(self):
|
def test_listcomp(self):
|
||||||
code = "x = [i for i in range(3)]"
|
code = "x = [i for i in range(3)]"
|
||||||
result = evaluate_python_code(code, {"range": range}, state={})
|
result = evaluate_python_code(code, {"range": range}, state={})
|
||||||
@@ -278,10 +285,20 @@ for block in text_block:
|
|||||||
|
|
||||||
# test infinite loop
|
# test infinite loop
|
||||||
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
||||||
with pytest.raises(InterpretorError) as e:
|
with pytest.raises(InterpreterError) as e:
|
||||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert "iterations in While loop exceeded" in str(e)
|
assert "iterations in While loop exceeded" in str(e)
|
||||||
|
|
||||||
|
# test lazy evaluation
|
||||||
|
code = """
|
||||||
|
house_positions = [0, 7, 10, 15, 18, 22, 22]
|
||||||
|
i, n, loc = 0, 7, 30
|
||||||
|
while i < n and house_positions[i] <= loc:
|
||||||
|
i += 1
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||||
|
|
||||||
def test_generator(self):
|
def test_generator(self):
|
||||||
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
@@ -353,7 +370,19 @@ if char.isalpha():
|
|||||||
assert result == "LATIN CAPITAL LETTER A"
|
assert result == "LATIN CAPITAL LETTER A"
|
||||||
|
|
||||||
def test_multiple_comparators(self):
|
def test_multiple_comparators(self):
|
||||||
code = "0x30A0 <= ord('a') <= 0x30FF"
|
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
||||||
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert not result
|
||||||
|
|
||||||
|
code = "0 <= 1 < 4 and 0 <= -5 < 4"
|
||||||
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert not result
|
||||||
|
|
||||||
|
code = "0 <= 4 < 4 and 0 <= 3 < 4"
|
||||||
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert not result
|
||||||
|
|
||||||
|
code = "0 <= 3 < 4 and 0 <= 3 < 4"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
assert result
|
assert result
|
||||||
|
|
||||||
@@ -364,6 +393,16 @@ if char.isalpha():
|
|||||||
assert result == "Ok no one cares"
|
assert result == "Ok no one cares"
|
||||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||||
|
|
||||||
|
# test print in function
|
||||||
|
code = """
|
||||||
|
print("1")
|
||||||
|
def function():
|
||||||
|
print("2")
|
||||||
|
function()"""
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {"print": print}, state)
|
||||||
|
assert state["print_outputs"] == "1\n2\n"
|
||||||
|
|
||||||
def test_tuple_target_in_iterator(self):
|
def test_tuple_target_in_iterator(self):
|
||||||
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
@@ -491,3 +530,147 @@ except ValueError as e:
|
|||||||
state = {}
|
state = {}
|
||||||
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
||||||
assert result == int
|
assert result == int
|
||||||
|
|
||||||
|
def test_tuple_id(self):
|
||||||
|
code = """
|
||||||
|
food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
|
||||||
|
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {}, state=state)
|
||||||
|
assert result == ["orange", "pear"]
|
||||||
|
|
||||||
|
def test_nonsimple_augassign(self):
|
||||||
|
code = """
|
||||||
|
counts_dict = {'a': 0}
|
||||||
|
counts_dict['a'] += 1
|
||||||
|
counts_list = [1, 2, 3]
|
||||||
|
counts_list += [4, 5, 6]
|
||||||
|
|
||||||
|
class Counter:
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
a = Counter()
|
||||||
|
a.count += 1
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {}, state=state)
|
||||||
|
assert state["counts_dict"] == {"a": 1}
|
||||||
|
assert state["counts_list"] == [1, 2, 3, 4, 5, 6]
|
||||||
|
assert state["a"].count == 1
|
||||||
|
|
||||||
|
def test_adding_int_to_list_raises_error(self):
|
||||||
|
code = """
|
||||||
|
counts = [1, 2, 3]
|
||||||
|
counts += 1"""
|
||||||
|
with pytest.raises(InterpreterError) as e:
|
||||||
|
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert "Cannot add non-list value 1 to a list." in str(e)
|
||||||
|
|
||||||
|
def test_error_highlights_correct_line_of_code(self):
|
||||||
|
code = """# Ok this is a very long code
|
||||||
|
# It has many commented lines
|
||||||
|
a = 1
|
||||||
|
b = 2
|
||||||
|
|
||||||
|
# Here is another piece
|
||||||
|
counts = [1, 2, 3]
|
||||||
|
counts += 1
|
||||||
|
b += 1"""
|
||||||
|
with pytest.raises(InterpreterError) as e:
|
||||||
|
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert "Evaluation stopped at line 'counts += 1" in str(e)
|
||||||
|
|
||||||
|
def test_assert(self):
|
||||||
|
code = """
|
||||||
|
assert 1 == 1
|
||||||
|
assert 1 == 2
|
||||||
|
"""
|
||||||
|
with pytest.raises(AssertionError) as e:
|
||||||
|
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert "1 == 2" in str(e) and "1 == 1" not in str(e)
|
||||||
|
|
||||||
|
def test_with_context_manager(self):
|
||||||
|
code = """
|
||||||
|
class SimpleLock:
|
||||||
|
def __init__(self):
|
||||||
|
self.locked = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.locked = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.locked = False
|
||||||
|
|
||||||
|
lock = SimpleLock()
|
||||||
|
|
||||||
|
with lock as l:
|
||||||
|
assert l.locked == True
|
||||||
|
|
||||||
|
assert lock.locked == False
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
tools = {}
|
||||||
|
evaluate_python_code(code, tools, state)
|
||||||
|
|
||||||
|
def test_default_arg_in_function(self):
|
||||||
|
code = """
|
||||||
|
def f(a, b=333, n=1000):
|
||||||
|
return b + n
|
||||||
|
n = f(1, n=667)
|
||||||
|
"""
|
||||||
|
res = evaluate_python_code(code, {}, {})
|
||||||
|
assert res == 1000
|
||||||
|
|
||||||
|
def test_set(self):
|
||||||
|
code = """
|
||||||
|
S1 = {'a', 'b', 'c'}
|
||||||
|
S2 = {'b', 'c', 'd'}
|
||||||
|
S3 = S1.difference(S2)
|
||||||
|
S4 = S1.intersection(S2)
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {}, state=state)
|
||||||
|
assert state["S3"] == {"a"}
|
||||||
|
assert state["S4"] == {"b", "c"}
|
||||||
|
|
||||||
|
def test_break(self):
|
||||||
|
code = """
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
i+= 1
|
||||||
|
if i==3:
|
||||||
|
break
|
||||||
|
|
||||||
|
i"""
|
||||||
|
result = evaluate_python_code(code, {"print": print, "round": round}, state={})
|
||||||
|
assert result == 3
|
||||||
|
|
||||||
|
def test_return(self):
|
||||||
|
# test early returns
|
||||||
|
code = """
|
||||||
|
def add_one(n, shift):
|
||||||
|
if True:
|
||||||
|
return n + shift
|
||||||
|
return n
|
||||||
|
|
||||||
|
add_one(1, 1)
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||||
|
print(state)
|
||||||
|
assert result == 2
|
||||||
|
|
||||||
|
# test returning None
|
||||||
|
code = """
|
||||||
|
def returns_none(a):
|
||||||
|
return
|
||||||
|
|
||||||
|
returns_none(1)
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||||
|
print(state)
|
||||||
|
assert result is None
|
||||||
|
|||||||
Reference in New Issue
Block a user