File size: 4,811 Bytes
8b21729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Custom Shell Toolkit with Base Directory Support

This toolkit provides shell command execution constrained to a specific base directory,
preventing agents from navigating outside their assigned working directory.
"""

import os
import subprocess
from pathlib import Path
from typing import List, Optional
from agno.tools import Toolkit
from agno.utils.log import logger


class RestrictedShellTools(Toolkit):
    """
    Shell toolkit that restricts command execution to a specific base directory.
    
    This ensures agents cannot navigate outside their assigned working directory,
    solving the issue of files being saved in wrong locations.
    """
    
    def __init__(self, base_dir: Optional[Path] = None, **kwargs):
        """
        Initialize the restricted shell toolkit.
        
        Args:
            base_dir: Base directory to constrain all shell operations to
            **kwargs: Additional arguments passed to parent Toolkit
        """
        self.base_dir = Path(base_dir) if base_dir else Path.cwd()
        
        # Ensure base directory exists
        self.base_dir.mkdir(parents=True, exist_ok=True)
        
        # Initialize toolkit with our shell command function
        super().__init__(
            name="restricted_shell_tools", 
            tools=[self.run_shell_command], 
            **kwargs
        )
        
        logger.info(f"RestrictedShellTools initialized with base_dir: {self.base_dir}")
    
    def run_shell_command(self, command: str, timeout: int = 30) -> str:
        """
        Runs a shell command in the constrained base directory.
        
        Args:
            command (str): The shell command to execute
            timeout (int): Maximum execution time in seconds
            
        Returns:
            str: The output of the command or error message
        """
        try:
            # Log the command and working directory
            logger.info(f"Executing shell command in {self.base_dir}: {command}")
            
            # Ensure we're working in the correct directory
            original_cwd = os.getcwd()
            
            try:
                # Change to base directory before executing command
                os.chdir(self.base_dir)
                
                # Execute the command in the base directory
                result = subprocess.run(
                    command,
                    shell=True,
                    capture_output=True,
                    text=True,
                    timeout=timeout,
                    cwd=str(self.base_dir)  # Explicitly set working directory
                )
                
                # Log execution details
                logger.debug(f"Command executed with return code: {result.returncode}")
                
                if result.returncode != 0:
                    error_msg = f"Command failed with return code {result.returncode}\nSTDERR: {result.stderr}\nSTDOUT: {result.stdout}"
                    logger.warning(error_msg)
                    return error_msg
                
                # Return successful output
                output = result.stdout.strip()
                logger.debug(f"Command output: {output[:200]}{'...' if len(output) > 200 else ''}")
                return output
                
            finally:
                # Always restore original working directory
                os.chdir(original_cwd)
                
        except subprocess.TimeoutExpired:
            error_msg = f"Command timed out after {timeout} seconds: {command}"
            logger.error(error_msg)
            return error_msg
            
        except Exception as e:
            error_msg = f"Error executing command '{command}': {str(e)}"
            logger.error(error_msg)
            return error_msg
    
    def get_current_directory(self) -> str:
        """
        Returns the current base directory path.
        
        Returns:
            str: Absolute path of the base directory
        """
        return str(self.base_dir.absolute())
    
    def list_directory_contents(self) -> str:
        """
        Lists the contents of the base directory.
        
        Returns:
            str: Directory listing
        """
        return self.run_shell_command("ls -la")
    
    def check_file_exists(self, filename: str) -> str:
        """
        Checks if a file exists in the base directory.
        
        Args:
            filename (str): Name of the file to check
            
        Returns:
            str: Result of the check
        """
        file_path = self.base_dir / filename
        if file_path.exists():
            return f"File '{filename}' exists in {self.base_dir}"
        else:
            return f"File '{filename}' does not exist in {self.base_dir}"