|
""" |
|
Custom Shell Toolkit with Base Directory Support |
|
|
|
This toolkit provides shell command execution constrained to a specific base directory, |
|
preventing agents from navigating outside their assigned working directory. |
|
""" |
|
|
|
import os |
|
import subprocess |
|
from pathlib import Path |
|
from typing import List, Optional |
|
from agno.tools import Toolkit |
|
from agno.utils.log import logger |
|
|
|
|
|
class RestrictedShellTools(Toolkit): |
|
""" |
|
Shell toolkit that restricts command execution to a specific base directory. |
|
|
|
This ensures agents cannot navigate outside their assigned working directory, |
|
solving the issue of files being saved in wrong locations. |
|
""" |
|
|
|
def __init__(self, base_dir: Optional[Path] = None, **kwargs): |
|
""" |
|
Initialize the restricted shell toolkit. |
|
|
|
Args: |
|
base_dir: Base directory to constrain all shell operations to |
|
**kwargs: Additional arguments passed to parent Toolkit |
|
""" |
|
self.base_dir = Path(base_dir) if base_dir else Path.cwd() |
|
|
|
|
|
self.base_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
super().__init__( |
|
name="restricted_shell_tools", |
|
tools=[self.run_shell_command], |
|
**kwargs |
|
) |
|
|
|
logger.info(f"RestrictedShellTools initialized with base_dir: {self.base_dir}") |
|
|
|
def run_shell_command(self, command: str, timeout: int = 30) -> str: |
|
""" |
|
Runs a shell command in the constrained base directory. |
|
|
|
Args: |
|
command (str): The shell command to execute |
|
timeout (int): Maximum execution time in seconds |
|
|
|
Returns: |
|
str: The output of the command or error message |
|
""" |
|
try: |
|
|
|
logger.info(f"Executing shell command in {self.base_dir}: {command}") |
|
|
|
|
|
original_cwd = os.getcwd() |
|
|
|
try: |
|
|
|
os.chdir(self.base_dir) |
|
|
|
|
|
result = subprocess.run( |
|
command, |
|
shell=True, |
|
capture_output=True, |
|
text=True, |
|
timeout=timeout, |
|
cwd=str(self.base_dir) |
|
) |
|
|
|
|
|
logger.debug(f"Command executed with return code: {result.returncode}") |
|
|
|
if result.returncode != 0: |
|
error_msg = f"Command failed with return code {result.returncode}\nSTDERR: {result.stderr}\nSTDOUT: {result.stdout}" |
|
logger.warning(error_msg) |
|
return error_msg |
|
|
|
|
|
output = result.stdout.strip() |
|
logger.debug(f"Command output: {output[:200]}{'...' if len(output) > 200 else ''}") |
|
return output |
|
|
|
finally: |
|
|
|
os.chdir(original_cwd) |
|
|
|
except subprocess.TimeoutExpired: |
|
error_msg = f"Command timed out after {timeout} seconds: {command}" |
|
logger.error(error_msg) |
|
return error_msg |
|
|
|
except Exception as e: |
|
error_msg = f"Error executing command '{command}': {str(e)}" |
|
logger.error(error_msg) |
|
return error_msg |
|
|
|
def get_current_directory(self) -> str: |
|
""" |
|
Returns the current base directory path. |
|
|
|
Returns: |
|
str: Absolute path of the base directory |
|
""" |
|
return str(self.base_dir.absolute()) |
|
|
|
def list_directory_contents(self) -> str: |
|
""" |
|
Lists the contents of the base directory. |
|
|
|
Returns: |
|
str: Directory listing |
|
""" |
|
return self.run_shell_command("ls -la") |
|
|
|
def check_file_exists(self, filename: str) -> str: |
|
""" |
|
Checks if a file exists in the base directory. |
|
|
|
Args: |
|
filename (str): Name of the file to check |
|
|
|
Returns: |
|
str: Result of the check |
|
""" |
|
file_path = self.base_dir / filename |
|
if file_path.exists(): |
|
return f"File '{filename}' exists in {self.base_dir}" |
|
else: |
|
return f"File '{filename}' does not exist in {self.base_dir}" |