|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ast |
|
import builtins |
|
import difflib |
|
import inspect |
|
import logging |
|
import math |
|
import re |
|
from collections.abc import Callable, Mapping |
|
from functools import wraps |
|
from importlib import import_module |
|
from types import BuiltinFunctionType, FunctionType, ModuleType |
|
from typing import Any |
|
|
|
from .tools import Tool |
|
from .utils import BASE_BUILTIN_MODULES, truncate_content |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class InterpreterError(ValueError): |
|
""" |
|
An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported |
|
operations. |
|
""" |
|
|
|
pass |
|
|
|
|
|
ERRORS = { |
|
name: getattr(builtins, name) |
|
for name in dir(builtins) |
|
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) |
|
} |
|
|
|
DEFAULT_MAX_LEN_OUTPUT = 50000 |
|
MAX_OPERATIONS = 10000000 |
|
MAX_WHILE_ITERATIONS = 1000000 |
|
|
|
|
|
def custom_print(*args): |
|
return None |
|
|
|
|
|
def nodunder_getattr(obj, name, default=None): |
|
if name.startswith("__") and name.endswith("__"): |
|
raise InterpreterError(f"Forbidden access to dunder attribute: {name}") |
|
return getattr(obj, name, default) |
|
|
|
|
|
BASE_PYTHON_TOOLS = { |
|
"print": custom_print, |
|
"isinstance": isinstance, |
|
"range": range, |
|
"float": float, |
|
"int": int, |
|
"bool": bool, |
|
"str": str, |
|
"set": set, |
|
"list": list, |
|
"dict": dict, |
|
"tuple": tuple, |
|
"round": round, |
|
"ceil": math.ceil, |
|
"floor": math.floor, |
|
"log": math.log, |
|
"exp": math.exp, |
|
"sin": math.sin, |
|
"cos": math.cos, |
|
"tan": math.tan, |
|
"asin": math.asin, |
|
"acos": math.acos, |
|
"atan": math.atan, |
|
"atan2": math.atan2, |
|
"degrees": math.degrees, |
|
"radians": math.radians, |
|
"pow": pow, |
|
"sqrt": math.sqrt, |
|
"len": len, |
|
"sum": sum, |
|
"max": max, |
|
"min": min, |
|
"abs": abs, |
|
"enumerate": enumerate, |
|
"zip": zip, |
|
"reversed": reversed, |
|
"sorted": sorted, |
|
"all": all, |
|
"any": any, |
|
"map": map, |
|
"filter": filter, |
|
"ord": ord, |
|
"chr": chr, |
|
"next": next, |
|
"iter": iter, |
|
"divmod": divmod, |
|
"callable": callable, |
|
"getattr": nodunder_getattr, |
|
"hasattr": hasattr, |
|
"setattr": setattr, |
|
"issubclass": issubclass, |
|
"type": type, |
|
"complex": complex, |
|
} |
|
|
|
|
|
DANGEROUS_MODULES = [ |
|
"builtins", |
|
"io", |
|
"multiprocessing", |
|
"os", |
|
"pathlib", |
|
"pty", |
|
"shutil", |
|
"socket", |
|
"subprocess", |
|
"sys", |
|
] |
|
|
|
DANGEROUS_FUNCTIONS = [ |
|
"builtins.compile", |
|
"builtins.eval", |
|
"builtins.exec", |
|
"builtins.globals", |
|
"builtins.locals", |
|
"builtins.__import__", |
|
"os.popen", |
|
"os.system", |
|
"posix.system", |
|
] |
|
|
|
|
|
def check_safer_result(result: Any, static_tools: dict[str, Callable] = None, authorized_imports: list[str] = None): |
|
""" |
|
Checks if a result is safer according to authorized imports and static tools. |
|
|
|
Args: |
|
result (Any): The result to check. |
|
static_tools (dict[str, Callable]): Dictionary of static tools. |
|
authorized_imports (list[str]): List of authorized imports. |
|
|
|
Raises: |
|
InterpreterError: If the result is not safe |
|
""" |
|
if isinstance(result, ModuleType): |
|
if not check_import_authorized(result.__name__, authorized_imports): |
|
raise InterpreterError(f"Forbidden access to module: {result.__name__}") |
|
elif isinstance(result, dict) and result.get("__spec__"): |
|
if not check_import_authorized(result["__name__"], authorized_imports): |
|
raise InterpreterError(f"Forbidden access to module: {result['__name__']}") |
|
elif isinstance(result, (FunctionType, BuiltinFunctionType)): |
|
for qualified_function_name in DANGEROUS_FUNCTIONS: |
|
module_name, function_name = qualified_function_name.rsplit(".", 1) |
|
if ( |
|
(static_tools is None or function_name not in static_tools) |
|
and result.__name__ == function_name |
|
and result.__module__ == module_name |
|
): |
|
raise InterpreterError(f"Forbidden access to function: {function_name}") |
|
|
|
|
|
def safer_eval(func: Callable): |
|
""" |
|
Decorator to enhance the security of an evaluation function by checking its return value. |
|
|
|
Args: |
|
func (Callable): Evaluation function to be made safer. |
|
|
|
Returns: |
|
Callable: Safer evaluation function with return value check. |
|
""" |
|
|
|
@wraps(func) |
|
def _check_return( |
|
expression, |
|
state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports=BASE_BUILTIN_MODULES, |
|
): |
|
result = func(expression, state, static_tools, custom_tools, authorized_imports=authorized_imports) |
|
check_safer_result(result, static_tools, authorized_imports) |
|
return result |
|
|
|
return _check_return |
|
|
|
|
|
def safer_func( |
|
func: Callable, |
|
static_tools: dict[str, Callable] = BASE_PYTHON_TOOLS, |
|
authorized_imports: list[str] = BASE_BUILTIN_MODULES, |
|
): |
|
""" |
|
Decorator to enhance the security of a function call by checking its return value. |
|
|
|
Args: |
|
func (Callable): Function to be made safer. |
|
static_tools (dict[str, Callable]): Dictionary of static tools. |
|
authorized_imports (list[str]): List of authorized imports. |
|
|
|
Returns: |
|
Callable: Safer function with return value check. |
|
""" |
|
|
|
if isinstance(func, type): |
|
return func |
|
|
|
@wraps(func) |
|
def _check_return(*args, **kwargs): |
|
result = func(*args, **kwargs) |
|
check_safer_result(result, static_tools, authorized_imports) |
|
return result |
|
|
|
return _check_return |
|
|
|
|
|
class PrintContainer: |
|
def __init__(self): |
|
self.value = "" |
|
|
|
def append(self, text): |
|
self.value += text |
|
return self |
|
|
|
def __iadd__(self, other): |
|
"""Implements the += operator""" |
|
self.value += str(other) |
|
return self |
|
|
|
def __str__(self): |
|
"""String representation""" |
|
return self.value |
|
|
|
def __repr__(self): |
|
"""Representation for debugging""" |
|
return f"PrintContainer({self.value})" |
|
|
|
def __len__(self): |
|
"""Implements len() function support""" |
|
return len(self.value) |
|
|
|
|
|
class BreakException(Exception): |
|
pass |
|
|
|
|
|
class ContinueException(Exception): |
|
pass |
|
|
|
|
|
class ReturnException(Exception): |
|
def __init__(self, value): |
|
self.value = value |
|
|
|
|
|
def get_iterable(obj): |
|
if isinstance(obj, list): |
|
return obj |
|
elif hasattr(obj, "__iter__"): |
|
return list(obj) |
|
else: |
|
raise InterpreterError("Object is not iterable") |
|
|
|
|
|
def fix_final_answer_code(code: str) -> str: |
|
""" |
|
Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool. |
|
This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable, |
|
while preserving function calls to final_answer(). |
|
""" |
|
|
|
|
|
assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*=" |
|
if "final_answer(" not in code or not re.search(assignment_pattern, code): |
|
|
|
|
|
return code |
|
|
|
|
|
|
|
|
|
assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)" |
|
code = re.sub(assignment_regex, r"final_answer_variable\2", code) |
|
|
|
|
|
|
|
|
|
variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()" |
|
code = re.sub(variable_regex, "final_answer_variable", code) |
|
return code |
|
|
|
|
|
def build_import_tree(authorized_imports: list[str]) -> dict[str, Any]: |
|
tree = {} |
|
for import_path in authorized_imports: |
|
parts = import_path.split(".") |
|
current = tree |
|
for part in parts: |
|
if part not in current: |
|
current[part] = {} |
|
current = current[part] |
|
return tree |
|
|
|
|
|
def check_import_authorized(import_to_check: str, authorized_imports: list[str]) -> bool: |
|
current_node = build_import_tree(authorized_imports) |
|
for part in import_to_check.split("."): |
|
if "*" in current_node: |
|
return True |
|
if part not in current_node: |
|
return False |
|
current_node = current_node[part] |
|
return True |
|
|
|
|
|
def evaluate_attribute( |
|
expression: ast.Attribute, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
if expression.attr.startswith("__") and expression.attr.endswith("__"): |
|
raise InterpreterError(f"Forbidden access to dunder attribute: {expression.attr}") |
|
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) |
|
return getattr(value, expression.attr) |
|
|
|
|
|
def evaluate_unaryop( |
|
expression: ast.UnaryOp, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports) |
|
if isinstance(expression.op, ast.USub): |
|
return -operand |
|
elif isinstance(expression.op, ast.UAdd): |
|
return operand |
|
elif isinstance(expression.op, ast.Not): |
|
return not operand |
|
elif isinstance(expression.op, ast.Invert): |
|
return ~operand |
|
else: |
|
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") |
|
|
|
|
|
def evaluate_lambda( |
|
lambda_expression: ast.Lambda, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Callable: |
|
args = [arg.arg for arg in lambda_expression.args.args] |
|
|
|
def lambda_func(*values: Any) -> Any: |
|
new_state = state.copy() |
|
for arg, value in zip(args, values): |
|
new_state[arg] = value |
|
return evaluate_ast( |
|
lambda_expression.body, |
|
new_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
|
|
return lambda_func |
|
|
|
|
|
def evaluate_while( |
|
while_loop: ast.While, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> None: |
|
iterations = 0 |
|
while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports): |
|
for node in while_loop.body: |
|
try: |
|
evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) |
|
except BreakException: |
|
return None |
|
except ContinueException: |
|
break |
|
iterations += 1 |
|
if iterations > MAX_WHILE_ITERATIONS: |
|
raise InterpreterError(f"Maximum number of {MAX_WHILE_ITERATIONS} iterations in While loop exceeded") |
|
return None |
|
|
|
|
|
def create_function( |
|
func_def: ast.FunctionDef, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Callable: |
|
source_code = ast.unparse(func_def) |
|
|
|
def new_func(*args: Any, **kwargs: Any) -> Any: |
|
func_state = state.copy() |
|
arg_names = [arg.arg for arg in func_def.args.args] |
|
default_values = [ |
|
evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults |
|
] |
|
|
|
|
|
defaults = dict(zip(arg_names[-len(default_values) :], default_values)) |
|
|
|
|
|
for name, value in zip(arg_names, args): |
|
func_state[name] = value |
|
|
|
|
|
for name, value in kwargs.items(): |
|
func_state[name] = value |
|
|
|
|
|
if func_def.args.vararg: |
|
vararg_name = func_def.args.vararg.arg |
|
func_state[vararg_name] = args |
|
|
|
if func_def.args.kwarg: |
|
kwarg_name = func_def.args.kwarg.arg |
|
func_state[kwarg_name] = kwargs |
|
|
|
|
|
for name, value in defaults.items(): |
|
if name not in func_state: |
|
func_state[name] = value |
|
|
|
|
|
if func_def.args.args and func_def.args.args[0].arg == "self": |
|
if args: |
|
func_state["self"] = args[0] |
|
func_state["__class__"] = args[0].__class__ |
|
|
|
result = None |
|
try: |
|
for stmt in func_def.body: |
|
result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports) |
|
except ReturnException as e: |
|
result = e.value |
|
|
|
if func_def.name == "__init__": |
|
return None |
|
|
|
return result |
|
|
|
|
|
new_func.__ast__ = func_def |
|
new_func.__source__ = source_code |
|
new_func.__name__ = func_def.name |
|
|
|
return new_func |
|
|
|
|
|
def evaluate_function_def( |
|
func_def: ast.FunctionDef, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Callable: |
|
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports) |
|
return custom_tools[func_def.name] |
|
|
|
|
|
def evaluate_class_def( |
|
class_def: ast.ClassDef, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> type: |
|
class_name = class_def.name |
|
bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases] |
|
class_dict = {} |
|
|
|
for stmt in class_def.body: |
|
if isinstance(stmt, ast.FunctionDef): |
|
class_dict[stmt.name] = evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) |
|
elif isinstance(stmt, ast.AnnAssign): |
|
if stmt.value: |
|
value = evaluate_ast(stmt.value, state, static_tools, custom_tools, authorized_imports) |
|
target = stmt.target |
|
|
|
if isinstance(target, ast.Name): |
|
|
|
annotation = evaluate_ast(stmt.annotation, state, static_tools, custom_tools, authorized_imports) |
|
class_dict.setdefault("__annotations__", {})[target.id] = annotation |
|
|
|
if stmt.value: |
|
class_dict[target.id] = value |
|
elif isinstance(target, ast.Attribute): |
|
|
|
obj = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports) |
|
|
|
if stmt.value: |
|
setattr(obj, target.attr, value) |
|
elif isinstance(target, ast.Subscript): |
|
|
|
container = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports) |
|
index = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) |
|
|
|
if stmt.value: |
|
container[index] = value |
|
else: |
|
raise InterpreterError(f"Unsupported AnnAssign target in class body: {type(target).__name__}") |
|
elif isinstance(stmt, ast.Assign): |
|
value = evaluate_ast(stmt.value, state, static_tools, custom_tools, authorized_imports) |
|
for target in stmt.targets: |
|
if isinstance(target, ast.Name): |
|
class_dict[target.id] = value |
|
elif isinstance(target, ast.Attribute): |
|
obj = evaluate_ast(target.value, class_dict, static_tools, custom_tools, authorized_imports) |
|
setattr(obj, target.attr, value) |
|
elif isinstance(stmt, ast.Pass): |
|
pass |
|
elif ( |
|
isinstance(stmt, ast.Expr) |
|
and stmt == class_def.body[0] |
|
and isinstance(stmt.value, ast.Constant) |
|
and isinstance(stmt.value.value, str) |
|
): |
|
|
|
class_dict["__doc__"] = stmt.value.value |
|
else: |
|
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") |
|
|
|
new_class = type(class_name, tuple(bases), class_dict) |
|
state[class_name] = new_class |
|
return new_class |
|
|
|
|
|
def evaluate_annassign( |
|
annassign: ast.AnnAssign, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
|
|
if annassign.value: |
|
value = evaluate_ast(annassign.value, state, static_tools, custom_tools, authorized_imports) |
|
|
|
set_value(annassign.target, value, state, static_tools, custom_tools, authorized_imports) |
|
return value |
|
|
|
return None |
|
|
|
|
|
def evaluate_augassign( |
|
expression: ast.AugAssign, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
def get_current_value(target: ast.AST) -> Any: |
|
if isinstance(target, ast.Name): |
|
return state.get(target.id, 0) |
|
elif isinstance(target, ast.Subscript): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) |
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) |
|
return obj[key] |
|
elif isinstance(target, ast.Attribute): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) |
|
return getattr(obj, target.attr) |
|
elif isinstance(target, ast.Tuple): |
|
return tuple(get_current_value(elt) for elt in target.elts) |
|
elif isinstance(target, ast.List): |
|
return [get_current_value(elt) for elt in target.elts] |
|
else: |
|
raise InterpreterError("AugAssign not supported for {type(target)} targets.") |
|
|
|
current_value = get_current_value(expression.target) |
|
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) |
|
|
|
if isinstance(expression.op, ast.Add): |
|
if isinstance(current_value, list): |
|
if not isinstance(value_to_add, list): |
|
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") |
|
current_value += value_to_add |
|
else: |
|
current_value += value_to_add |
|
elif isinstance(expression.op, ast.Sub): |
|
current_value -= value_to_add |
|
elif isinstance(expression.op, ast.Mult): |
|
current_value *= value_to_add |
|
elif isinstance(expression.op, ast.Div): |
|
current_value /= value_to_add |
|
elif isinstance(expression.op, ast.Mod): |
|
current_value %= value_to_add |
|
elif isinstance(expression.op, ast.Pow): |
|
current_value **= value_to_add |
|
elif isinstance(expression.op, ast.FloorDiv): |
|
current_value //= value_to_add |
|
elif isinstance(expression.op, ast.BitAnd): |
|
current_value &= value_to_add |
|
elif isinstance(expression.op, ast.BitOr): |
|
current_value |= value_to_add |
|
elif isinstance(expression.op, ast.BitXor): |
|
current_value ^= value_to_add |
|
elif isinstance(expression.op, ast.LShift): |
|
current_value <<= value_to_add |
|
elif isinstance(expression.op, ast.RShift): |
|
current_value >>= value_to_add |
|
else: |
|
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") |
|
|
|
|
|
set_value( |
|
expression.target, |
|
current_value, |
|
state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
|
|
return current_value |
|
|
|
|
|
def evaluate_boolop( |
|
node: ast.BoolOp, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
|
|
|
|
|
|
is_short_circuit_value = (lambda x: not x) if isinstance(node.op, ast.And) else (lambda x: bool(x)) |
|
for value in node.values: |
|
result = evaluate_ast(value, state, static_tools, custom_tools, authorized_imports) |
|
|
|
if is_short_circuit_value(result): |
|
return result |
|
|
|
return result |
|
|
|
|
|
def evaluate_binop( |
|
binop: ast.BinOp, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
|
|
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports) |
|
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports) |
|
|
|
|
|
if isinstance(binop.op, ast.Add): |
|
return left_val + right_val |
|
elif isinstance(binop.op, ast.Sub): |
|
return left_val - right_val |
|
elif isinstance(binop.op, ast.Mult): |
|
return left_val * right_val |
|
elif isinstance(binop.op, ast.Div): |
|
return left_val / right_val |
|
elif isinstance(binop.op, ast.Mod): |
|
return left_val % right_val |
|
elif isinstance(binop.op, ast.Pow): |
|
return left_val**right_val |
|
elif isinstance(binop.op, ast.FloorDiv): |
|
return left_val // right_val |
|
elif isinstance(binop.op, ast.BitAnd): |
|
return left_val & right_val |
|
elif isinstance(binop.op, ast.BitOr): |
|
return left_val | right_val |
|
elif isinstance(binop.op, ast.BitXor): |
|
return left_val ^ right_val |
|
elif isinstance(binop.op, ast.LShift): |
|
return left_val << right_val |
|
elif isinstance(binop.op, ast.RShift): |
|
return left_val >> right_val |
|
else: |
|
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") |
|
|
|
|
|
def evaluate_assign( |
|
assign: ast.Assign, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports) |
|
if len(assign.targets) == 1: |
|
target = assign.targets[0] |
|
set_value(target, result, state, static_tools, custom_tools, authorized_imports) |
|
else: |
|
expanded_values = [] |
|
for tgt in assign.targets: |
|
if isinstance(tgt, ast.Starred): |
|
expanded_values.extend(result) |
|
else: |
|
expanded_values.append(result) |
|
|
|
for tgt, val in zip(assign.targets, expanded_values): |
|
set_value(tgt, val, state, static_tools, custom_tools, authorized_imports) |
|
return result |
|
|
|
|
|
def set_value( |
|
target: ast.AST, |
|
value: Any, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> None: |
|
if isinstance(target, ast.Name): |
|
if target.id in static_tools: |
|
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") |
|
state[target.id] = value |
|
elif isinstance(target, ast.Tuple): |
|
if not isinstance(value, tuple): |
|
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): |
|
value = tuple(value) |
|
else: |
|
raise InterpreterError("Cannot unpack non-tuple value") |
|
if len(target.elts) != len(value): |
|
raise InterpreterError("Cannot unpack tuple of wrong size") |
|
for i, elem in enumerate(target.elts): |
|
set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports) |
|
elif isinstance(target, ast.Subscript): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) |
|
key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) |
|
obj[key] = value |
|
elif isinstance(target, ast.Attribute): |
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) |
|
setattr(obj, target.attr, value) |
|
|
|
|
|
def evaluate_call( |
|
call: ast.Call, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
if not isinstance(call.func, (ast.Call, ast.Lambda, ast.Attribute, ast.Name, ast.Subscript)): |
|
raise InterpreterError(f"This is not a correct function: {call.func}).") |
|
|
|
func, func_name = None, None |
|
|
|
if isinstance(call.func, ast.Call): |
|
func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports) |
|
elif isinstance(call.func, ast.Lambda): |
|
func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports) |
|
elif isinstance(call.func, ast.Attribute): |
|
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports) |
|
func_name = call.func.attr |
|
if not hasattr(obj, func_name): |
|
raise InterpreterError(f"Object {obj} has no attribute {func_name}") |
|
func = getattr(obj, func_name) |
|
elif isinstance(call.func, ast.Name): |
|
func_name = call.func.id |
|
if func_name in state: |
|
func = state[func_name] |
|
elif func_name in static_tools: |
|
func = static_tools[func_name] |
|
elif func_name in custom_tools: |
|
func = custom_tools[func_name] |
|
elif func_name in ERRORS: |
|
func = ERRORS[func_name] |
|
else: |
|
raise InterpreterError( |
|
f"Forbidden function evaluation: '{call.func.id}' is not among the explicitly allowed tools or defined/imported in the preceding code" |
|
) |
|
elif isinstance(call.func, ast.Subscript): |
|
func = evaluate_ast(call.func, state, static_tools, custom_tools, authorized_imports) |
|
if not callable(func): |
|
raise InterpreterError(f"This is not a correct function: {call.func}).") |
|
func_name = None |
|
|
|
args = [] |
|
for arg in call.args: |
|
if isinstance(arg, ast.Starred): |
|
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports)) |
|
else: |
|
args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)) |
|
|
|
kwargs = { |
|
keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports) |
|
for keyword in call.keywords |
|
} |
|
|
|
if func_name == "super": |
|
if not args: |
|
if "__class__" in state and "self" in state: |
|
return super(state["__class__"], state["self"]) |
|
else: |
|
raise InterpreterError("super() needs at least one argument") |
|
cls = args[0] |
|
if not isinstance(cls, type): |
|
raise InterpreterError("super() argument 1 must be type") |
|
if len(args) == 1: |
|
return super(cls) |
|
elif len(args) == 2: |
|
instance = args[1] |
|
return super(cls, instance) |
|
else: |
|
raise InterpreterError("super() takes at most 2 arguments") |
|
elif func_name == "print": |
|
state["_print_outputs"] += " ".join(map(str, args)) + "\n" |
|
return None |
|
else: |
|
if (inspect.getmodule(func) == builtins) and inspect.isbuiltin(func) and (func not in static_tools.values()): |
|
raise InterpreterError( |
|
f"Invoking a builtin function that has not been explicitly added as a tool is not allowed ({func_name})." |
|
) |
|
return func(*args, **kwargs) |
|
|
|
|
|
def evaluate_subscript( |
|
subscript: ast.Subscript, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports) |
|
value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports) |
|
try: |
|
return value[index] |
|
except (KeyError, IndexError, TypeError) as e: |
|
error_message = f"Could not index {value} with '{index}': {type(e).__name__}: {e}" |
|
if isinstance(index, str) and isinstance(value, Mapping): |
|
close_matches = difflib.get_close_matches(index, list(value.keys())) |
|
if len(close_matches) > 0: |
|
error_message += f". Maybe you meant one of these indexes instead: {str(close_matches)}" |
|
raise InterpreterError(error_message) from e |
|
|
|
|
|
def evaluate_name( |
|
name: ast.Name, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
if name.id in state: |
|
return state[name.id] |
|
elif name.id in static_tools: |
|
return safer_func(static_tools[name.id], static_tools=static_tools, authorized_imports=authorized_imports) |
|
elif name.id in custom_tools: |
|
return custom_tools[name.id] |
|
elif name.id in ERRORS: |
|
return ERRORS[name.id] |
|
close_matches = difflib.get_close_matches(name.id, list(state.keys())) |
|
if len(close_matches) > 0: |
|
return state[close_matches[0]] |
|
raise InterpreterError(f"The variable `{name.id}` is not defined.") |
|
|
|
|
|
def evaluate_condition( |
|
condition: ast.Compare, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> bool | object: |
|
result = True |
|
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) |
|
for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)): |
|
op = type(op) |
|
right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports) |
|
if op == ast.Eq: |
|
current_result = left == right |
|
elif op == ast.NotEq: |
|
current_result = left != right |
|
elif op == ast.Lt: |
|
current_result = left < right |
|
elif op == ast.LtE: |
|
current_result = left <= right |
|
elif op == ast.Gt: |
|
current_result = left > right |
|
elif op == ast.GtE: |
|
current_result = left >= right |
|
elif op == ast.Is: |
|
current_result = left is right |
|
elif op == ast.IsNot: |
|
current_result = left is not right |
|
elif op == ast.In: |
|
current_result = left in right |
|
elif op == ast.NotIn: |
|
current_result = left not in right |
|
else: |
|
raise InterpreterError(f"Unsupported comparison operator: {op}") |
|
|
|
if current_result is False: |
|
return False |
|
result = current_result if i == 0 else (result and current_result) |
|
left = right |
|
return result |
|
|
|
|
|
def evaluate_if( |
|
if_statement: ast.If, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
result = None |
|
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports) |
|
if test_result: |
|
for line in if_statement.body: |
|
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports) |
|
if line_result is not None: |
|
result = line_result |
|
else: |
|
for line in if_statement.orelse: |
|
line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports) |
|
if line_result is not None: |
|
result = line_result |
|
return result |
|
|
|
|
|
def evaluate_for( |
|
for_loop: ast.For, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> Any: |
|
result = None |
|
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports) |
|
for counter in iterator: |
|
set_value( |
|
for_loop.target, |
|
counter, |
|
state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
for node in for_loop.body: |
|
try: |
|
line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) |
|
if line_result is not None: |
|
result = line_result |
|
except BreakException: |
|
break |
|
except ContinueException: |
|
continue |
|
else: |
|
continue |
|
break |
|
return result |
|
|
|
|
|
def evaluate_listcomp( |
|
listcomp: ast.ListComp, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> list[Any]: |
|
def inner_evaluate(generators: list[ast.comprehension], index: int, current_state: dict[str, Any]) -> list[Any]: |
|
if index >= len(generators): |
|
return [ |
|
evaluate_ast( |
|
listcomp.elt, |
|
current_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
] |
|
generator = generators[index] |
|
iter_value = evaluate_ast( |
|
generator.iter, |
|
current_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
result = [] |
|
for value in iter_value: |
|
new_state = current_state.copy() |
|
if isinstance(generator.target, ast.Tuple): |
|
for idx, elem in enumerate(generator.target.elts): |
|
new_state[elem.id] = value[idx] |
|
else: |
|
new_state[generator.target.id] = value |
|
if all( |
|
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) |
|
for if_clause in generator.ifs |
|
): |
|
result.extend(inner_evaluate(generators, index + 1, new_state)) |
|
return result |
|
|
|
return inner_evaluate(listcomp.generators, 0, state) |
|
|
|
|
|
def evaluate_setcomp( |
|
setcomp: ast.SetComp, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> set[Any]: |
|
result = set() |
|
for gen in setcomp.generators: |
|
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports) |
|
for value in iter_value: |
|
new_state = state.copy() |
|
set_value( |
|
gen.target, |
|
value, |
|
new_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
if all( |
|
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) |
|
for if_clause in gen.ifs |
|
): |
|
element = evaluate_ast( |
|
setcomp.elt, |
|
new_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
result.add(element) |
|
return result |
|
|
|
|
|
def evaluate_try( |
|
try_node: ast.Try, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> None: |
|
try: |
|
for stmt in try_node.body: |
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) |
|
except Exception as e: |
|
matched = False |
|
for handler in try_node.handlers: |
|
if handler.type is None or isinstance( |
|
e, |
|
evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports), |
|
): |
|
matched = True |
|
if handler.name: |
|
state[handler.name] = e |
|
for stmt in handler.body: |
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) |
|
break |
|
if not matched: |
|
raise e |
|
else: |
|
if try_node.orelse: |
|
for stmt in try_node.orelse: |
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) |
|
finally: |
|
if try_node.finalbody: |
|
for stmt in try_node.finalbody: |
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) |
|
|
|
|
|
def evaluate_raise( |
|
raise_node: ast.Raise, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> None: |
|
if raise_node.exc is not None: |
|
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports) |
|
else: |
|
exc = None |
|
if raise_node.cause is not None: |
|
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports) |
|
else: |
|
cause = None |
|
if exc is not None: |
|
if cause is not None: |
|
raise exc from cause |
|
else: |
|
raise exc |
|
else: |
|
raise InterpreterError("Re-raise is not supported without an active exception") |
|
|
|
|
|
def evaluate_assert( |
|
assert_node: ast.Assert, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> None: |
|
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports) |
|
if not test_result: |
|
if assert_node.msg: |
|
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports) |
|
raise AssertionError(msg) |
|
else: |
|
|
|
test_code = ast.unparse(assert_node.test) |
|
raise AssertionError(f"Assertion failed: {test_code}") |
|
|
|
|
|
def evaluate_with( |
|
with_node: ast.With, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> None: |
|
contexts = [] |
|
for item in with_node.items: |
|
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports) |
|
if item.optional_vars: |
|
state[item.optional_vars.id] = context_expr.__enter__() |
|
contexts.append(state[item.optional_vars.id]) |
|
else: |
|
context_var = context_expr.__enter__() |
|
contexts.append(context_var) |
|
|
|
try: |
|
for stmt in with_node.body: |
|
evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) |
|
except Exception as e: |
|
for context in reversed(contexts): |
|
context.__exit__(type(e), e, e.__traceback__) |
|
raise |
|
else: |
|
for context in reversed(contexts): |
|
context.__exit__(None, None, None) |
|
|
|
|
|
def get_safe_module(raw_module, authorized_imports, visited=None): |
|
"""Creates a safe copy of a module or returns the original if it's a function""" |
|
|
|
if not isinstance(raw_module, ModuleType): |
|
return raw_module |
|
|
|
|
|
if visited is None: |
|
visited = set() |
|
|
|
module_id = id(raw_module) |
|
if module_id in visited: |
|
return raw_module |
|
|
|
visited.add(module_id) |
|
|
|
|
|
safe_module = ModuleType(raw_module.__name__) |
|
|
|
|
|
for attr_name in dir(raw_module): |
|
try: |
|
attr_value = getattr(raw_module, attr_name) |
|
except (ImportError, AttributeError) as e: |
|
|
|
logger.info( |
|
f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}" |
|
) |
|
continue |
|
|
|
if isinstance(attr_value, ModuleType): |
|
attr_value = get_safe_module(attr_value, authorized_imports, visited=visited) |
|
|
|
setattr(safe_module, attr_name, attr_value) |
|
|
|
return safe_module |
|
|
|
|
|
def evaluate_import(expression, state, authorized_imports): |
|
if isinstance(expression, ast.Import): |
|
for alias in expression.names: |
|
if check_import_authorized(alias.name, authorized_imports): |
|
raw_module = import_module(alias.name) |
|
state[alias.asname or alias.name] = get_safe_module(raw_module, authorized_imports) |
|
else: |
|
raise InterpreterError( |
|
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" |
|
) |
|
return None |
|
elif isinstance(expression, ast.ImportFrom): |
|
if check_import_authorized(expression.module, authorized_imports): |
|
raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) |
|
module = get_safe_module(raw_module, authorized_imports) |
|
if expression.names[0].name == "*": |
|
if hasattr(module, "__all__"): |
|
for name in module.__all__: |
|
state[name] = getattr(module, name) |
|
else: |
|
for name in dir(module): |
|
if not name.startswith("_"): |
|
state[name] = getattr(module, name) |
|
else: |
|
for alias in expression.names: |
|
if hasattr(module, alias.name): |
|
state[alias.asname or alias.name] = getattr(module, alias.name) |
|
else: |
|
raise InterpreterError(f"Module {expression.module} has no attribute {alias.name}") |
|
else: |
|
raise InterpreterError( |
|
f"Import from {expression.module} is not allowed. Authorized imports are: {str(authorized_imports)}" |
|
) |
|
return None |
|
|
|
|
|
def evaluate_dictcomp( |
|
dictcomp: ast.DictComp, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> dict[Any, Any]: |
|
result = {} |
|
for gen in dictcomp.generators: |
|
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports) |
|
for value in iter_value: |
|
new_state = state.copy() |
|
set_value( |
|
gen.target, |
|
value, |
|
new_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
if all( |
|
evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) |
|
for if_clause in gen.ifs |
|
): |
|
key = evaluate_ast( |
|
dictcomp.key, |
|
new_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
val = evaluate_ast( |
|
dictcomp.value, |
|
new_state, |
|
static_tools, |
|
custom_tools, |
|
authorized_imports, |
|
) |
|
result[key] = val |
|
return result |
|
|
|
|
|
def evaluate_delete( |
|
delete_node: ast.Delete, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str], |
|
) -> None: |
|
""" |
|
Evaluate a delete statement (del x, del x[y]). |
|
|
|
Args: |
|
delete_node: The AST Delete node to evaluate |
|
state: The current state dictionary |
|
static_tools: Dictionary of static tools |
|
custom_tools: Dictionary of custom tools |
|
authorized_imports: List of authorized imports |
|
""" |
|
for target in delete_node.targets: |
|
if isinstance(target, ast.Name): |
|
|
|
if target.id in state: |
|
del state[target.id] |
|
else: |
|
raise InterpreterError(f"Cannot delete name '{target.id}': name is not defined") |
|
elif isinstance(target, ast.Subscript): |
|
|
|
obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) |
|
index = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) |
|
try: |
|
del obj[index] |
|
except (TypeError, KeyError, IndexError) as e: |
|
raise InterpreterError(f"Cannot delete index/key: {str(e)}") |
|
else: |
|
raise InterpreterError(f"Deletion of {type(target).__name__} targets is not supported") |
|
|
|
|
|
@safer_eval |
|
def evaluate_ast( |
|
expression: ast.AST, |
|
state: dict[str, Any], |
|
static_tools: dict[str, Callable], |
|
custom_tools: dict[str, Callable], |
|
authorized_imports: list[str] = BASE_BUILTIN_MODULES, |
|
): |
|
""" |
|
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given |
|
set of functions. |
|
|
|
This function will recurse through the nodes of the tree provided. |
|
|
|
Args: |
|
expression (`ast.AST`): |
|
The code to evaluate, as an abstract syntax tree. |
|
state (`Dict[str, Any]`): |
|
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation |
|
encounters assignments. |
|
static_tools (`Dict[str, Callable]`): |
|
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error. |
|
custom_tools (`Dict[str, Callable]`): |
|
Functions that may be called during the evaluation. These custom_tools can be overwritten. |
|
authorized_imports (`List[str]`): |
|
The list of modules that can be imported by the code. By default, only a few safe modules are allowed. |
|
If it contains "*", it will authorize any import. Use this at your own risk! |
|
""" |
|
if state.setdefault("_operations_count", {"counter": 0})["counter"] >= MAX_OPERATIONS: |
|
raise InterpreterError( |
|
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations." |
|
) |
|
state["_operations_count"]["counter"] += 1 |
|
common_params = (state, static_tools, custom_tools, authorized_imports) |
|
if isinstance(expression, ast.Assign): |
|
|
|
|
|
return evaluate_assign(expression, *common_params) |
|
elif isinstance(expression, ast.AnnAssign): |
|
return evaluate_annassign(expression, *common_params) |
|
elif isinstance(expression, ast.AugAssign): |
|
return evaluate_augassign(expression, *common_params) |
|
elif isinstance(expression, ast.Call): |
|
|
|
return evaluate_call(expression, *common_params) |
|
elif isinstance(expression, ast.Constant): |
|
|
|
return expression.value |
|
elif isinstance(expression, ast.Tuple): |
|
return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts)) |
|
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): |
|
return evaluate_listcomp(expression, *common_params) |
|
elif isinstance(expression, ast.DictComp): |
|
return evaluate_dictcomp(expression, *common_params) |
|
elif isinstance(expression, ast.SetComp): |
|
return evaluate_setcomp(expression, *common_params) |
|
elif isinstance(expression, ast.UnaryOp): |
|
return evaluate_unaryop(expression, *common_params) |
|
elif isinstance(expression, ast.Starred): |
|
return evaluate_ast(expression.value, *common_params) |
|
elif isinstance(expression, ast.BoolOp): |
|
|
|
return evaluate_boolop(expression, *common_params) |
|
elif isinstance(expression, ast.Break): |
|
raise BreakException() |
|
elif isinstance(expression, ast.Continue): |
|
raise ContinueException() |
|
elif isinstance(expression, ast.BinOp): |
|
|
|
return evaluate_binop(expression, *common_params) |
|
elif isinstance(expression, ast.Compare): |
|
|
|
return evaluate_condition(expression, *common_params) |
|
elif isinstance(expression, ast.Lambda): |
|
return evaluate_lambda(expression, *common_params) |
|
elif isinstance(expression, ast.FunctionDef): |
|
return evaluate_function_def(expression, *common_params) |
|
elif isinstance(expression, ast.Dict): |
|
|
|
keys = (evaluate_ast(k, *common_params) for k in expression.keys) |
|
values = (evaluate_ast(v, *common_params) for v in expression.values) |
|
return dict(zip(keys, values)) |
|
elif isinstance(expression, ast.Expr): |
|
|
|
return evaluate_ast(expression.value, *common_params) |
|
elif isinstance(expression, ast.For): |
|
|
|
return evaluate_for(expression, *common_params) |
|
elif isinstance(expression, ast.FormattedValue): |
|
|
|
value = evaluate_ast(expression.value, *common_params) |
|
|
|
if not expression.format_spec: |
|
return value |
|
|
|
format_spec = evaluate_ast(expression.format_spec, *common_params) |
|
return format(value, format_spec) |
|
elif isinstance(expression, ast.If): |
|
|
|
return evaluate_if(expression, *common_params) |
|
elif hasattr(ast, "Index") and isinstance(expression, ast.Index): |
|
return evaluate_ast(expression.value, *common_params) |
|
elif isinstance(expression, ast.JoinedStr): |
|
return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values]) |
|
elif isinstance(expression, ast.List): |
|
|
|
return [evaluate_ast(elt, *common_params) for elt in expression.elts] |
|
elif isinstance(expression, ast.Name): |
|
|
|
return evaluate_name(expression, *common_params) |
|
elif isinstance(expression, ast.Subscript): |
|
|
|
return evaluate_subscript(expression, *common_params) |
|
elif isinstance(expression, ast.IfExp): |
|
test_val = evaluate_ast(expression.test, *common_params) |
|
if test_val: |
|
return evaluate_ast(expression.body, *common_params) |
|
else: |
|
return evaluate_ast(expression.orelse, *common_params) |
|
elif isinstance(expression, ast.Attribute): |
|
return evaluate_attribute(expression, *common_params) |
|
elif isinstance(expression, ast.Slice): |
|
return slice( |
|
evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None, |
|
evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None, |
|
evaluate_ast(expression.step, *common_params) if expression.step is not None else None, |
|
) |
|
elif isinstance(expression, ast.While): |
|
return evaluate_while(expression, *common_params) |
|
elif isinstance(expression, (ast.Import, ast.ImportFrom)): |
|
return evaluate_import(expression, state, authorized_imports) |
|
elif isinstance(expression, ast.ClassDef): |
|
return evaluate_class_def(expression, *common_params) |
|
elif isinstance(expression, ast.Try): |
|
return evaluate_try(expression, *common_params) |
|
elif isinstance(expression, ast.Raise): |
|
return evaluate_raise(expression, *common_params) |
|
elif isinstance(expression, ast.Assert): |
|
return evaluate_assert(expression, *common_params) |
|
elif isinstance(expression, ast.With): |
|
return evaluate_with(expression, *common_params) |
|
elif isinstance(expression, ast.Set): |
|
return set((evaluate_ast(elt, *common_params) for elt in expression.elts)) |
|
elif isinstance(expression, ast.Return): |
|
raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None) |
|
elif isinstance(expression, ast.Pass): |
|
return None |
|
elif isinstance(expression, ast.Delete): |
|
return evaluate_delete(expression, *common_params) |
|
else: |
|
|
|
raise InterpreterError(f"{expression.__class__.__name__} is not supported.") |
|
|
|
|
|
class FinalAnswerException(Exception): |
|
def __init__(self, value): |
|
self.value = value |
|
|
|
|
|
def evaluate_python_code( |
|
code: str, |
|
static_tools: dict[str, Callable] | None = None, |
|
custom_tools: dict[str, Callable] | None = None, |
|
state: dict[str, Any] | None = None, |
|
authorized_imports: list[str] = BASE_BUILTIN_MODULES, |
|
max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT, |
|
): |
|
""" |
|
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set |
|
of functions. |
|
|
|
This function will recurse through the nodes of the tree provided. |
|
|
|
Args: |
|
code (`str`): |
|
The code to evaluate. |
|
static_tools (`Dict[str, Callable]`): |
|
The functions that may be called during the evaluation. These can also be agents in a multiagent setting. |
|
These tools cannot be overwritten in the code: any assignment to their name will raise an error. |
|
custom_tools (`Dict[str, Callable]`): |
|
The functions that may be called during the evaluation. |
|
These tools can be overwritten in the code: any assignment to their name will overwrite them. |
|
state (`Dict[str, Any]`): |
|
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be |
|
updated by this function to contain all variables as they are evaluated. |
|
The print outputs will be stored in the state under the key "_print_outputs". |
|
""" |
|
try: |
|
expression = ast.parse(code) |
|
except SyntaxError as e: |
|
raise InterpreterError( |
|
f"Code parsing failed on line {e.lineno} due to: {type(e).__name__}\n" |
|
f"{e.text}" |
|
f"{' ' * (e.offset or 0)}^\n" |
|
f"Error: {str(e)}" |
|
) |
|
|
|
if state is None: |
|
state = {} |
|
static_tools = static_tools.copy() if static_tools is not None else {} |
|
custom_tools = custom_tools if custom_tools is not None else {} |
|
result = None |
|
state["_print_outputs"] = PrintContainer() |
|
state["_operations_count"] = {"counter": 0} |
|
|
|
if "final_answer" in static_tools: |
|
previous_final_answer = static_tools["final_answer"] |
|
|
|
def final_answer(*args, **kwargs): |
|
raise FinalAnswerException(previous_final_answer(*args, **kwargs)) |
|
|
|
static_tools["final_answer"] = final_answer |
|
|
|
try: |
|
for node in expression.body: |
|
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) |
|
state["_print_outputs"].value = truncate_content( |
|
str(state["_print_outputs"]), max_length=max_print_outputs_length |
|
) |
|
is_final_answer = False |
|
return result, is_final_answer |
|
except FinalAnswerException as e: |
|
state["_print_outputs"].value = truncate_content( |
|
str(state["_print_outputs"]), max_length=max_print_outputs_length |
|
) |
|
is_final_answer = True |
|
return e.value, is_final_answer |
|
except Exception as e: |
|
state["_print_outputs"].value = truncate_content( |
|
str(state["_print_outputs"]), max_length=max_print_outputs_length |
|
) |
|
raise InterpreterError( |
|
f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {type(e).__name__}: {e}" |
|
) |
|
|
|
|
|
class PythonExecutor: |
|
pass |
|
|
|
|
|
class LocalPythonExecutor(PythonExecutor): |
|
""" |
|
Executor of Python code in a local environment. |
|
|
|
This executor evaluates Python code with restricted access to imports and built-in functions, |
|
making it suitable for running untrusted code. It maintains state between executions, |
|
allows for custom tools and functions to be made available to the code, and captures |
|
print outputs separately from return values. |
|
|
|
Args: |
|
additional_authorized_imports (`list[str]`): |
|
Additional authorized imports for the executor. |
|
max_print_outputs_length (`int`, defaults to `DEFAULT_MAX_LEN_OUTPUT=50_000`): |
|
Maximum length of the print outputs. |
|
additional_functions (`dict[str, Callable]`, *optional*): |
|
Additional Python functions to be added to the executor. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
additional_authorized_imports: list[str], |
|
max_print_outputs_length: int | None = None, |
|
additional_functions: dict[str, Callable] | None = None, |
|
): |
|
self.custom_tools = {} |
|
self.state = {"__name__": "__main__"} |
|
self.max_print_outputs_length = max_print_outputs_length |
|
if max_print_outputs_length is None: |
|
self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT |
|
self.additional_authorized_imports = additional_authorized_imports |
|
self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) |
|
|
|
self.static_tools = None |
|
self.additional_functions = additional_functions or {} |
|
|
|
def __call__(self, code_action: str) -> tuple[Any, str, bool]: |
|
output, is_final_answer = evaluate_python_code( |
|
code_action, |
|
static_tools=self.static_tools, |
|
custom_tools=self.custom_tools, |
|
state=self.state, |
|
authorized_imports=self.authorized_imports, |
|
max_print_outputs_length=self.max_print_outputs_length, |
|
) |
|
logs = str(self.state["_print_outputs"]) |
|
return output, logs, is_final_answer |
|
|
|
def send_variables(self, variables: dict): |
|
self.state.update(variables) |
|
|
|
def send_tools(self, tools: dict[str, Tool]): |
|
|
|
self.static_tools = {**tools, **BASE_PYTHON_TOOLS.copy(), **self.additional_functions} |
|
|
|
|
|
__all__ = ["evaluate_python_code", "LocalPythonExecutor"] |
|
|