|
|
|
|
|
import pytest |
|
import asyncio |
|
import time |
|
from unittest.mock import AsyncMock, MagicMock, patch |
|
|
|
from ankigen_core.agents.security import ( |
|
RateLimitConfig, |
|
SecurityConfig, |
|
RateLimiter, |
|
SecurityValidator, |
|
SecureAgentWrapper, |
|
SecurityError, |
|
get_rate_limiter, |
|
get_security_validator, |
|
create_secure_agent, |
|
strip_html_tags, |
|
validate_api_key_format, |
|
sanitize_for_logging |
|
) |
|
|
|
|
|
|
|
def test_rate_limit_config_defaults(): |
|
"""Test RateLimitConfig default values""" |
|
config = RateLimitConfig() |
|
|
|
assert config.requests_per_minute == 60 |
|
assert config.requests_per_hour == 1000 |
|
assert config.burst_limit == 10 |
|
assert config.cooldown_period == 300 |
|
|
|
|
|
def test_rate_limit_config_custom(): |
|
"""Test RateLimitConfig with custom values""" |
|
config = RateLimitConfig( |
|
requests_per_minute=30, |
|
requests_per_hour=500, |
|
burst_limit=5, |
|
cooldown_period=600 |
|
) |
|
|
|
assert config.requests_per_minute == 30 |
|
assert config.requests_per_hour == 500 |
|
assert config.burst_limit == 5 |
|
assert config.cooldown_period == 600 |
|
|
|
|
|
|
|
def test_security_config_defaults(): |
|
"""Test SecurityConfig default values""" |
|
config = SecurityConfig() |
|
|
|
assert config.enable_input_validation is True |
|
assert config.enable_output_filtering is True |
|
assert config.enable_rate_limiting is True |
|
assert config.max_input_length == 10000 |
|
assert config.max_output_length == 50000 |
|
assert len(config.blocked_patterns) > 0 |
|
assert '.txt' in config.allowed_file_extensions |
|
|
|
|
|
def test_security_config_blocked_patterns(): |
|
"""Test SecurityConfig blocked patterns""" |
|
config = SecurityConfig() |
|
|
|
|
|
patterns = config.blocked_patterns |
|
assert any('api' in pattern.lower() for pattern in patterns) |
|
assert any('secret' in pattern.lower() for pattern in patterns) |
|
assert any('password' in pattern.lower() for pattern in patterns) |
|
|
|
|
|
|
|
@pytest.fixture |
|
def rate_limiter(): |
|
"""Rate limiter with test configuration""" |
|
config = RateLimitConfig( |
|
requests_per_minute=5, |
|
requests_per_hour=50, |
|
burst_limit=3 |
|
) |
|
return RateLimiter(config) |
|
|
|
|
|
async def test_rate_limiter_allows_requests_under_limit(rate_limiter): |
|
"""Test rate limiter allows requests under limits""" |
|
identifier = "test_user" |
|
|
|
|
|
assert await rate_limiter.check_rate_limit(identifier) is True |
|
assert await rate_limiter.check_rate_limit(identifier) is True |
|
assert await rate_limiter.check_rate_limit(identifier) is True |
|
|
|
|
|
async def test_rate_limiter_blocks_burst_limit(rate_limiter): |
|
"""Test rate limiter blocks requests exceeding burst limit""" |
|
identifier = "test_user" |
|
|
|
|
|
for _ in range(3): |
|
assert await rate_limiter.check_rate_limit(identifier) is True |
|
|
|
|
|
assert await rate_limiter.check_rate_limit(identifier) is False |
|
|
|
|
|
async def test_rate_limiter_per_minute_limit(rate_limiter): |
|
"""Test rate limiter per-minute limit""" |
|
identifier = "test_user" |
|
|
|
|
|
with patch('time.time') as mock_time: |
|
current_time = 1000.0 |
|
mock_time.return_value = current_time |
|
|
|
|
|
for _ in range(5): |
|
assert await rate_limiter.check_rate_limit(identifier) is True |
|
|
|
|
|
assert await rate_limiter.check_rate_limit(identifier) is False |
|
|
|
|
|
async def test_rate_limiter_different_identifiers(rate_limiter): |
|
"""Test rate limiter handles different identifiers separately""" |
|
user1 = "user1" |
|
user2 = "user2" |
|
|
|
|
|
for _ in range(3): |
|
assert await rate_limiter.check_rate_limit(user1) is True |
|
|
|
assert await rate_limiter.check_rate_limit(user1) is False |
|
|
|
|
|
assert await rate_limiter.check_rate_limit(user2) is True |
|
|
|
|
|
async def test_rate_limiter_reset_time(rate_limiter): |
|
"""Test rate limiter reset time calculation""" |
|
identifier = "test_user" |
|
|
|
|
|
for _ in range(3): |
|
await rate_limiter.check_rate_limit(identifier) |
|
|
|
|
|
reset_time = rate_limiter.get_reset_time(identifier) |
|
assert reset_time is not None |
|
|
|
|
|
|
|
@pytest.fixture |
|
def security_validator(): |
|
"""Security validator with test configuration""" |
|
config = SecurityConfig( |
|
max_input_length=100, |
|
max_output_length=200 |
|
) |
|
return SecurityValidator(config) |
|
|
|
|
|
def test_security_validator_valid_input(security_validator): |
|
"""Test security validator allows valid input""" |
|
valid_input = "This is a normal, safe input text." |
|
assert security_validator.validate_input(valid_input, "test") is True |
|
|
|
|
|
def test_security_validator_input_too_long(security_validator): |
|
"""Test security validator rejects input that's too long""" |
|
long_input = "x" * 1000 |
|
assert security_validator.validate_input(long_input, "test") is False |
|
|
|
|
|
def test_security_validator_blocked_patterns(security_validator): |
|
"""Test security validator blocks dangerous patterns""" |
|
dangerous_inputs = [ |
|
"Here is my API key: sk-1234567890abcdef", |
|
"My password is secret123", |
|
"The access_token is abc123", |
|
"<script>alert('xss')</script>" |
|
] |
|
|
|
for dangerous_input in dangerous_inputs: |
|
assert security_validator.validate_input(dangerous_input, "test") is False |
|
|
|
|
|
def test_security_validator_output_validation(security_validator): |
|
"""Test security validator validates output""" |
|
safe_output = "This is a safe response with no sensitive information." |
|
assert security_validator.validate_output(safe_output, "test_agent") is True |
|
|
|
dangerous_output = "Here's your API key: sk-1234567890abcdef" |
|
assert security_validator.validate_output(dangerous_output, "test_agent") is False |
|
|
|
|
|
def test_security_validator_sanitize_input(security_validator): |
|
"""Test input sanitization""" |
|
dirty_input = "<script>alert('xss')</script>Normal text" |
|
sanitized = security_validator.sanitize_input(dirty_input) |
|
|
|
assert "<script>" not in sanitized |
|
assert "Normal text" in sanitized |
|
|
|
|
|
def test_security_validator_sanitize_output(security_validator): |
|
"""Test output sanitization""" |
|
output_with_secrets = "Response with API key sk-1234567890abcdef" |
|
sanitized = security_validator.sanitize_output(output_with_secrets) |
|
|
|
assert "sk-1234567890abcdef" not in sanitized |
|
assert "[REDACTED]" in sanitized |
|
|
|
|
|
def test_security_validator_disabled_validation(): |
|
"""Test validator with validation disabled""" |
|
config = SecurityConfig( |
|
enable_input_validation=False, |
|
enable_output_filtering=False |
|
) |
|
validator = SecurityValidator(config) |
|
|
|
|
|
assert validator.validate_input("api_key: sk-123", "test") is True |
|
assert validator.validate_output("secret: password", "test") is True |
|
|
|
|
|
|
|
@pytest.fixture |
|
def mock_base_agent(): |
|
"""Mock base agent for testing""" |
|
agent = MagicMock() |
|
agent.config = {"name": "test_agent"} |
|
agent.execute = AsyncMock(return_value="test response") |
|
return agent |
|
|
|
|
|
@pytest.fixture |
|
def secure_agent_wrapper(mock_base_agent): |
|
"""Secure agent wrapper for testing""" |
|
rate_limiter = RateLimiter(RateLimitConfig(burst_limit=2)) |
|
validator = SecurityValidator(SecurityConfig()) |
|
return SecureAgentWrapper(mock_base_agent, rate_limiter, validator) |
|
|
|
|
|
async def test_secure_agent_wrapper_successful_execution(secure_agent_wrapper, mock_base_agent): |
|
"""Test successful secure execution""" |
|
result = await secure_agent_wrapper.secure_execute("Safe input") |
|
|
|
assert result == "test response" |
|
mock_base_agent.execute.assert_called_once() |
|
|
|
|
|
async def test_secure_agent_wrapper_rate_limit_exceeded(secure_agent_wrapper): |
|
"""Test rate limit exceeded""" |
|
|
|
await secure_agent_wrapper.secure_execute("input1") |
|
await secure_agent_wrapper.secure_execute("input2") |
|
|
|
|
|
with pytest.raises(SecurityError, match="Rate limit exceeded"): |
|
await secure_agent_wrapper.secure_execute("input3") |
|
|
|
|
|
async def test_secure_agent_wrapper_input_validation_failed(): |
|
"""Test input validation failure""" |
|
rate_limiter = RateLimiter(RateLimitConfig()) |
|
validator = SecurityValidator(SecurityConfig()) |
|
mock_agent = MagicMock() |
|
wrapper = SecureAgentWrapper(mock_agent, rate_limiter, validator) |
|
|
|
|
|
with pytest.raises(SecurityError, match="Input validation failed"): |
|
await wrapper.secure_execute("API key: sk-1234567890abcdef") |
|
|
|
|
|
async def test_secure_agent_wrapper_output_validation_failed(): |
|
"""Test output validation failure""" |
|
rate_limiter = RateLimiter(RateLimitConfig()) |
|
validator = SecurityValidator(SecurityConfig()) |
|
|
|
mock_agent = MagicMock() |
|
mock_agent.execute = AsyncMock(return_value="Response with API key: sk-1234567890abcdef") |
|
|
|
wrapper = SecureAgentWrapper(mock_agent, rate_limiter, validator) |
|
|
|
with pytest.raises(SecurityError, match="Output validation failed"): |
|
await wrapper.secure_execute("Safe input") |
|
|
|
|
|
|
|
def test_strip_html_tags(): |
|
"""Test HTML tag stripping""" |
|
html_text = "<p>Hello <b>World</b>!</p><script>alert('xss')</script>" |
|
clean_text = strip_html_tags(html_text) |
|
|
|
assert "<p>" not in clean_text |
|
assert "<b>" not in clean_text |
|
assert "<script>" not in clean_text |
|
assert "Hello World!" in clean_text |
|
|
|
|
|
def test_validate_api_key_format(): |
|
"""Test API key format validation""" |
|
|
|
assert validate_api_key_format("sk-1234567890abcdef1234567890abcdef") is True |
|
|
|
|
|
assert validate_api_key_format("") is False |
|
assert validate_api_key_format("invalid") is False |
|
assert validate_api_key_format("sk-test") is False |
|
assert validate_api_key_format("sk-fake1234567890abcdef") is False |
|
|
|
|
|
def test_sanitize_for_logging(): |
|
"""Test log sanitization""" |
|
sensitive_text = "User input with API key sk-1234567890abcdef" |
|
sanitized = sanitize_for_logging(sensitive_text, max_length=50) |
|
|
|
assert "sk-1234567890abcdef" not in sanitized |
|
assert len(sanitized) <= 50 + 20 |
|
|
|
|
|
|
|
def test_get_rate_limiter(): |
|
"""Test global rate limiter getter""" |
|
limiter1 = get_rate_limiter() |
|
limiter2 = get_rate_limiter() |
|
|
|
|
|
assert limiter1 is limiter2 |
|
|
|
|
|
def test_get_security_validator(): |
|
"""Test global security validator getter""" |
|
validator1 = get_security_validator() |
|
validator2 = get_security_validator() |
|
|
|
|
|
assert validator1 is validator2 |
|
|
|
|
|
def test_create_secure_agent(): |
|
"""Test secure agent creation""" |
|
mock_agent = MagicMock() |
|
secure_agent = create_secure_agent(mock_agent) |
|
|
|
assert isinstance(secure_agent, SecureAgentWrapper) |
|
assert secure_agent.base_agent is mock_agent |
|
|
|
|
|
|
|
async def test_rate_limiter_cleanup(): |
|
"""Test rate limiter cleans up old requests""" |
|
config = RateLimitConfig(requests_per_minute=10, requests_per_hour=100) |
|
limiter = RateLimiter(config) |
|
|
|
identifier = "test_user" |
|
|
|
|
|
with patch('time.time') as mock_time: |
|
|
|
mock_time.return_value = 1000.0 |
|
|
|
|
|
for _ in range(5): |
|
await limiter.check_rate_limit(identifier) |
|
|
|
|
|
mock_time.return_value = 5000.0 |
|
|
|
|
|
assert await limiter.check_rate_limit(identifier) is True |
|
|
|
|
|
assert len(limiter._requests[identifier]) == 1 |
|
|
|
|
|
def test_security_config_file_permissions(): |
|
"""Test setting secure file permissions""" |
|
import tempfile |
|
import os |
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file: |
|
tmp_path = tmp_file.name |
|
|
|
try: |
|
from ankigen_core.agents.security import set_secure_file_permissions |
|
|
|
|
|
set_secure_file_permissions(tmp_path) |
|
|
|
|
|
if hasattr(os, 'chmod'): |
|
stat_info = os.stat(tmp_path) |
|
|
|
assert stat_info.st_mode & 0o077 == 0 |
|
|
|
finally: |
|
os.unlink(tmp_path) |
|
|
|
|
|
|
|
async def test_rate_limiter_concurrent_access(): |
|
"""Test rate limiter with concurrent access""" |
|
limiter = RateLimiter(RateLimitConfig(burst_limit=5)) |
|
identifier = "concurrent_user" |
|
|
|
|
|
tasks = [limiter.check_rate_limit(identifier) for _ in range(10)] |
|
results = await asyncio.gather(*tasks) |
|
|
|
|
|
success_count = sum(1 for result in results if result) |
|
assert success_count <= 5 |
|
|
|
|
|
def test_security_validator_error_handling(): |
|
"""Test security validator error handling""" |
|
validator = SecurityValidator(SecurityConfig()) |
|
|
|
|
|
assert validator.validate_input(None, "test") is False |
|
|
|
|
|
huge_input = "x" * 1000000 |
|
assert validator.validate_input(huge_input, "test") is False |
|
|
|
|
|
async def test_secure_agent_wrapper_base_agent_error(): |
|
"""Test secure agent wrapper handles base agent errors""" |
|
rate_limiter = RateLimiter(RateLimitConfig()) |
|
validator = SecurityValidator(SecurityConfig()) |
|
|
|
mock_agent = MagicMock() |
|
mock_agent.config = {"name": "test_agent"} |
|
mock_agent.execute = AsyncMock(side_effect=Exception("Base agent failed")) |
|
|
|
wrapper = SecureAgentWrapper(mock_agent, rate_limiter, validator) |
|
|
|
with pytest.raises(Exception, match="Base agent failed"): |
|
await wrapper.secure_execute("Safe input") |