data_science_agent / utils /code_sandbox.py
bpHigh's picture
terminate sandboxes
c209a97 verified
import ast
import modal
import io
import sys
def detect_dependencies(code_snippet: str) -> list[str]:
tree = ast.parse(code_snippet)
imports = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for n in node.names:
imports.add(n.name.split('.')[0])
elif isinstance(node, ast.ImportFrom):
if node.module:
imports.add(node.module.split('.')[0])
imports = list(imports)
if 'sklearn' in imports:
imports[imports.index('sklearn')] = 'scikit-learn'
return imports
def build_sandbox(requirements: list[str], app: modal.App) -> tuple[modal.Sandbox, str]:
buffer = io.StringIO()
original_stdout = sys.stdout
sys.stdout = buffer
try:
with modal.enable_output():
image = modal.Image.debian_slim(python_version='3.10').pip_install(*requirements)
sandbox = modal.Sandbox.create(app=app, image=image, timeout=600, gpu="T4")
finally:
sys.stdout = original_stdout
logs = buffer.getvalue()
return sandbox, logs
def code_eval(code_snippet: str) -> tuple[dict, str]:
"""
Run a python code snippet into a sandbox environment.
Args:
code_snippet (str): The Python code to execute.
Returns:
tuple[dict, str]:
- A dictionary containing execution results:
'stdout', 'stderr', 'returncode', and 'error' (if any).
- A string with the image build logs (empty if no dependencies were detected).
"""
try:
ast.parse(code_snippet)
except SyntaxError as e:
return {
"error": str(e),
"stdout": "",
"stderr": "",
"returncode": 1,
}, ""
app = modal.App.lookup("my-app", create_if_missing=True)
requirements = detect_dependencies(code_snippet)
if requirements:
sb, build_logs = build_sandbox(requirements, app)
else:
build_logs = ""
sb = modal.Sandbox.create(app=app, timeout=60)
with sb.open("/tmp/solution.py", "w") as f:
f.write(code_snippet)
proc = sb.exec("python", "/tmp/solution.py")
proc.wait()
error = proc.returncode != 0
sb.terminate()
return {
"error": "Script failed with non-zero exit code" if error else "",
"stdout": proc.stdout.read(),
"stderr": proc.stderr.read(),
"returncode": proc.returncode,
}, build_logs