# app.py (Strategic Agent Service for Hugging Face Spaces - CPU Only, Preload All Models, No ngrok) import os import json import logging import numpy as np import requests from fastapi import FastAPI, HTTPException, Depends, status from pydantic import BaseModel, Field, constr from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from datetime import datetime import firebase_admin from firebase_admin import credentials, firestore from bs4 import BeautifulSoup import re from typing import List, Dict, Optional, Tuple from cachetools import TTLCache import gc from llama_cpp import Llama import asyncio import nest_asyncio from fastapi.responses import JSONResponse # Added for explicit JSONResponse # Apply nest_asyncio to allow running asyncio.run() in environments with existing event loops nest_asyncio.apply() # --- Configuration --- # Directory to store downloaded GGUF models within Hugging Face Space's writable space DOWNLOAD_DIR = "./downloaded_models/" # Changed to a local directory within the Space os.makedirs(DOWNLOAD_DIR, exist_ok=True) # Predefined Hugging Face GGUF model URLs for dynamic loading HUGGINGFACE_MODELS = [ { "name": "Foundation-Sec-8B-Q8_0", "url": "https://huggingface.co/fdtn-ai/Foundation-Sec-8B-Q8_0-GGUF/resolve/main/foundation-sec-8b-q8_0.gguf" }, { "name": "Lily-Cybersecurity-7B-v0.2-Q8_0", "url": "https://huggingface.co/Nekuromento/Lily-Cybersecurity-7B-v0.2-Q8_0-GGUF/resolve/main/lily-cybersecurity-7b-v0.2-q8_0.gguf" }, { "name": "SecurityLLM-GGUF (sarvam-m-q8_0)", "url": "https://huggingface.co/QuantFactory/SecurityLLM-GGUF/resolve/main/sarvam-m-q8_0.gguf" } ] DATA_DIR = "./data" # Local data for Hugging Face Space DEEP_SEARCH_CACHE_TTL = 3600 # --- ngrok Configuration (Removed) --- # NGROK_AUTH_TOKEN and NGROK_STRATEGIC_AGENT_TUNNEL_URL are removed # --- Logging Setup --- logging.basicConfig( level=logging.DEBUG, # Changed from INFO to DEBUG format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) logger.info("Logging initialized with DEBUG level.") # Initialize FastAPI app app = FastAPI( title="Hugging Face Strategic Agent Service", description="Provides knowledge base access and strategic reasoning for the pentest agent on Hugging Face Spaces.", version="1.0.0" ) # Initialize Firebase firebase_creds_path = os.getenv("FIREBASE_CREDS_PATH", "cred.json") db = None if not firebase_admin._apps: try: if os.path.exists(firebase_creds_path): cred = credentials.Certificate(firebase_creds_path) firebase_admin.initialize_app(cred) db = firestore.client() logger.info("Firebase initialized successfully.") else: logger.warning(f"Firebase credentials file not found at {firebase_creds_path}. Firebase will not be initialized.") except Exception as e: logger.error(f"Failed to initialize Firebase: {e}. Ensure FIREBASE_CREDS_PATH is set correctly and the file exists.", exc_info=True) # Global LLM instance for Strategic Agent strategic_llm: Optional[Llama] = None current_strategic_model_url: Optional[str] = None # Now tracks URL, not local path # Supported tools (Strategic Agent needs to know these for command generation) SUPPORTED_TOOLS = [ "nmap", "gobuster", "nikto", "sqlmap", "adb", "frida", "drozer", "apktool", "msfconsole", "mobsfscan", "burpsuite", "metasploit", "curl", "wget", "hydra", "john", "aircrack-ng" ] # --- Deep Search Cache --- deep_search_cache = TTLCache(maxsize=100, ttl=DEEP_SEARCH_CACHE_TTL) # --- Enhanced System Instruction (English) --- SYSTEM_INSTRUCTION = ( "You are an expert pentest agent. Strictly follow these rules:\n" "1. Output ONLY valid shell commands\n" "2. NEVER include timestamps, dates, or any text outside commands\n" "3. Never repeat previous commands\n" "4. Always verify command safety before execution\n\n" "Example valid response:\n" "nmap -sV 192.168.1.6\n\n" "Key Principles:\n" "- Never give up until the goal is achieved\n" "- Learn from failures and adapt strategies\n" "- Leverage all available knowledge and tools\n" "- Break complex tasks into smaller achievable steps\n" "- Always ensure actions are ethical and within scope\n\n" "Available Tools:\n" "- nmap: Network scanning and service detection\n" "- gobuster: Web directory brute-forcing\n" "- nikto: Web server vulnerability scanner\n" "- sqlmap: SQL injection testing\n" "- adb: Android Debug Bridge\n" "- metasploit: Exploitation framework\n\n" "Error Handling Examples:\n" "Example 1 (Command Failure):\n" " If nmap fails because host is down, try: nmap -Pn -sV 192.168.1.6\n" "Example 2 (Web Server Error):\n" " If web server returns 403, try: gobuster dir -u http://192.168.1.6 -w /usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt\n" "Example 3 (ADB Connection Failed):\n" " If ADB connection fails, try: adb kill-server && adb start-server" ) # --- Firebase Knowledge Base Integration --- class FirebaseKnowledgeBase: def __init__(self): self.collection = db.collection('knowledge_base') if db else None def query(self, goal: str, phase: str = None, limit: int = 10) -> list: if not db or not firebase_admin._apps: # Check if Firebase is initialized logger.error("Firestore client not initialized. Cannot query knowledge base.") return [] # Re-instantiate collection if it's None (e.g., if Firebase init failed initially) if not hasattr(self, 'collection') or self.collection is None: self.collection = db.collection('knowledge_base') keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device'] if phase: keywords.append(phase.lower()) try: query_ref = self.collection results = [] docs = query_ref.stream() # Use query_ref instead of self.collection directly for doc in docs: data = doc.to_dict() text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}" if any(keyword in text for keyword in keywords): results.append(data) if len(results) >= 10: # Use a fixed limit for stream break priority_order = {"high": 1, "medium": 2, "low": 3} results.sort(key=lambda x: ( priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3), x.get('metadata', {}).get('timestamp', 0) )) return results[:10] # Ensure limit is applied except Exception as e: logger.error(f"Failed to query knowledge base: {e}", exc_info=True) return [] # --- RAG Knowledge Index --- class KnowledgeIndex: def __init__(self, model_name="all-MiniLM-L6-v2"): self.model = SentenceTransformer( model_name, cache_folder=os.path.join(DATA_DIR, "hf_cache") # Use local data dir for cache ) self.knowledge_base = [] os.makedirs(DATA_DIR, exist_ok=True) self.load_knowledge_from_file(os.path.join(DATA_DIR, 'knowledge_base.json')) def load_knowledge_from_file(self, file_path): logger.debug(f"Attempting to load knowledge from file: {file_path}") if os.path.exists(file_path): try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) if not isinstance(data, list): logger.error("Knowledge base file is not a list. Please check the file format.") return for item in data: if isinstance(item, dict): text = item.get('text', '') source = item.get('source', 'local') elif isinstance(item, str): text = item source = 'local' else: logger.warning(f"Skipping unsupported item type: {type(item)}") continue if text: embedding = self.model.encode(text).tolist() self.knowledge_base.append({'text': text, 'embedding': embedding, 'source': source}) logger.info(f"Loaded {len(self.knowledge_base)} items into RAG knowledge base.") except Exception as e: logger.error(f"Error loading knowledge from {file_path}: {e}", exc_info=True) else: logger.warning(f"Knowledge base file not found: {file_path}. RAG will operate on an empty knowledge base.") try: with open(file_path, 'w', encoding='utf-8') as f: json.dump([], f) logger.info(f"Created empty knowledge base file at: {file_path}") except Exception as e: logger.error(f"Error creating empty knowledge base file at {file_path}: {e}", exc_info=True) def retrieve(self, query: str, top_k: int = 5) -> List[Dict]: if not self.knowledge_base: logger.debug("Knowledge base is empty, no RAG retrieval possible.") return [] try: query_embedding = self.model.encode(query).reshape(1, -1) embeddings = np.array([item['embedding'] for item in self.knowledge_base]) similarities = cosine_similarity(query_embedding, embeddings)[0] top_indices = similarities.argsort()[-top_k:][::-1] results = [] for i in top_indices: results.append({ "text": self.knowledge_base[i]['text'], "similarity": similarities[i], "source": self.knowledge_base[i].get('source', 'RAG') }) logger.debug(f"RAG retrieved {len(results)} results for query: '{query}'") return results except Exception as e: logger.error(f"Error during RAG retrieval for query '{query}': {e}", exc_info=True) return [] # --- Deep Search Engine --- class DeepSearchEngine: def __init__(self): self.headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" } def search_device_info(self, device_info: str, os_version: str) -> dict: logger.debug(f"Performing deep search for device: {device_info}, OS: {os_version}") results = { "device": device_info, "os_version": os_version, "vulnerabilities": [], "exploits": [], "recommendations": [] } try: cve_results = self.search_cve(device_info, os_version) results["vulnerabilities"] = cve_results exploit_results = self.search_exploits(device_info, os_version) results["exploits"] = exploit_results recommendations = self.get_security_recommendations(os_version) results["recommendations"] = recommendations logger.debug("Deep search completed.") except Exception as e: logger.error(f"Deep search failed: {e}", exc_info=True) return results def search_cve(self, device: str, os_version: str) -> list: cves = [] try: query = f"{device} {os_version} CVE" search_url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}" logger.debug(f"Searching CVE Mitre: {search_url}") response = requests.get(search_url, headers=self.headers) response.raise_for_status() # Raise an exception for HTTP errors if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') table = soup.find('div', id='TableWithRules') if table: rows = table.find_all('tr')[1:] for row in rows: cols = row.find_all('td') if len(cols) >= 2: cve_id = cols[0].get_text(strip=True) description = cols[1].get_text(strip=True) cves.append({ "cve_id": cve_id, "description": description, "source": "CVE Mitre" }) logger.debug(f"Found {len(cves)} CVEs.") return cves[:10] except Exception as e: logger.error(f"CVE search failed: {e}", exc_info=True) return [] def search_exploits(self, device: str, os_version: str) -> list: exploits = [] try: query = f"{device} {os_version}" search_url = f"https://www.exploit-db.com/search?q={query}" logger.debug(f"Searching ExploitDB: {search_url}") response = requests.get(search_url, headers=self.headers) response.raise_for_status() # Raise an exception for HTTP errors if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') cards = soup.select('.card .card-title') for card in cards: title = card.get_text(strip=True) link = card.find('a')['href'] if not link.startswith('http'): link = f"https://www.exploit-db.com{link}" exploits.append({ "title": title, "link": link, "source": "ExploitDB" }) logger.debug(f"Found {len(exploits)} exploits.") return exploits[:10] except Exception as e: logger.error(f"Exploit search failed: {e}", exc_info=True) return [] def get_security_recommendations(self, os_version: str) -> list: recommendations = [] try: logger.debug(f"Getting security recommendations for OS: {os_version}") if "android" in os_version.lower(): url = "https://source.android.com/docs/security/bulletin" response = requests.get(url, headers=self.headers) response.raise_for_status() if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') versions = soup.select('.devsite-article-body h2') for version in versions: if os_version in version.get_text(): next_ul = version.find_next('ul') if next_ul: items = next_ul.select('li') for item in items: recommendations.append(item.get_text(strip=True)) elif "ios" in os_version.lower(): url = "https://support.apple.com/en-us/HT201222" response = requests.get(url, headers=self.headers) response.raise_for_status() if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') sections = soup.select('#sections') for section in sections: if os_version in section.get_text(): items = section.select('li') for item in items: recommendations.append(item.get_text(strip=True)) logger.debug(f"Found {len(recommendations)} recommendations.") return recommendations[:5] except Exception as e: logger.error(f"Security recommendations search failed: {e}", exc_info=True) return [] def search_public_resources(self, device_info: str) -> list: resources = [] try: logger.debug(f"Searching public resources for device: {device_info}") github_url = f"https://github.com/search?q={device_info.replace(' ', '+')}+pentest" response = requests.get(github_url, headers=self.headers) response.raise_for_status() if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') repos = soup.select('.repo-list-item') for repo in repos: title = repo.select_one('.v-align-middle').get_text(strip=True) description = repo.select_one('.mb-1').get_text(strip=True) if repo.select_one('.mb-1') else "" url = f"https://github.com{repo.select_one('.v-align-middle')['href']}" resources.append({ "title": title, "description": description, "url": url, "source": "GitHub" }) forum_url = f"https://hackforums.net/search.php?action=finduserthreads&keywords={device_info.replace(' ', '+')}" response = requests.get(forum_url, headers=self.headers) response.raise_for_status() if response.status_code == 200: soup = BeautifulSoup(response.text, 'html.parser') threads = soup.select('.thread') for thread in threads: title = thread.select_one('.threadtitle').get_text(strip=True) url = f"https://hackforums.net{thread.select_one('.threadtitle a')['href']}" resources.append({ "title": title, "description": "Forum discussion", "url": url, "source": "HackForums" }) logger.debug(f"Found {len(resources)} public resources.") return resources[:10] except Exception as e: logger.error(f"Public resources search failed: {e}", exc_info=True) return [] # --- Initialize Services (Local to Strategic Agent) --- firebase_kb = FirebaseKnowledgeBase() rag_index = KnowledgeIndex() deep_search_engine = DeepSearchEngine() # --- Strategic Agent Brain (formerly SmartExecutionEngine logic) --- class StrategicAgentBrain: def __init__(self): self.llm: Optional[Llama] = None self.current_goal: Optional[str] = None self.current_phase: str = "initial_reconnaissance" self.current_plan: List[Dict] = [] self.current_phase_index: int = 0 self.identified_vulnerabilities: List[Dict] = [] self.gathered_info: List[str] = [] self.command_retry_counts: Dict[str, int] = {} self.conversation_history: List[Dict] = [] self.used_commands = set() self.execution_history = [] self.goal_achieved = False self.no_progress_count = 0 self.react_cycle_count = 0 self.loaded_model_name: Optional[str] = None # To store the name of the loaded model logger.info("StrategicAgentBrain initialized.") async def load_strategic_llm(self, model_url: str): global strategic_llm, current_strategic_model_url logger.info(f"Attempting to load strategic LLM from URL: {model_url}") # Determine local path for the model model_filename = model_url.split('/')[-1] local_model_path = os.path.join(DOWNLOAD_DIR, model_filename) if strategic_llm and current_strategic_model_url == model_url: logger.info(f"Strategic LLM model from {model_url} is already loaded.") self.llm = strategic_llm return True, f"Model '{self.loaded_model_name}' is already loaded." # If a model is currently loaded, unload it first if strategic_llm: await self.unload_strategic_llm() # Ensure model is downloaded before attempting to load if not os.path.exists(local_model_path): logger.info(f"Model not found locally. Attempting to download from {model_url} to {local_model_path}...") try: response = requests.get(model_url, stream=True) response.raise_for_status() with open(local_model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info(f"Model downloaded successfully to {local_model_path}.") except Exception as e: logger.error(f"Failed to download model from {model_url}: {e}", exc_info=True) return False, f"Failed to download model: {str(e)}" try: logger.info(f"Loading Strategic LLM model from {local_model_path}...") strategic_llm = Llama( model_path=local_model_path, n_ctx=3096, n_gpu_layers=0, # Explicitly set to 0 for CPU-only n_threads=os.cpu_count(), # Use all available CPU threads n_batch=512, verbose=False ) current_strategic_model_url = model_url self.llm = strategic_llm self.loaded_model_name = model_filename # Store the filename logger.info(f"Strategic LLM model {model_filename} loaded successfully (CPU-only).") return True, f"Model '{model_filename}' loaded successfully (CPU-only)." except Exception as e: logger.error(f"Failed to load Strategic LLM model from {local_model_path}: {e}", exc_info=True) strategic_llm = None current_strategic_model_url = None self.llm = None self.loaded_model_name = None return False, f"Failed to load model: {str(e)}" async def unload_strategic_llm(self): global strategic_llm, current_strategic_model_url if strategic_llm: logger.info("Unloading Strategic LLM model...") del strategic_llm strategic_llm = None current_strategic_model_url = None self.llm = None self.loaded_model_name = None gc.collect() logger.info("Strategic LLM model unloaded.") def _get_rag_context(self, query: str) -> str: results = rag_index.retrieve(query) if not results: return "" rag_context = "Relevant Knowledge for Current Context:\n" for i, result in enumerate(results): text = result.get('text', '') or result.get('completion', '') source = result.get('source', 'RAG') rag_context += f"{i+1}. [{source}] {text}\n" return rag_context def _get_firebase_knowledge(self, goal: str, phase: str = None) -> str: if not db or not firebase_admin._apps: # Check if Firebase is initialized logger.error("Firestore client not initialized. Cannot query knowledge base.") return "" # Re-instantiate collection if it's None (e.g., if Firebase init failed initially) if not hasattr(self, 'collection') or self.collection is None: self.collection = db.collection('knowledge_base') keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device'] if phase: keywords.append(phase.lower()) try: query_ref = self.collection results = [] docs = query_ref.stream() # Use query_ref instead of self.collection directly for doc in docs: data = doc.to_dict() text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}" if any(keyword in text for keyword in keywords): results.append(data) if len(results) >= 10: # Use a fixed limit for stream break priority_order = {"high": 1, "medium": 2, "low": 3} results.sort(key=lambda x: ( priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3), x.get('metadata', {}).get('timestamp', 0) )) return results[:10] # Ensure limit is applied except Exception as e: logger.error(f"Failed to query knowledge base: {e}", exc_info=True) return "" def extract_device_info(self) -> str: for info in self.gathered_info: if "model" in info.lower() or "device" in info.lower(): match = re.search(r'(?:model|device)\s*[:=]\s*([^\n]+)', info, re.IGNORECASE) if match: return match.group(1).strip() ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") return ip_match.group(0) if ip_match else "Unknown Device" def extract_os_version(self) -> str: for info in self.gathered_info: if "android" in info.lower() or "ios" in info.lower() or "os" in info.lower(): android_match = re.search(r'android\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE) if android_match: return f"Android {android_match.group(1)}" ios_match = re.search(r'ios\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE) if ios_match: return f"iOS {ios_match.group(1)}" linux_match = re.search(r'linux\s+kernel\s+(\d+\.\d+\.\d+)', info, re.IGNORECASE) if linux_match: return f"Linux {linux_match.group(1)}" return "Unknown OS Version" def format_deep_search_results(self, results: dict) -> str: context = "Deep Search Results:\n" context += f"Device: {results.get('device', 'Unknown')}\n" context += f"OS Version: {results.get('os_version', 'Unknown')}\n\n" if results.get('vulnerabilities'): context += "Discovered Vulnerabilities:\n" for i, vuln in enumerate(results['vulnerabilities'][:5], 1): context += f"{i}. {vuln.get('cve_id', 'CVE-XXXX-XXXX')}: {vuln.get('description', 'No description')}\n" context += "\n" if results.get('exploits'): context += "Available Exploits:\n" for i, exploit in enumerate(results['exploits'][:5], 1): context += f"{i}. {exploit.get('title', 'Untitled exploit')} [Source: {exploit.get('source', 'Unknown')}]\n" context += "\n" if results.get('recommendations'): context += "Security Recommendations:\n" for i, rec in enumerate(results['recommendations'][:3], 1): context += f"{i}. {rec}\n" context += "\n" if results.get('public_resources'): context += "Public Resources:\n" for i, res in enumerate(results['public_resources'][:3], 1): context += f"{i}. {res.get('title', 'Untitled resource')} [Source: {res.get('source', 'Unknown')}]\n" return context def generate_deep_search_prompt(self, context: str) -> str: return f""" You are an expert pentester. Below are deep search results for the target device. Use this information to generate the next penetration testing command.{context} Current Goal: {self.current_goal} Current Phase: {self.current_phase} Recent Command History:{', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'} Based on this information, what is the SINGLE MOST EFFECTIVE shell command to execute next? Focus on exploiting the most critical vulnerabilities or gathering more information. Response Format: Command: """ def _generate_llm_prompt(self) -> str: rag_context = self._get_rag_context(f"{self.current_goal} {self.current_phase}") firebase_knowledge = self._get_firebase_knowledge(self.current_goal, self.current_phase) history_context = "\n".join( f"{entry['role']}: {entry['content']}" for entry in self.conversation_history[-2:] ) execution_history = "\n".join( f"Command: {res['command']}\nResult: {res['output'][:100]}...\nSuccess: {res['success']}" for res in self.execution_history[-2:] ) if self.execution_history else "No previous results." strategic_advice = self._get_rag_context(self.current_phase) # Using RAG for strategic advice too def shorten_text(text, max_length=300): if len(text) > max_length: return text[:max_length] + "... [truncated]" return text rag_context = shorten_text(rag_context, max_length=200) firebase_knowledge = shorten_text(firebase_knowledge, max_length=200) strategic_advice = shorten_text(strategic_advice, max_length=100) history_context = shorten_text(history_context, max_length=150) execution_history = shorten_text(execution_history, max_length=500) prompt = f""" System Instructions: {SYSTEM_INSTRUCTION} Current Goal: '{self.current_goal}' Current Phase: {self.current_phase} - {self.current_plan[self.current_phase_index]['objective'] if self.current_plan and self.current_phase_index < len(self.current_plan) else 'No objective'} Based on the following knowledge and previous results, generate the SINGLE, VALID SHELL COMMAND to advance the penetration testing process. **Knowledge from External Services (RAG & Firebase):** {rag_context} {firebase_knowledge} **Previous Execution Results:** {execution_history} **Recent Conversation History:** {history_context} **Strategic Advice for Current Phase:** {strategic_advice} ***CRITICAL RULES FOR OUTPUT:*** 1. **OUTPUT ONLY THE COMMAND.** 2. **DO NOT INCLUDE ANY REASONING, THOUGHTS, EXPLANATIONS, OR ANY OTHER TEXT.** 3. The command MUST be directly executable in a Linux terminal. 4. Avoid repeating these recent commands: {', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'} 5. If the previous command failed, try a different approach or a related tool. 6. For the 'android_enumeration' phase, prioritize ADB commands. Example valid commands for initial reconnaissance of an Android phone: nmap -sV -Pn 192.168.1.14 adb devices adb connect 192.168.1.14:5555 Command: """ return prompt def _get_llm_response(self, custom_prompt: str = None) -> str: if not self.llm: logger.error("Strategic LLM instance is None. Cannot get response. Please load a model first.") target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" return f"Command: echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {target_ip}'" prompt = custom_prompt if custom_prompt else self._generate_llm_prompt() logger.info(f"Sending prompt to Strategic LLM:\n{prompt[:500]}...") try: response = self.llm( prompt, max_tokens=512, temperature=0.3, stop=["\n"] ) llm_response = response['choices'][0]['text'].strip() logger.info(f"Strategic LLM raw response: {llm_response}") if not llm_response: logger.warning("Strategic LLM returned an empty response. Using fallback command.") target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" return f"Command: nmap -sV -Pn {target_ip}" return llm_response except Exception as e: logger.error(f"Error during Strategic LLM inference: {e}", exc_info=True) logger.warning("Strategic LLM inference failed. Using fallback command.") target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" return f"Command: nmap -sV -Pn {target_ip}" def parse_llm_response(self, response: str) -> str: logger.info(f"Attempting to parse LLM response: '{response}'") command = None try: code_block = re.search(r'```(?:bash|sh)?\s*([\s\S]*?)```', response) if code_block: command = code_block.group(1).strip() logger.info(f"Command extracted from code block: '{command}'") if not command: command_match = re.search(r'^\s*Command\s*:\s*(.+)$', response, re.MULTILINE | re.IGNORECASE) if command_match: command = command_match.group(1).strip() logger.info(f"Command extracted from 'Command:' line: '{command}'") if not command: stripped_response = response.strip() if any(stripped_response.startswith(tool) for tool in SUPPORTED_TOOLS): command = stripped_response logger.info(f"Command extracted as direct supported tool command: '{command}'") if command: original_command = command command = re.sub(r'^\s*(Command|Answer|Note|Result)\s*[:.-]?\s*', '', command, flags=re.IGNORECASE).strip() logger.info(f"Cleaned command: from '{original_command}' to '{command}'") if not re.match(r'^[a-zA-Z0-9_./:;= \-\'"\s]+$', command): logger.error(f"Invalid command characters detected after cleanup: '{command}'") return None if re.search(r'(reason|thought|explanation|rationale|note|result):', command, re.IGNORECASE): logger.warning(f"Command '{command}' appears to be reasoning/explanation. Rejecting.") return None if command not in self.used_commands: self.used_commands.add(command) logger.info(f"Returning valid and new command: '{command}'") return command else: logger.warning(f"Command '{command}' already used. Skipping.") return None else: logger.warning("No valid command could be extracted from LLM response based on strict rules.") return None except Exception as e: logger.error(f"Error parsing LLM response: {e}", exc_info=True) return None def set_goal(self, goal: str): self.current_goal = goal self.goal_achieved = False self.react_cycle_count = 0 self.no_progress_count = 0 self.current_plan = self._generate_strategic_plan(goal) self.current_phase_index = 0 self.identified_vulnerabilities = [] self.gathered_info = [] self.command_retry_counts = {} self.conversation_history = [{"role": "user", "content": f"New goal set: {goal}"}] self.used_commands.clear() self.execution_history = [] self.goal_achieved = False logger.info(f"Strategic Agent Goal set: {goal}. Starting initial reconnaissance.") def _generate_strategic_plan(self, goal: str) -> List[Dict]: logger.debug(f"Generating strategic plan for goal: {goal}") plan = [] goal_lower = goal.lower() plan.append({"phase": "initial_reconnaissance", "objective": f"Perform initial reconnaissance for {goal}"}) if "web" in goal_lower or "http" in goal_lower: plan.append({"phase": "web_enumeration", "objective": "Enumerate web server for directories and files"}) plan.append({"phase": "web_vulnerability_analysis", "objective": "Analyze web vulnerabilities (SQLi, XSS, etc.)"}) plan.append({"phase": "web_exploitation", "objective": "Attempt to exploit web vulnerabilities"}) plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation activities"}) elif "android" in goal_lower or "mobile" in goal_lower or "adb" in goal_lower: plan.append({"phase": "android_enumeration", "objective": "Enumerate Android device via ADB"}) plan.append({"phase": "android_app_analysis", "objective": "Analyze Android application for vulnerabilities"}) plan.append({"phase": "android_exploitation", "objective": "Attempt to exploit Android vulnerabilities"}) plan.append({"phase": "data_extraction", "objective": "Extract sensitive data from device"}) else: plan.append({"phase": "network_scanning", "objective": "Perform detailed network scanning"}) plan.append({"phase": "service_enumeration", "objective": "Enumerate services and identify versions"}) plan.append({"phase": "vulnerability_analysis", "objective": "Analyze services for vulnerabilities"}) plan.append({"phase": "exploitation", "objective": "Attempt to exploit vulnerabilities"}) plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation (privilege escalation, data exfiltration)"}) plan.append({"phase": "reporting", "objective": "Generate pentest report"}) logger.info(f"Generated strategic plan for goal '{goal}': {plan}") return plan def evaluate_phase_completion(self) -> float: phase_commands = [cmd for cmd in self.execution_history if cmd.get('phase', '') == self.current_phase] if not phase_commands: return 0.0 successful = sum(1 for cmd in phase_commands if cmd['success']) return successful / len(phase_commands) def advance_phase(self): if self.current_phase_index < len(self.current_plan) - 1: self.current_phase_index += 1 self.current_phase = self.current_plan[self.current_phase_index]["phase"] logger.info(f"Strategic Agent advancing to new phase: {self.current_phase.replace('_', ' ').title()}") self.no_progress_count = 0 self.react_cycle_count = 0 else: self.current_phase = "completed" self.goal_achieved = True logger.info("Strategic Agent: All planned phases completed. Goal achieved!") def observe_result(self, command: str, output: str, success: bool): logger.debug(f"Strategic Agent observing result for command '{command}': Success={success}") self.execution_history.append({"command": command, "output": output, "success": success, "timestamp": datetime.now().isoformat()}) self.gathered_info.append(output) self.analyze_command_output_strategic(command, output) if not success: self.no_progress_count += 1 else: self.no_progress_count = 0 if success and self.current_phase_index < len(self.current_plan) - 1: phase_completion = self.evaluate_phase_completion() if phase_completion >= 0.8: self.advance_phase() def analyze_command_output_strategic(self, command: str, output: str): """Strategic Agent performs deeper analysis of command output for vulnerabilities.""" try: logger.debug(f"Analyzing strategic command output for: {command}") if command.startswith("nmap"): if "open" in output and "vulnerable" in output.lower(): self.ingest_vulnerability( "Potential vulnerability found in NMAP scan", "Medium", "NMAP-SCAN" ) port_matches = re.findall(r'(\d+)/tcp\s+open\s+(\S+)', output) for port, service in port_matches: self.gathered_info.append(f"Discovered open port {port} with service {service}") elif command.startswith("nikto"): if "OSVDB-" in output: vuln_matches = re.findall(r'OSVDB-\d+:\s*(.+)', output) for vuln in vuln_matches[:3]: self.ingest_vulnerability( f"Nikto vulnerability: {vuln}", "High", "NIKTO-SCAN" ) elif command.startswith("sqlmap"): if "injection" in output.lower(): self.ingest_vulnerability( "SQL injection vulnerability detected", "Critical", "SQLMAP-SCAN" ) elif command.startswith("adb"): if "debuggable" in output.lower(): self.ingest_vulnerability( "Debuggable Android application found", "High", "ADB-DEBUG" ) if "permission" in output.lower() and "denied" in output.lower(): self.ingest_vulnerability( "Permission issue detected on Android device", "Medium", "ADB-PERMISSION" ) except Exception as e: logger.error(f"Strategic Agent: Error analyzing command output: {e}", exc_info=True) def ingest_vulnerability(self, description: str, severity: str, cve_id: Optional[str] = None, exploit_id: Optional[str] = None): vulnerability = { "description": description, "severity": severity, "timestamp": datetime.now().isoformat() } if cve_id: vulnerability["cve_id"] = cve_id if exploit_id: vulnerability["exploit_id"] = exploit_id self.identified_vulnerabilities.append(vulnerability) logger.info(f"Strategic Agent identified vulnerability: {description} (Severity: {severity})") # Instantiate the Strategic Agent Brain strategic_brain = StrategicAgentBrain() # --- Request Models for API Endpoints --- class RAGRequest(BaseModel): query: constr(min_length=3, max_length=500) top_k: int = Field(5, gt=0, le=20) class FirebaseQueryRequest(BaseModel): goal: str phase: str = None limit: int = 10 class DeepSearchRequest(BaseModel): device_info: str os_version: str class SetGoalRequest(BaseModel): goal: str class GetNextCommandRequest(BaseModel): current_state: str last_command_output: str last_command_success: bool execution_history_summary: List[Dict] = [] gathered_info_summary: List[str] = [] identified_vulnerabilities_summary: List[Dict] = [] class ObserveResultRequest(BaseModel): command: str output: str success: bool class LoadStrategicModelRequest(BaseModel): model_url: str # Now expects a URL instead of a local path # --- API Endpoints --- @app.get("/health") async def health_check(): """Endpoint to check the health of the service.""" logger.debug("Health check requested.") return {"status": "ok", "message": "Knowledge service is running."} @app.post("/rag/retrieve") async def rag_retrieve_endpoint(request: RAGRequest): logger.debug(f"RAG retrieve endpoint called with query: {request.query}") try: results = rag_index.retrieve(request.query, request.top_k) return {"success": True, "data": {"results": results}, "error": None} except Exception as e: logger.error(f"RAG retrieval failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/firebase/query") async def firebase_query_endpoint(request: FirebaseQueryRequest): logger.debug(f"Firebase query endpoint called with goal: {request.goal}, phase: {request.phase}") try: results = firebase_kb.query(request.goal, request.phase, request.limit) return {"success": True, "data": {"results": results}, "error": None} except Exception as e: logger.error(f"Firebase query failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/deep_search") async def deep_search_endpoint(request: DeepSearchRequest): logger.debug(f"Deep search endpoint called for device: {request.device_info}, OS: {request.os_version}") try: results = deep_search_engine.search_device_info(request.device_info, request.os_version) results["public_resources"] = deep_search_engine.search_public_resources(request.device_info) return {"success": True, "data": results, "error": None} except Exception as e: logger.error(f"Deep search failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/strategic_agent/load_model") async def load_strategic_model(request: LoadStrategicModelRequest): logger.info(f"Request to load strategic model: {request.model_url}") success, message = await strategic_brain.load_strategic_llm(request.model_url) if success: logger.info(f"Strategic model loaded successfully: {message}") return {"status": "success", "message": message, "model": strategic_brain.loaded_model_name} else: logger.error(f"Failed to load strategic model: {message}") raise HTTPException(status_code=500, detail=message) @app.post("/strategic_agent/unload_model") async def unload_strategic_model(): logger.info("Request to unload strategic model.") await strategic_brain.unload_strategic_llm() return {"status": "success", "message": "Strategic LLM unloaded."} @app.post("/strategic_agent/set_goal") async def strategic_set_goal(request: SetGoalRequest): logger.info(f"Strategic Agent received new goal: {request.goal}") # Call the synchronous set_goal method strategic_brain.set_goal(request.goal) return {"status": "success", "message": f"Goal set to: {request.goal}"} @app.post("/strategic_agent/get_next_command") async def strategic_get_next_command(request: GetNextCommandRequest): logger.debug("Strategic Agent received request for next command.") # Update strategic brain's state with latest from execution agent strategic_brain.execution_history = request.execution_history_summary strategic_brain.gathered_info = request.gathered_info_summary strategic_brain.identified_vulnerabilities = request.identified_vulnerabilities_summary # Simulate agent's thinking process command = strategic_brain.parse_llm_response( strategic_brain._get_llm_response( strategic_brain._generate_llm_prompt() # Generate prompt based on updated state ) ) if command: strategic_brain.used_commands.add(command) # Ensure strategic agent tracks used commands logger.info(f"Strategic Agent generated command: {command}") return {"command": command, "status": "success"} else: # Fallback if strategic agent fails to generate a valid command target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', strategic_brain.current_goal or "") fallback_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" logger.warning(f"Strategic Agent failed to generate command. Returning fallback: {fallback_ip}") # If no LLM is loaded, provide a more informative fallback if strategic_brain.llm is None: return {"command": f"echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {fallback_ip}'", "status": "fallback", "message": "No LLM loaded on Strategic Agent. Please load one from the frontend settings."} else: return {"command": f"nmap -sV -Pn {fallback_ip}", "status": "fallback", "message": "Strategic Agent could not determine a valid next command."} @app.post("/strategic_agent/observe_result") async def strategic_observe_result(request: ObserveResultRequest): logger.debug(f"Strategic Agent received observation for command: {request.command}, success: {request.success}") strategic_brain.observe_result(request.command, request.output, request.success) return {"status": "success", "message": "Observation received and processed."} @app.get("/strategic_agent/get_status") async def strategic_get_status(): logger.debug("Strategic Agent status requested.") return { "currentGoal": strategic_brain.current_goal, "currentPhase": strategic_brain.current_phase.replace('_', ' ').title(), "reactCycleCount": strategic_brain.react_cycle_count, "noProgressCount": strategic_brain.no_progress_count, "identifiedVulnerabilities": [v['description'] for v in strategic_brain.identified_vulnerabilities], "gatheredInfo": [info[:100] + "..." for info in strategic_brain.gathered_info[-5:]] if strategic_brain.gathered_info else [], "executionHistorySummary": [{ "command": e['command'], "success": e['success'], "timestamp": e['timestamp'] } for e in strategic_brain.execution_history[-10:]], "strategicPlan": strategic_brain.current_plan, "currentPhaseIndex": strategic_brain.current_phase_index, "goalAchieved": strategic_brain.goal_achieved, "strategicAgentStatus": "Running" if strategic_brain.current_goal and not strategic_brain.goal_achieved else "Idle", "loadedModel": strategic_brain.loaded_model_name # Return the name of the loaded model } @app.get("/api/models") async def get_available_models_strategic(): """List predefined Hugging Face models for strategic agent.""" logger.debug("Request for available strategic models received.") # Explicitly return JSONResponse to ensure correct content type return JSONResponse(content=json.dumps(HUGGINGFACE_MODELS), media_type="application/json") # --- Startup Event to Download All Models and Start ngrok Tunnel (Modified for HF Spaces) --- @app.on_event("startup") async def startup_event_download_models(): # Renamed function logger.info("Application startup event triggered. Attempting to download all predefined models.") # Download all models for model_info in HUGGINGFACE_MODELS: model_url = model_info["url"] model_name = model_info["name"] model_filename = model_url.split('/')[-1] local_model_path = os.path.join(DOWNLOAD_DIR, model_filename) if not os.path.exists(local_model_path): logger.info(f"Downloading model '{model_name}' from {model_url} to {local_model_path}...") try: response = requests.get(model_url, stream=True) response.raise_for_status() with open(local_model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info(f"Model '{model_name}' downloaded successfully.") except Exception as e: logger.error(f"Failed to download model '{model_name}': {e}", exc_info=True) else: logger.info(f"Model '{model_name}' already exists at {local_model_path}. Skipping download.") logger.info("Finished attempting to download all predefined models.") # --- Shutdown Event (ngrok related parts removed) --- @app.on_event("shutdown") async def shutdown_event_cleanup(): # Renamed function logger.info("Application shutdown event triggered. Performing cleanup.") # No ngrok.kill() needed here as ngrok is not used if __name__ == "__main__": import uvicorn logger.info("Starting FastAPI application on Hugging Face Spaces (port 7860)...") uvicorn.run( app, host="0.0.0.0", port=7860, # Standard port for Hugging Face Spaces log_level="info" # Changed to info for less verbose default output )