|
|
|
|
|
import time |
|
import hashlib |
|
import re |
|
from typing import Dict, Any, Optional, List |
|
from dataclasses import dataclass, field |
|
from datetime import datetime |
|
from collections import defaultdict |
|
import asyncio |
|
|
|
from ankigen_core.logging import logger |
|
|
|
|
|
@dataclass |
|
class RateLimitConfig: |
|
"""Configuration for rate limiting""" |
|
|
|
requests_per_minute: int = 60 |
|
requests_per_hour: int = 1000 |
|
burst_limit: int = 10 |
|
cooldown_period: int = 300 |
|
|
|
|
|
@dataclass |
|
class SecurityConfig: |
|
"""Security configuration for agents""" |
|
|
|
enable_input_validation: bool = True |
|
enable_output_filtering: bool = True |
|
enable_rate_limiting: bool = True |
|
max_input_length: int = 10000 |
|
max_output_length: int = 50000 |
|
blocked_patterns: List[str] = field(default_factory=list) |
|
allowed_file_extensions: List[str] = field( |
|
default_factory=lambda: [".txt", ".md", ".json", ".yaml"] |
|
) |
|
|
|
def __post_init__(self): |
|
if not self.blocked_patterns: |
|
self.blocked_patterns = [ |
|
r"(?i)(api[_\-]?key|secret|password|token|credential)", |
|
r"(?i)(sk-[a-zA-Z0-9]{48,})", |
|
r"(?i)(access[_\-]?token)", |
|
r"(?i)(private[_\-]?key)", |
|
r"(?i)(<script\b[^<]*(?:(?!<\/script>)<[^<]*)*<\/script>)", |
|
r"(?i)(javascript:|data:|vbscript:)", |
|
] |
|
|
|
|
|
class RateLimiter: |
|
"""Rate limiter for API calls and agent executions""" |
|
|
|
def __init__(self, config: RateLimitConfig): |
|
self.config = config |
|
self._requests: Dict[str, List[float]] = defaultdict(list) |
|
self._locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) |
|
|
|
async def check_rate_limit(self, identifier: str) -> bool: |
|
"""Check if request is within rate limits""" |
|
async with self._locks[identifier]: |
|
now = time.time() |
|
|
|
|
|
self._requests[identifier] = [ |
|
req_time |
|
for req_time in self._requests[identifier] |
|
if now - req_time < 3600 |
|
] |
|
|
|
recent_requests = self._requests[identifier] |
|
|
|
|
|
last_minute = [req for req in recent_requests if now - req < 60] |
|
if len(last_minute) >= self.config.burst_limit: |
|
logger.warning(f"Burst limit exceeded for {identifier}") |
|
return False |
|
|
|
|
|
if len(last_minute) >= self.config.requests_per_minute: |
|
logger.warning(f"Per-minute rate limit exceeded for {identifier}") |
|
return False |
|
|
|
|
|
if len(recent_requests) >= self.config.requests_per_hour: |
|
logger.warning(f"Per-hour rate limit exceeded for {identifier}") |
|
return False |
|
|
|
|
|
self._requests[identifier].append(now) |
|
return True |
|
|
|
def get_reset_time(self, identifier: str) -> Optional[datetime]: |
|
"""Get when rate limits will reset for identifier""" |
|
if identifier not in self._requests: |
|
return None |
|
|
|
now = time.time() |
|
recent_requests = [req for req in self._requests[identifier] if now - req < 60] |
|
|
|
if len(recent_requests) >= self.config.requests_per_minute: |
|
oldest_request = min(recent_requests) |
|
return datetime.fromtimestamp(oldest_request + 60) |
|
|
|
return None |
|
|
|
|
|
class SecurityValidator: |
|
"""Security validator for agent inputs and outputs""" |
|
|
|
def __init__(self, config: SecurityConfig): |
|
self.config = config |
|
self._blocked_patterns = [ |
|
re.compile(pattern) for pattern in config.blocked_patterns |
|
] |
|
|
|
def validate_input(self, input_text: str, source: str = "unknown") -> bool: |
|
"""Validate input for security issues""" |
|
if not self.config.enable_input_validation: |
|
return True |
|
|
|
try: |
|
|
|
if len(input_text) > self.config.max_input_length: |
|
logger.warning(f"Input too long from {source}: {len(input_text)} chars") |
|
return False |
|
|
|
|
|
for pattern in self._blocked_patterns: |
|
if pattern.search(input_text): |
|
logger.warning(f"Blocked pattern detected in input from {source}") |
|
return False |
|
|
|
|
|
if self._contains_suspicious_content(input_text): |
|
logger.warning(f"Suspicious content detected in input from {source}") |
|
return False |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error validating input from {source}: {e}") |
|
return False |
|
|
|
def validate_output(self, output_text: str, agent_name: str = "unknown") -> bool: |
|
"""Validate output for security issues""" |
|
if not self.config.enable_output_filtering: |
|
return True |
|
|
|
try: |
|
|
|
if len(output_text) > self.config.max_output_length: |
|
logger.warning( |
|
f"Output too long from {agent_name}: {len(output_text)} chars" |
|
) |
|
return False |
|
|
|
|
|
for pattern in self._blocked_patterns: |
|
if pattern.search(output_text): |
|
logger.warning( |
|
f"Potential data leak detected in output from {agent_name}" |
|
) |
|
return False |
|
|
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error validating output from {agent_name}: {e}") |
|
return False |
|
|
|
def sanitize_input(self, input_text: str) -> str: |
|
"""Sanitize input by removing potentially dangerous content""" |
|
try: |
|
|
|
sanitized = re.sub(r"<[^>]+>", "", input_text) |
|
|
|
|
|
sanitized = re.sub( |
|
r"(?i)(javascript:|data:|vbscript:)[^\s]*", "[URL_REMOVED]", sanitized |
|
) |
|
|
|
|
|
if len(sanitized) > self.config.max_input_length: |
|
sanitized = sanitized[: self.config.max_input_length] + "...[TRUNCATED]" |
|
|
|
return sanitized |
|
|
|
except Exception as e: |
|
logger.error(f"Error sanitizing input: {e}") |
|
return input_text[:1000] |
|
|
|
def sanitize_output(self, output_text: str) -> str: |
|
"""Sanitize output by removing sensitive information""" |
|
try: |
|
sanitized = output_text |
|
|
|
|
|
for pattern in self._blocked_patterns: |
|
sanitized = pattern.sub("[REDACTED]", sanitized) |
|
|
|
|
|
if len(sanitized) > self.config.max_output_length: |
|
sanitized = ( |
|
sanitized[: self.config.max_output_length] + "...[TRUNCATED]" |
|
) |
|
|
|
return sanitized |
|
|
|
except Exception as e: |
|
logger.error(f"Error sanitizing output: {e}") |
|
return output_text[:5000] |
|
|
|
def _contains_suspicious_content(self, text: str) -> bool: |
|
"""Check for suspicious content patterns""" |
|
suspicious_patterns = [ |
|
r"(?i)(\beval\s*\()", |
|
r"(?i)(\bexec\s*\()", |
|
r"(?i)(__import__)", |
|
r"(?i)(subprocess|os\.system)", |
|
r"(?i)(file://|ftp://)", |
|
r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b", |
|
] |
|
|
|
for pattern in suspicious_patterns: |
|
if re.search(pattern, text): |
|
return True |
|
|
|
return False |
|
|
|
|
|
class SecureAgentWrapper: |
|
"""Secure wrapper for agent execution with rate limiting and validation""" |
|
|
|
def __init__( |
|
self, base_agent, rate_limiter: RateLimiter, validator: SecurityValidator |
|
): |
|
self.base_agent = base_agent |
|
self.rate_limiter = rate_limiter |
|
self.validator = validator |
|
self._identifier = self._generate_identifier() |
|
|
|
def _generate_identifier(self) -> str: |
|
"""Generate unique identifier for rate limiting""" |
|
agent_name = getattr(self.base_agent, "config", {}).get("name", "unknown") |
|
|
|
return hashlib.md5(f"{agent_name}_{id(self.base_agent)}".encode()).hexdigest()[ |
|
:16 |
|
] |
|
|
|
async def secure_execute( |
|
self, user_input: str, context: Dict[str, Any] = None |
|
) -> Any: |
|
"""Execute agent with security checks and rate limiting""" |
|
|
|
|
|
if not await self.rate_limiter.check_rate_limit(self._identifier): |
|
reset_time = self.rate_limiter.get_reset_time(self._identifier) |
|
raise SecurityError(f"Rate limit exceeded. Reset at: {reset_time}") |
|
|
|
|
|
if not self.validator.validate_input(user_input, self._identifier): |
|
raise SecurityError("Input validation failed") |
|
|
|
|
|
sanitized_input = self.validator.sanitize_input(user_input) |
|
|
|
try: |
|
|
|
result = await self.base_agent.execute(sanitized_input, context) |
|
|
|
|
|
if isinstance(result, str): |
|
if not self.validator.validate_output(result, self._identifier): |
|
raise SecurityError("Output validation failed") |
|
|
|
|
|
result = self.validator.sanitize_output(result) |
|
|
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Secure execution failed for {self._identifier}: {e}") |
|
raise |
|
|
|
|
|
class SecurityError(Exception): |
|
"""Custom exception for security-related errors""" |
|
|
|
pass |
|
|
|
|
|
|
|
_global_rate_limiter: Optional[RateLimiter] = None |
|
_global_validator: Optional[SecurityValidator] = None |
|
|
|
|
|
def get_rate_limiter(config: Optional[RateLimitConfig] = None) -> RateLimiter: |
|
"""Get global rate limiter instance""" |
|
global _global_rate_limiter |
|
if _global_rate_limiter is None: |
|
_global_rate_limiter = RateLimiter(config or RateLimitConfig()) |
|
return _global_rate_limiter |
|
|
|
|
|
def get_security_validator( |
|
config: Optional[SecurityConfig] = None, |
|
) -> SecurityValidator: |
|
"""Get global security validator instance""" |
|
global _global_validator |
|
if _global_validator is None: |
|
_global_validator = SecurityValidator(config or SecurityConfig()) |
|
return _global_validator |
|
|
|
|
|
def create_secure_agent( |
|
base_agent, |
|
rate_config: Optional[RateLimitConfig] = None, |
|
security_config: Optional[SecurityConfig] = None, |
|
) -> SecureAgentWrapper: |
|
"""Create a secure wrapper for an agent""" |
|
rate_limiter = get_rate_limiter(rate_config) |
|
validator = get_security_validator(security_config) |
|
return SecureAgentWrapper(base_agent, rate_limiter, validator) |
|
|
|
|
|
|
|
def set_secure_file_permissions(file_path: str): |
|
"""Set secure permissions for configuration files""" |
|
try: |
|
import os |
|
import stat |
|
|
|
|
|
os.chmod(file_path, stat.S_IRUSR | stat.S_IWUSR) |
|
logger.info(f"Set secure permissions for {file_path}") |
|
|
|
except Exception as e: |
|
logger.warning(f"Could not set secure permissions for {file_path}: {e}") |
|
|
|
|
|
|
|
def strip_html_tags(text: str) -> str: |
|
"""Strip HTML tags from text (improved version)""" |
|
import html |
|
|
|
|
|
text = html.unescape(text) |
|
|
|
|
|
text = re.sub(r"<[^>]+>", "", text) |
|
|
|
|
|
text = re.sub(r"&[a-zA-Z0-9#]+;", "", text) |
|
|
|
|
|
text = re.sub(r"\s+", " ", text).strip() |
|
|
|
return text |
|
|
|
|
|
def validate_api_key_format(api_key: str) -> bool: |
|
"""Validate OpenAI API key format without logging it""" |
|
if not api_key: |
|
return False |
|
|
|
|
|
if not api_key.startswith("sk-"): |
|
return False |
|
|
|
if len(api_key) < 20: |
|
return False |
|
|
|
|
|
fake_patterns = ["test", "fake", "demo", "example", "placeholder"] |
|
lower_key = api_key.lower() |
|
if any(pattern in lower_key for pattern in fake_patterns): |
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def sanitize_for_logging(text: str, max_length: int = 100) -> str: |
|
"""Sanitize text for safe logging""" |
|
if not text: |
|
return "[EMPTY]" |
|
|
|
|
|
validator = get_security_validator() |
|
sanitized = validator.sanitize_output(text) |
|
|
|
|
|
if len(sanitized) > max_length: |
|
sanitized = sanitized[:max_length] + "...[TRUNCATED]" |
|
|
|
return sanitized |
|
|