# Tests for ankigen_core/agents/security.py import pytest import asyncio import time from unittest.mock import AsyncMock, MagicMock, patch from ankigen_core.agents.security import ( RateLimitConfig, SecurityConfig, RateLimiter, SecurityValidator, SecureAgentWrapper, SecurityError, get_rate_limiter, get_security_validator, create_secure_agent, strip_html_tags, validate_api_key_format, sanitize_for_logging ) # Test RateLimitConfig def test_rate_limit_config_defaults(): """Test RateLimitConfig default values""" config = RateLimitConfig() assert config.requests_per_minute == 60 assert config.requests_per_hour == 1000 assert config.burst_limit == 10 assert config.cooldown_period == 300 def test_rate_limit_config_custom(): """Test RateLimitConfig with custom values""" config = RateLimitConfig( requests_per_minute=30, requests_per_hour=500, burst_limit=5, cooldown_period=600 ) assert config.requests_per_minute == 30 assert config.requests_per_hour == 500 assert config.burst_limit == 5 assert config.cooldown_period == 600 # Test SecurityConfig def test_security_config_defaults(): """Test SecurityConfig default values""" config = SecurityConfig() assert config.enable_input_validation is True assert config.enable_output_filtering is True assert config.enable_rate_limiting is True assert config.max_input_length == 10000 assert config.max_output_length == 50000 assert len(config.blocked_patterns) > 0 assert '.txt' in config.allowed_file_extensions def test_security_config_blocked_patterns(): """Test SecurityConfig blocked patterns""" config = SecurityConfig() # Should have common sensitive patterns patterns = config.blocked_patterns assert any('api' in pattern.lower() for pattern in patterns) assert any('secret' in pattern.lower() for pattern in patterns) assert any('password' in pattern.lower() for pattern in patterns) # Test RateLimiter @pytest.fixture def rate_limiter(): """Rate limiter with test configuration""" config = RateLimitConfig( requests_per_minute=5, requests_per_hour=50, burst_limit=3 ) return RateLimiter(config) async def test_rate_limiter_allows_requests_under_limit(rate_limiter): """Test rate limiter allows requests under limits""" identifier = "test_user" # Should allow first few requests assert await rate_limiter.check_rate_limit(identifier) is True assert await rate_limiter.check_rate_limit(identifier) is True assert await rate_limiter.check_rate_limit(identifier) is True async def test_rate_limiter_blocks_burst_limit(rate_limiter): """Test rate limiter blocks requests exceeding burst limit""" identifier = "test_user" # Use up burst limit for _ in range(3): assert await rate_limiter.check_rate_limit(identifier) is True # Next request should be blocked assert await rate_limiter.check_rate_limit(identifier) is False async def test_rate_limiter_per_minute_limit(rate_limiter): """Test rate limiter per-minute limit""" identifier = "test_user" # Mock time to control rate limiting with patch('time.time') as mock_time: current_time = 1000.0 mock_time.return_value = current_time # Use up per-minute limit for _ in range(5): assert await rate_limiter.check_rate_limit(identifier) is True # Next request should be blocked assert await rate_limiter.check_rate_limit(identifier) is False async def test_rate_limiter_different_identifiers(rate_limiter): """Test rate limiter handles different identifiers separately""" user1 = "user1" user2 = "user2" # Use up limit for user1 for _ in range(3): assert await rate_limiter.check_rate_limit(user1) is True assert await rate_limiter.check_rate_limit(user1) is False # user2 should still be allowed assert await rate_limiter.check_rate_limit(user2) is True async def test_rate_limiter_reset_time(rate_limiter): """Test rate limiter reset time calculation""" identifier = "test_user" # Use up burst limit for _ in range(3): await rate_limiter.check_rate_limit(identifier) # Should have reset time reset_time = rate_limiter.get_reset_time(identifier) assert reset_time is not None # Test SecurityValidator @pytest.fixture def security_validator(): """Security validator with test configuration""" config = SecurityConfig( max_input_length=100, max_output_length=200 ) return SecurityValidator(config) def test_security_validator_valid_input(security_validator): """Test security validator allows valid input""" valid_input = "This is a normal, safe input text." assert security_validator.validate_input(valid_input, "test") is True def test_security_validator_input_too_long(security_validator): """Test security validator rejects input that's too long""" long_input = "x" * 1000 # Exceeds max_input_length of 100 assert security_validator.validate_input(long_input, "test") is False def test_security_validator_blocked_patterns(security_validator): """Test security validator blocks dangerous patterns""" dangerous_inputs = [ "Here is my API key: sk-1234567890abcdef", "My password is secret123", "The access_token is abc123", "" ] for dangerous_input in dangerous_inputs: assert security_validator.validate_input(dangerous_input, "test") is False def test_security_validator_output_validation(security_validator): """Test security validator validates output""" safe_output = "This is a safe response with no sensitive information." assert security_validator.validate_output(safe_output, "test_agent") is True dangerous_output = "Here's your API key: sk-1234567890abcdef" assert security_validator.validate_output(dangerous_output, "test_agent") is False def test_security_validator_sanitize_input(security_validator): """Test input sanitization""" dirty_input = "Normal text" sanitized = security_validator.sanitize_input(dirty_input) assert "" clean_text = strip_html_tags(html_text) assert "

" not in clean_text assert "" not in clean_text assert "