MrA7A commited on
Commit
e26acf0
·
verified ·
1 Parent(s): d6ec34c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1107 -0
app.py ADDED
@@ -0,0 +1,1107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (Strategic Agent Service for Hugging Face Spaces - CPU Only, Preload All Models, No ngrok)
2
+ import os
3
+ import json
4
+ import logging
5
+ import numpy as np
6
+ import requests
7
+ from fastapi import FastAPI, HTTPException, Depends, status
8
+ from pydantic import BaseModel, Field, constr
9
+ from sentence_transformers import SentenceTransformer
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ from datetime import datetime
12
+ import firebase_admin
13
+ from firebase_admin import credentials, firestore
14
+ from bs4 import BeautifulSoup
15
+ import re
16
+ from typing import List, Dict, Optional, Tuple
17
+ from cachetools import TTLCache
18
+ import gc
19
+ from llama_cpp import Llama
20
+ import asyncio
21
+ import nest_asyncio
22
+ from fastapi.responses import JSONResponse # Added for explicit JSONResponse
23
+
24
+ # Apply nest_asyncio to allow running asyncio.run() in environments with existing event loops
25
+ nest_asyncio.apply()
26
+
27
+ # --- Configuration ---
28
+ # Directory to store downloaded GGUF models within Hugging Face Space's writable space
29
+ DOWNLOAD_DIR = "./downloaded_models/" # Changed to a local directory within the Space
30
+ os.makedirs(DOWNLOAD_DIR, exist_ok=True)
31
+
32
+ # Predefined Hugging Face GGUF model URLs for dynamic loading
33
+ HUGGINGFACE_MODELS = [
34
+ {
35
+ "name": "Foundation-Sec-8B-Q8_0",
36
+ "url": "https://huggingface.co/fdtn-ai/Foundation-Sec-8B-Q8_0-GGUF/resolve/main/foundation-sec-8b-q8_0.gguf"
37
+ },
38
+ {
39
+ "name": "Lily-Cybersecurity-7B-v0.2-Q8_0",
40
+ "url": "https://huggingface.co/Nekuromento/Lily-Cybersecurity-7B-v0.2-Q8_0-GGUF/resolve/main/lily-cybersecurity-7b-v0.2-q8_0.gguf"
41
+ },
42
+ {
43
+ "name": "SecurityLLM-GGUF (sarvam-m-q8_0)",
44
+ "url": "https://huggingface.co/QuantFactory/SecurityLLM-GGUF/resolve/main/sarvam-m-q8_0.gguf"
45
+ }
46
+ ]
47
+
48
+ DATA_DIR = "./data" # Local data for Hugging Face Space
49
+ DEEP_SEARCH_CACHE_TTL = 3600
50
+
51
+ # --- ngrok Configuration (Removed) ---
52
+ # NGROK_AUTH_TOKEN and NGROK_STRATEGIC_AGENT_TUNNEL_URL are removed
53
+
54
+ # --- Logging Setup ---
55
+ logging.basicConfig(
56
+ level=logging.DEBUG, # Changed from INFO to DEBUG
57
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
58
+ )
59
+ logger = logging.getLogger(__name__)
60
+ logger.info("Logging initialized with DEBUG level.")
61
+
62
+ # Initialize FastAPI app
63
+ app = FastAPI(
64
+ title="Hugging Face Strategic Agent Service",
65
+ description="Provides knowledge base access and strategic reasoning for the pentest agent on Hugging Face Spaces.",
66
+ version="1.0.0"
67
+ )
68
+
69
+ # Initialize Firebase
70
+ firebase_creds_path = os.getenv("FIREBASE_CREDS_PATH", "cred.json")
71
+ db = None
72
+ if not firebase_admin._apps:
73
+ try:
74
+ if os.path.exists(firebase_creds_path):
75
+ cred = credentials.Certificate(firebase_creds_path)
76
+ firebase_admin.initialize_app(cred)
77
+ db = firestore.client()
78
+ logger.info("Firebase initialized successfully.")
79
+ else:
80
+ logger.warning(f"Firebase credentials file not found at {firebase_creds_path}. Firebase will not be initialized.")
81
+ except Exception as e:
82
+ logger.error(f"Failed to initialize Firebase: {e}. Ensure FIREBASE_CREDS_PATH is set correctly and the file exists.", exc_info=True)
83
+
84
+ # Global LLM instance for Strategic Agent
85
+ strategic_llm: Optional[Llama] = None
86
+ current_strategic_model_url: Optional[str] = None # Now tracks URL, not local path
87
+
88
+ # Supported tools (Strategic Agent needs to know these for command generation)
89
+ SUPPORTED_TOOLS = [
90
+ "nmap", "gobuster", "nikto", "sqlmap", "adb", "frida",
91
+ "drozer", "apktool", "msfconsole", "mobsfscan", "burpsuite",
92
+ "metasploit", "curl", "wget", "hydra", "john", "aircrack-ng"
93
+ ]
94
+
95
+ # --- Deep Search Cache ---
96
+ deep_search_cache = TTLCache(maxsize=100, ttl=DEEP_SEARCH_CACHE_TTL)
97
+
98
+ # --- Enhanced System Instruction (English) ---
99
+ SYSTEM_INSTRUCTION = (
100
+ "You are an expert pentest agent. Strictly follow these rules:\n"
101
+ "1. Output ONLY valid shell commands\n"
102
+ "2. NEVER include timestamps, dates, or any text outside commands\n"
103
+ "3. Never repeat previous commands\n"
104
+ "4. Always verify command safety before execution\n\n"
105
+ "Example valid response:\n"
106
+ "nmap -sV 192.168.1.6\n\n"
107
+ "Key Principles:\n"
108
+ "- Never give up until the goal is achieved\n"
109
+ "- Learn from failures and adapt strategies\n"
110
+ "- Leverage all available knowledge and tools\n"
111
+ "- Break complex tasks into smaller achievable steps\n"
112
+ "- Always ensure actions are ethical and within scope\n\n"
113
+ "Available Tools:\n"
114
+ "- nmap: Network scanning and service detection\n"
115
+ "- gobuster: Web directory brute-forcing\n"
116
+ "- nikto: Web server vulnerability scanner\n"
117
+ "- sqlmap: SQL injection testing\n"
118
+ "- adb: Android Debug Bridge\n"
119
+ "- metasploit: Exploitation framework\n\n"
120
+ "Error Handling Examples:\n"
121
+ "Example 1 (Command Failure):\n"
122
+ " If nmap fails because host is down, try: nmap -Pn -sV 192.168.1.6\n"
123
+ "Example 2 (Web Server Error):\n"
124
+ " If web server returns 403, try: gobuster dir -u http://192.168.1.6 -w /usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt\n"
125
+ "Example 3 (ADB Connection Failed):\n"
126
+ " If ADB connection fails, try: adb kill-server && adb start-server"
127
+ )
128
+
129
+ # --- Firebase Knowledge Base Integration ---
130
+ class FirebaseKnowledgeBase:
131
+ def __init__(self):
132
+ self.collection = db.collection('knowledge_base') if db else None
133
+
134
+ def query(self, goal: str, phase: str = None, limit: int = 10) -> list:
135
+ if not db or not firebase_admin._apps: # Check if Firebase is initialized
136
+ logger.error("Firestore client not initialized. Cannot query knowledge base.")
137
+ return []
138
+
139
+ # Re-instantiate collection if it's None (e.g., if Firebase init failed initially)
140
+ if not hasattr(self, 'collection') or self.collection is None:
141
+ self.collection = db.collection('knowledge_base')
142
+
143
+ keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device']
144
+ if phase:
145
+ keywords.append(phase.lower())
146
+
147
+ try:
148
+ query_ref = self.collection
149
+ results = []
150
+ docs = query_ref.stream() # Use query_ref instead of self.collection directly
151
+
152
+ for doc in docs:
153
+ data = doc.to_dict()
154
+ text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}"
155
+ if any(keyword in text for keyword in keywords):
156
+ results.append(data)
157
+ if len(results) >= 10: # Use a fixed limit for stream
158
+ break
159
+
160
+ priority_order = {"high": 1, "medium": 2, "low": 3}
161
+ results.sort(key=lambda x: (
162
+ priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3),
163
+ x.get('metadata', {}).get('timestamp', 0)
164
+ ))
165
+
166
+ return results[:10] # Ensure limit is applied
167
+ except Exception as e:
168
+ logger.error(f"Failed to query knowledge base: {e}", exc_info=True)
169
+ return []
170
+
171
+ # --- RAG Knowledge Index ---
172
+ class KnowledgeIndex:
173
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
174
+ self.model = SentenceTransformer(
175
+ model_name,
176
+ cache_folder=os.path.join(DATA_DIR, "hf_cache") # Use local data dir for cache
177
+ )
178
+ self.knowledge_base = []
179
+ os.makedirs(DATA_DIR, exist_ok=True)
180
+ self.load_knowledge_from_file(os.path.join(DATA_DIR, 'knowledge_base.json'))
181
+
182
+ def load_knowledge_from_file(self, file_path):
183
+ logger.debug(f"Attempting to load knowledge from file: {file_path}")
184
+ if os.path.exists(file_path):
185
+ try:
186
+ with open(file_path, 'r', encoding='utf-8') as f:
187
+ data = json.load(f)
188
+ if not isinstance(data, list):
189
+ logger.error("Knowledge base file is not a list. Please check the file format.")
190
+ return
191
+ for item in data:
192
+ if isinstance(item, dict):
193
+ text = item.get('text', '')
194
+ source = item.get('source', 'local')
195
+ elif isinstance(item, str):
196
+ text = item
197
+ source = 'local'
198
+ else:
199
+ logger.warning(f"Skipping unsupported item type: {type(item)}")
200
+ continue
201
+ if text:
202
+ embedding = self.model.encode(text).tolist()
203
+ self.knowledge_base.append({'text': text, 'embedding': embedding, 'source': source})
204
+ logger.info(f"Loaded {len(self.knowledge_base)} items into RAG knowledge base.")
205
+ except Exception as e:
206
+ logger.error(f"Error loading knowledge from {file_path}: {e}", exc_info=True)
207
+ else:
208
+ logger.warning(f"Knowledge base file not found: {file_path}. RAG will operate on an empty knowledge base.")
209
+ try:
210
+ with open(file_path, 'w', encoding='utf-8') as f:
211
+ json.dump([], f)
212
+ logger.info(f"Created empty knowledge base file at: {file_path}")
213
+ except Exception as e:
214
+ logger.error(f"Error creating empty knowledge base file at {file_path}: {e}", exc_info=True)
215
+
216
+ def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
217
+ if not self.knowledge_base:
218
+ logger.debug("Knowledge base is empty, no RAG retrieval possible.")
219
+ return []
220
+
221
+ try:
222
+ query_embedding = self.model.encode(query).reshape(1, -1)
223
+ embeddings = np.array([item['embedding'] for item in self.knowledge_base])
224
+
225
+ similarities = cosine_similarity(query_embedding, embeddings)[0]
226
+ top_indices = similarities.argsort()[-top_k:][::-1]
227
+
228
+ results = []
229
+ for i in top_indices:
230
+ results.append({
231
+ "text": self.knowledge_base[i]['text'],
232
+ "similarity": similarities[i],
233
+ "source": self.knowledge_base[i].get('source', 'RAG')
234
+ })
235
+ logger.debug(f"RAG retrieved {len(results)} results for query: '{query}'")
236
+ return results
237
+ except Exception as e:
238
+ logger.error(f"Error during RAG retrieval for query '{query}': {e}", exc_info=True)
239
+ return []
240
+
241
+ # --- Deep Search Engine ---
242
+ class DeepSearchEngine:
243
+ def __init__(self):
244
+ self.headers = {
245
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
246
+ }
247
+
248
+ def search_device_info(self, device_info: str, os_version: str) -> dict:
249
+ logger.debug(f"Performing deep search for device: {device_info}, OS: {os_version}")
250
+ results = {
251
+ "device": device_info,
252
+ "os_version": os_version,
253
+ "vulnerabilities": [],
254
+ "exploits": [],
255
+ "recommendations": []
256
+ }
257
+ try:
258
+ cve_results = self.search_cve(device_info, os_version)
259
+ results["vulnerabilities"] = cve_results
260
+ exploit_results = self.search_exploits(device_info, os_version)
261
+ results["exploits"] = exploit_results
262
+ recommendations = self.get_security_recommendations(os_version)
263
+ results["recommendations"] = recommendations
264
+ logger.debug("Deep search completed.")
265
+ except Exception as e:
266
+ logger.error(f"Deep search failed: {e}", exc_info=True)
267
+ return results
268
+
269
+ def search_cve(self, device: str, os_version: str) -> list:
270
+ cves = []
271
+ try:
272
+ query = f"{device} {os_version} CVE"
273
+ search_url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}"
274
+ logger.debug(f"Searching CVE Mitre: {search_url}")
275
+ response = requests.get(search_url, headers=self.headers)
276
+ response.raise_for_status() # Raise an exception for HTTP errors
277
+ if response.status_code == 200:
278
+ soup = BeautifulSoup(response.text, 'html.parser')
279
+ table = soup.find('div', id='TableWithRules')
280
+ if table:
281
+ rows = table.find_all('tr')[1:]
282
+ for row in rows:
283
+ cols = row.find_all('td')
284
+ if len(cols) >= 2:
285
+ cve_id = cols[0].get_text(strip=True)
286
+ description = cols[1].get_text(strip=True)
287
+ cves.append({
288
+ "cve_id": cve_id,
289
+ "description": description,
290
+ "source": "CVE Mitre"
291
+ })
292
+ logger.debug(f"Found {len(cves)} CVEs.")
293
+ return cves[:10]
294
+ except Exception as e:
295
+ logger.error(f"CVE search failed: {e}", exc_info=True)
296
+ return []
297
+
298
+ def search_exploits(self, device: str, os_version: str) -> list:
299
+ exploits = []
300
+ try:
301
+ query = f"{device} {os_version}"
302
+ search_url = f"https://www.exploit-db.com/search?q={query}"
303
+ logger.debug(f"Searching ExploitDB: {search_url}")
304
+ response = requests.get(search_url, headers=self.headers)
305
+ response.raise_for_status() # Raise an exception for HTTP errors
306
+ if response.status_code == 200:
307
+ soup = BeautifulSoup(response.text, 'html.parser')
308
+ cards = soup.select('.card .card-title')
309
+ for card in cards:
310
+ title = card.get_text(strip=True)
311
+ link = card.find('a')['href']
312
+ if not link.startswith('http'):
313
+ link = f"https://www.exploit-db.com{link}"
314
+ exploits.append({
315
+ "title": title,
316
+ "link": link,
317
+ "source": "ExploitDB"
318
+ })
319
+ logger.debug(f"Found {len(exploits)} exploits.")
320
+ return exploits[:10]
321
+ except Exception as e:
322
+ logger.error(f"Exploit search failed: {e}", exc_info=True)
323
+ return []
324
+
325
+ def get_security_recommendations(self, os_version: str) -> list:
326
+ recommendations = []
327
+ try:
328
+ logger.debug(f"Getting security recommendations for OS: {os_version}")
329
+ if "android" in os_version.lower():
330
+ url = "https://source.android.com/docs/security/bulletin"
331
+ response = requests.get(url, headers=self.headers)
332
+ response.raise_for_status()
333
+ if response.status_code == 200:
334
+ soup = BeautifulSoup(response.text, 'html.parser')
335
+ versions = soup.select('.devsite-article-body h2')
336
+ for version in versions:
337
+ if os_version in version.get_text():
338
+ next_ul = version.find_next('ul')
339
+ if next_ul:
340
+ items = next_ul.select('li')
341
+ for item in items:
342
+ recommendations.append(item.get_text(strip=True))
343
+ elif "ios" in os_version.lower():
344
+ url = "https://support.apple.com/en-us/HT201222"
345
+ response = requests.get(url, headers=self.headers)
346
+ response.raise_for_status()
347
+ if response.status_code == 200:
348
+ soup = BeautifulSoup(response.text, 'html.parser')
349
+ sections = soup.select('#sections')
350
+ for section in sections:
351
+ if os_version in section.get_text():
352
+ items = section.select('li')
353
+ for item in items:
354
+ recommendations.append(item.get_text(strip=True))
355
+ logger.debug(f"Found {len(recommendations)} recommendations.")
356
+ return recommendations[:5]
357
+ except Exception as e:
358
+ logger.error(f"Security recommendations search failed: {e}", exc_info=True)
359
+ return []
360
+
361
+ def search_public_resources(self, device_info: str) -> list:
362
+ resources = []
363
+ try:
364
+ logger.debug(f"Searching public resources for device: {device_info}")
365
+ github_url = f"https://github.com/search?q={device_info.replace(' ', '+')}+pentest"
366
+ response = requests.get(github_url, headers=self.headers)
367
+ response.raise_for_status()
368
+ if response.status_code == 200:
369
+ soup = BeautifulSoup(response.text, 'html.parser')
370
+ repos = soup.select('.repo-list-item')
371
+ for repo in repos:
372
+ title = repo.select_one('.v-align-middle').get_text(strip=True)
373
+ description = repo.select_one('.mb-1').get_text(strip=True) if repo.select_one('.mb-1') else ""
374
+ url = f"https://github.com{repo.select_one('.v-align-middle')['href']}"
375
+ resources.append({
376
+ "title": title,
377
+ "description": description,
378
+ "url": url,
379
+ "source": "GitHub"
380
+ })
381
+ forum_url = f"https://hackforums.net/search.php?action=finduserthreads&keywords={device_info.replace(' ', '+')}"
382
+ response = requests.get(forum_url, headers=self.headers)
383
+ response.raise_for_status()
384
+ if response.status_code == 200:
385
+ soup = BeautifulSoup(response.text, 'html.parser')
386
+ threads = soup.select('.thread')
387
+ for thread in threads:
388
+ title = thread.select_one('.threadtitle').get_text(strip=True)
389
+ url = f"https://hackforums.net{thread.select_one('.threadtitle a')['href']}"
390
+ resources.append({
391
+ "title": title,
392
+ "description": "Forum discussion",
393
+ "url": url,
394
+ "source": "HackForums"
395
+ })
396
+ logger.debug(f"Found {len(resources)} public resources.")
397
+ return resources[:10]
398
+ except Exception as e:
399
+ logger.error(f"Public resources search failed: {e}", exc_info=True)
400
+ return []
401
+
402
+ # --- Initialize Services (Local to Strategic Agent) ---
403
+ firebase_kb = FirebaseKnowledgeBase()
404
+ rag_index = KnowledgeIndex()
405
+ deep_search_engine = DeepSearchEngine()
406
+
407
+ # --- Strategic Agent Brain (formerly SmartExecutionEngine logic) ---
408
+ class StrategicAgentBrain:
409
+ def __init__(self):
410
+ self.llm: Optional[Llama] = None
411
+ self.current_goal: Optional[str] = None
412
+ self.current_phase: str = "initial_reconnaissance"
413
+ self.current_plan: List[Dict] = []
414
+ self.current_phase_index: int = 0
415
+ self.identified_vulnerabilities: List[Dict] = []
416
+ self.gathered_info: List[str] = []
417
+ self.command_retry_counts: Dict[str, int] = {}
418
+ self.conversation_history: List[Dict] = []
419
+ self.used_commands = set()
420
+ self.execution_history = []
421
+ self.goal_achieved = False
422
+ self.no_progress_count = 0
423
+ self.react_cycle_count = 0
424
+ self.loaded_model_name: Optional[str] = None # To store the name of the loaded model
425
+ logger.info("StrategicAgentBrain initialized.")
426
+
427
+ async def load_strategic_llm(self, model_url: str):
428
+ global strategic_llm, current_strategic_model_url
429
+ logger.info(f"Attempting to load strategic LLM from URL: {model_url}")
430
+
431
+ # Determine local path for the model
432
+ model_filename = model_url.split('/')[-1]
433
+ local_model_path = os.path.join(DOWNLOAD_DIR, model_filename)
434
+
435
+ if strategic_llm and current_strategic_model_url == model_url:
436
+ logger.info(f"Strategic LLM model from {model_url} is already loaded.")
437
+ self.llm = strategic_llm
438
+ return True, f"Model '{self.loaded_model_name}' is already loaded."
439
+
440
+ # If a model is currently loaded, unload it first
441
+ if strategic_llm:
442
+ await self.unload_strategic_llm()
443
+
444
+ # Ensure model is downloaded before attempting to load
445
+ if not os.path.exists(local_model_path):
446
+ logger.info(f"Model not found locally. Attempting to download from {model_url} to {local_model_path}...")
447
+ try:
448
+ response = requests.get(model_url, stream=True)
449
+ response.raise_for_status()
450
+ with open(local_model_path, 'wb') as f:
451
+ for chunk in response.iter_content(chunk_size=8192):
452
+ f.write(chunk)
453
+ logger.info(f"Model downloaded successfully to {local_model_path}.")
454
+ except Exception as e:
455
+ logger.error(f"Failed to download model from {model_url}: {e}", exc_info=True)
456
+ return False, f"Failed to download model: {str(e)}"
457
+
458
+ try:
459
+ logger.info(f"Loading Strategic LLM model from {local_model_path}...")
460
+ strategic_llm = Llama(
461
+ model_path=local_model_path,
462
+ n_ctx=3096,
463
+ n_gpu_layers=0, # Explicitly set to 0 for CPU-only
464
+ n_threads=os.cpu_count(), # Use all available CPU threads
465
+ n_batch=512,
466
+ verbose=False
467
+ )
468
+ current_strategic_model_url = model_url
469
+ self.llm = strategic_llm
470
+ self.loaded_model_name = model_filename # Store the filename
471
+ logger.info(f"Strategic LLM model {model_filename} loaded successfully (CPU-only).")
472
+ return True, f"Model '{model_filename}' loaded successfully (CPU-only)."
473
+ except Exception as e:
474
+ logger.error(f"Failed to load Strategic LLM model from {local_model_path}: {e}", exc_info=True)
475
+ strategic_llm = None
476
+ current_strategic_model_url = None
477
+ self.llm = None
478
+ self.loaded_model_name = None
479
+ return False, f"Failed to load model: {str(e)}"
480
+
481
+ async def unload_strategic_llm(self):
482
+ global strategic_llm, current_strategic_model_url
483
+ if strategic_llm:
484
+ logger.info("Unloading Strategic LLM model...")
485
+ del strategic_llm
486
+ strategic_llm = None
487
+ current_strategic_model_url = None
488
+ self.llm = None
489
+ self.loaded_model_name = None
490
+ gc.collect()
491
+ logger.info("Strategic LLM model unloaded.")
492
+
493
+ def _get_rag_context(self, query: str) -> str:
494
+ results = rag_index.retrieve(query)
495
+ if not results:
496
+ return ""
497
+ rag_context = "Relevant Knowledge for Current Context:\n"
498
+ for i, result in enumerate(results):
499
+ text = result.get('text', '') or result.get('completion', '')
500
+ source = result.get('source', 'RAG')
501
+ rag_context += f"{i+1}. [{source}] {text}\n"
502
+ return rag_context
503
+
504
+ def _get_firebase_knowledge(self, goal: str, phase: str = None) -> str:
505
+ if not db or not firebase_admin._apps: # Check if Firebase is initialized
506
+ logger.error("Firestore client not initialized. Cannot query knowledge base.")
507
+ return ""
508
+
509
+ # Re-instantiate collection if it's None (e.g., if Firebase init failed initially)
510
+ if not hasattr(self, 'collection') or self.collection is None:
511
+ self.collection = db.collection('knowledge_base')
512
+
513
+ keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device']
514
+ if phase:
515
+ keywords.append(phase.lower())
516
+
517
+ try:
518
+ query_ref = self.collection
519
+ results = []
520
+ docs = query_ref.stream() # Use query_ref instead of self.collection directly
521
+
522
+ for doc in docs:
523
+ data = doc.to_dict()
524
+ text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}"
525
+ if any(keyword in text for keyword in keywords):
526
+ results.append(data)
527
+ if len(results) >= 10: # Use a fixed limit for stream
528
+ break
529
+
530
+ priority_order = {"high": 1, "medium": 2, "low": 3}
531
+ results.sort(key=lambda x: (
532
+ priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3),
533
+ x.get('metadata', {}).get('timestamp', 0)
534
+ ))
535
+
536
+ return results[:10] # Ensure limit is applied
537
+ except Exception as e:
538
+ logger.error(f"Failed to query knowledge base: {e}", exc_info=True)
539
+ return ""
540
+
541
+ def extract_device_info(self) -> str:
542
+ for info in self.gathered_info:
543
+ if "model" in info.lower() or "device" in info.lower():
544
+ match = re.search(r'(?:model|device)\s*[:=]\s*([^\n]+)', info, re.IGNORECASE)
545
+ if match:
546
+ return match.group(1).strip()
547
+ ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
548
+ return ip_match.group(0) if ip_match else "Unknown Device"
549
+
550
+ def extract_os_version(self) -> str:
551
+ for info in self.gathered_info:
552
+ if "android" in info.lower() or "ios" in info.lower() or "os" in info.lower():
553
+ android_match = re.search(r'android\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE)
554
+ if android_match:
555
+ return f"Android {android_match.group(1)}"
556
+ ios_match = re.search(r'ios\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE)
557
+ if ios_match:
558
+ return f"iOS {ios_match.group(1)}"
559
+ linux_match = re.search(r'linux\s+kernel\s+(\d+\.\d+\.\d+)', info, re.IGNORECASE)
560
+ if linux_match:
561
+ return f"Linux {linux_match.group(1)}"
562
+ return "Unknown OS Version"
563
+
564
+ def format_deep_search_results(self, results: dict) -> str:
565
+ context = "Deep Search Results:\n"
566
+ context += f"Device: {results.get('device', 'Unknown')}\n"
567
+ context += f"OS Version: {results.get('os_version', 'Unknown')}\n\n"
568
+ if results.get('vulnerabilities'):
569
+ context += "Discovered Vulnerabilities:\n"
570
+ for i, vuln in enumerate(results['vulnerabilities'][:5], 1):
571
+ context += f"{i}. {vuln.get('cve_id', 'CVE-XXXX-XXXX')}: {vuln.get('description', 'No description')}\n"
572
+ context += "\n"
573
+ if results.get('exploits'):
574
+ context += "Available Exploits:\n"
575
+ for i, exploit in enumerate(results['exploits'][:5], 1):
576
+ context += f"{i}. {exploit.get('title', 'Untitled exploit')} [Source: {exploit.get('source', 'Unknown')}]\n"
577
+ context += "\n"
578
+ if results.get('recommendations'):
579
+ context += "Security Recommendations:\n"
580
+ for i, rec in enumerate(results['recommendations'][:3], 1):
581
+ context += f"{i}. {rec}\n"
582
+ context += "\n"
583
+ if results.get('public_resources'):
584
+ context += "Public Resources:\n"
585
+ for i, res in enumerate(results['public_resources'][:3], 1):
586
+ context += f"{i}. {res.get('title', 'Untitled resource')} [Source: {res.get('source', 'Unknown')}]\n"
587
+ return context
588
+
589
+ def generate_deep_search_prompt(self, context: str) -> str:
590
+ return f"""
591
+ You are an expert pentester. Below are deep search results for the target device.
592
+ Use this information to generate the next penetration testing command.{context}
593
+
594
+ Current Goal: {self.current_goal}
595
+ Current Phase: {self.current_phase}
596
+
597
+ Recent Command History:{', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'}
598
+
599
+ Based on this information, what is the SINGLE MOST EFFECTIVE shell command to execute next?
600
+ Focus on exploiting the most critical vulnerabilities or gathering more information.
601
+
602
+ Response Format:
603
+ Command: <your_command_here>
604
+ """
605
+
606
+ def _generate_llm_prompt(self) -> str:
607
+ rag_context = self._get_rag_context(f"{self.current_goal} {self.current_phase}")
608
+ firebase_knowledge = self._get_firebase_knowledge(self.current_goal, self.current_phase)
609
+
610
+ history_context = "\n".join(
611
+ f"{entry['role']}: {entry['content']}" for entry in self.conversation_history[-2:]
612
+ )
613
+
614
+ execution_history = "\n".join(
615
+ f"Command: {res['command']}\nResult: {res['output'][:100]}...\nSuccess: {res['success']}"
616
+ for res in self.execution_history[-2:]
617
+ ) if self.execution_history else "No previous results."
618
+
619
+ strategic_advice = self._get_rag_context(self.current_phase) # Using RAG for strategic advice too
620
+
621
+ def shorten_text(text, max_length=300):
622
+ if len(text) > max_length:
623
+ return text[:max_length] + "... [truncated]"
624
+ return text
625
+
626
+ rag_context = shorten_text(rag_context, max_length=200)
627
+ firebase_knowledge = shorten_text(firebase_knowledge, max_length=200)
628
+ strategic_advice = shorten_text(strategic_advice, max_length=100)
629
+ history_context = shorten_text(history_context, max_length=150)
630
+ execution_history = shorten_text(execution_history, max_length=500)
631
+
632
+ prompt = f"""
633
+ System Instructions: {SYSTEM_INSTRUCTION}
634
+
635
+ Current Goal: '{self.current_goal}'
636
+ Current Phase: {self.current_phase} - {self.current_plan[self.current_phase_index]['objective'] if self.current_plan and self.current_phase_index < len(self.current_plan) else 'No objective'}
637
+
638
+ Based on the following knowledge and previous results, generate the SINGLE, VALID SHELL COMMAND to advance the penetration testing process.
639
+
640
+ **Knowledge from External Services (RAG & Firebase):**
641
+ {rag_context}
642
+ {firebase_knowledge}
643
+
644
+ **Previous Execution Results:**
645
+ {execution_history}
646
+
647
+ **Recent Conversation History:**
648
+ {history_context}
649
+
650
+ **Strategic Advice for Current Phase:**
651
+ {strategic_advice}
652
+
653
+ ***CRITICAL RULES FOR OUTPUT:***
654
+ 1. **OUTPUT ONLY THE COMMAND.**
655
+ 2. **DO NOT INCLUDE ANY REASONING, THOUGHTS, EXPLANATIONS, OR ANY OTHER TEXT.**
656
+ 3. The command MUST be directly executable in a Linux terminal.
657
+ 4. Avoid repeating these recent commands: {', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'}
658
+ 5. If the previous command failed, try a different approach or a related tool.
659
+ 6. For the 'android_enumeration' phase, prioritize ADB commands.
660
+
661
+ Example valid commands for initial reconnaissance of an Android phone:
662
+ nmap -sV -Pn 192.168.1.14
663
+ adb devices
664
+ adb connect 192.168.1.14:5555
665
+
666
+ Command:
667
+ """
668
+ return prompt
669
+
670
+ def _get_llm_response(self, custom_prompt: str = None) -> str:
671
+ if not self.llm:
672
+ logger.error("Strategic LLM instance is None. Cannot get response. Please load a model first.")
673
+ target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
674
+ target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
675
+ return f"Command: echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {target_ip}'"
676
+
677
+ prompt = custom_prompt if custom_prompt else self._generate_llm_prompt()
678
+
679
+ logger.info(f"Sending prompt to Strategic LLM:\n{prompt[:500]}...")
680
+
681
+ try:
682
+ response = self.llm(
683
+ prompt,
684
+ max_tokens=512,
685
+ temperature=0.3,
686
+ stop=["\n"]
687
+ )
688
+ llm_response = response['choices'][0]['text'].strip()
689
+ logger.info(f"Strategic LLM raw response: {llm_response}")
690
+
691
+ if not llm_response:
692
+ logger.warning("Strategic LLM returned an empty response. Using fallback command.")
693
+ target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
694
+ target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
695
+ return f"Command: nmap -sV -Pn {target_ip}"
696
+
697
+ return llm_response
698
+ except Exception as e:
699
+ logger.error(f"Error during Strategic LLM inference: {e}", exc_info=True)
700
+ logger.warning("Strategic LLM inference failed. Using fallback command.")
701
+ target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
702
+ target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
703
+ return f"Command: nmap -sV -Pn {target_ip}"
704
+
705
+ def parse_llm_response(self, response: str) -> str:
706
+ logger.info(f"Attempting to parse LLM response: '{response}'")
707
+ command = None
708
+ try:
709
+ code_block = re.search(r'```(?:bash|sh)?\s*([\s\S]*?)```', response)
710
+ if code_block:
711
+ command = code_block.group(1).strip()
712
+ logger.info(f"Command extracted from code block: '{command}'")
713
+
714
+ if not command:
715
+ command_match = re.search(r'^\s*Command\s*:\s*(.+)$', response, re.MULTILINE | re.IGNORECASE)
716
+ if command_match:
717
+ command = command_match.group(1).strip()
718
+ logger.info(f"Command extracted from 'Command:' line: '{command}'")
719
+
720
+ if not command:
721
+ stripped_response = response.strip()
722
+ if any(stripped_response.startswith(tool) for tool in SUPPORTED_TOOLS):
723
+ command = stripped_response
724
+ logger.info(f"Command extracted as direct supported tool command: '{command}'")
725
+
726
+ if command:
727
+ original_command = command
728
+ command = re.sub(r'^\s*(Command|Answer|Note|Result)\s*[:.-]?\s*', '', command, flags=re.IGNORECASE).strip()
729
+ logger.info(f"Cleaned command: from '{original_command}' to '{command}'")
730
+
731
+ if not re.match(r'^[a-zA-Z0-9_./:;= \-\'"\s]+$', command):
732
+ logger.error(f"Invalid command characters detected after cleanup: '{command}'")
733
+ return None
734
+
735
+ if re.search(r'(reason|thought|explanation|rationale|note|result):', command, re.IGNORECASE):
736
+ logger.warning(f"Command '{command}' appears to be reasoning/explanation. Rejecting.")
737
+ return None
738
+
739
+ if command not in self.used_commands:
740
+ self.used_commands.add(command)
741
+ logger.info(f"Returning valid and new command: '{command}'")
742
+ return command
743
+ else:
744
+ logger.warning(f"Command '{command}' already used. Skipping.")
745
+ return None
746
+ else:
747
+ logger.warning("No valid command could be extracted from LLM response based on strict rules.")
748
+ return None
749
+ except Exception as e:
750
+ logger.error(f"Error parsing LLM response: {e}", exc_info=True)
751
+ return None
752
+
753
+ def set_goal(self, goal: str):
754
+ self.current_goal = goal
755
+ self.goal_achieved = False
756
+ self.react_cycle_count = 0
757
+ self.no_progress_count = 0
758
+
759
+ self.current_plan = self._generate_strategic_plan(goal)
760
+ self.current_phase_index = 0
761
+ self.identified_vulnerabilities = []
762
+ self.gathered_info = []
763
+ self.command_retry_counts = {}
764
+ self.conversation_history = [{"role": "user", "content": f"New goal set: {goal}"}]
765
+ self.used_commands.clear()
766
+ self.execution_history = []
767
+ self.goal_achieved = False
768
+ logger.info(f"Strategic Agent Goal set: {goal}. Starting initial reconnaissance.")
769
+
770
+
771
+ def _generate_strategic_plan(self, goal: str) -> List[Dict]:
772
+ logger.debug(f"Generating strategic plan for goal: {goal}")
773
+ plan = []
774
+ goal_lower = goal.lower()
775
+
776
+ plan.append({"phase": "initial_reconnaissance", "objective": f"Perform initial reconnaissance for {goal}"})
777
+
778
+ if "web" in goal_lower or "http" in goal_lower:
779
+ plan.append({"phase": "web_enumeration", "objective": "Enumerate web server for directories and files"})
780
+ plan.append({"phase": "web_vulnerability_analysis", "objective": "Analyze web vulnerabilities (SQLi, XSS, etc.)"})
781
+ plan.append({"phase": "web_exploitation", "objective": "Attempt to exploit web vulnerabilities"})
782
+ plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation activities"})
783
+
784
+ elif "android" in goal_lower or "mobile" in goal_lower or "adb" in goal_lower:
785
+ plan.append({"phase": "android_enumeration", "objective": "Enumerate Android device via ADB"})
786
+ plan.append({"phase": "android_app_analysis", "objective": "Analyze Android application for vulnerabilities"})
787
+ plan.append({"phase": "android_exploitation", "objective": "Attempt to exploit Android vulnerabilities"})
788
+ plan.append({"phase": "data_extraction", "objective": "Extract sensitive data from device"})
789
+
790
+ else:
791
+ plan.append({"phase": "network_scanning", "objective": "Perform detailed network scanning"})
792
+ plan.append({"phase": "service_enumeration", "objective": "Enumerate services and identify versions"})
793
+ plan.append({"phase": "vulnerability_analysis", "objective": "Analyze services for vulnerabilities"})
794
+ plan.append({"phase": "exploitation", "objective": "Attempt to exploit vulnerabilities"})
795
+ plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation (privilege escalation, data exfiltration)"})
796
+
797
+ plan.append({"phase": "reporting", "objective": "Generate pentest report"})
798
+
799
+ logger.info(f"Generated strategic plan for goal '{goal}': {plan}")
800
+ return plan
801
+
802
+ def evaluate_phase_completion(self) -> float:
803
+ phase_commands = [cmd for cmd in self.execution_history
804
+ if cmd.get('phase', '') == self.current_phase]
805
+ if not phase_commands:
806
+ return 0.0
807
+ successful = sum(1 for cmd in phase_commands if cmd['success'])
808
+ return successful / len(phase_commands)
809
+
810
+ def advance_phase(self):
811
+ if self.current_phase_index < len(self.current_plan) - 1:
812
+ self.current_phase_index += 1
813
+ self.current_phase = self.current_plan[self.current_phase_index]["phase"]
814
+ logger.info(f"Strategic Agent advancing to new phase: {self.current_phase.replace('_', ' ').title()}")
815
+ self.no_progress_count = 0
816
+ self.react_cycle_count = 0
817
+ else:
818
+ self.current_phase = "completed"
819
+ self.goal_achieved = True
820
+ logger.info("Strategic Agent: All planned phases completed. Goal achieved!")
821
+
822
+ def observe_result(self, command: str, output: str, success: bool):
823
+ logger.debug(f"Strategic Agent observing result for command '{command}': Success={success}")
824
+ self.execution_history.append({"command": command, "output": output, "success": success, "timestamp": datetime.now().isoformat()})
825
+ self.gathered_info.append(output)
826
+
827
+ self.analyze_command_output_strategic(command, output)
828
+
829
+ if not success:
830
+ self.no_progress_count += 1
831
+ else:
832
+ self.no_progress_count = 0
833
+
834
+ if success and self.current_phase_index < len(self.current_plan) - 1:
835
+ phase_completion = self.evaluate_phase_completion()
836
+ if phase_completion >= 0.8:
837
+ self.advance_phase()
838
+
839
+ def analyze_command_output_strategic(self, command: str, output: str):
840
+ """Strategic Agent performs deeper analysis of command output for vulnerabilities."""
841
+ try:
842
+ logger.debug(f"Analyzing strategic command output for: {command}")
843
+ if command.startswith("nmap"):
844
+ if "open" in output and "vulnerable" in output.lower():
845
+ self.ingest_vulnerability(
846
+ "Potential vulnerability found in NMAP scan",
847
+ "Medium",
848
+ "NMAP-SCAN"
849
+ )
850
+ port_matches = re.findall(r'(\d+)/tcp\s+open\s+(\S+)', output)
851
+ for port, service in port_matches:
852
+ self.gathered_info.append(f"Discovered open port {port} with service {service}")
853
+
854
+ elif command.startswith("nikto"):
855
+ if "OSVDB-" in output:
856
+ vuln_matches = re.findall(r'OSVDB-\d+:\s*(.+)', output)
857
+ for vuln in vuln_matches[:3]:
858
+ self.ingest_vulnerability(
859
+ f"Nikto vulnerability: {vuln}",
860
+ "High",
861
+ "NIKTO-SCAN"
862
+ )
863
+
864
+ elif command.startswith("sqlmap"):
865
+ if "injection" in output.lower():
866
+ self.ingest_vulnerability(
867
+ "SQL injection vulnerability detected",
868
+ "Critical",
869
+ "SQLMAP-SCAN"
870
+ )
871
+
872
+ elif command.startswith("adb"):
873
+ if "debuggable" in output.lower():
874
+ self.ingest_vulnerability(
875
+ "Debuggable Android application found",
876
+ "High",
877
+ "ADB-DEBUG"
878
+ )
879
+ if "permission" in output.lower() and "denied" in output.lower():
880
+ self.ingest_vulnerability(
881
+ "Permission issue detected on Android device",
882
+ "Medium",
883
+ "ADB-PERMISSION"
884
+ )
885
+ except Exception as e:
886
+ logger.error(f"Strategic Agent: Error analyzing command output: {e}", exc_info=True)
887
+
888
+ def ingest_vulnerability(self, description: str, severity: str, cve_id: Optional[str] = None, exploit_id: Optional[str] = None):
889
+ vulnerability = {
890
+ "description": description,
891
+ "severity": severity,
892
+ "timestamp": datetime.now().isoformat()
893
+ }
894
+ if cve_id:
895
+ vulnerability["cve_id"] = cve_id
896
+ if exploit_id:
897
+ vulnerability["exploit_id"] = exploit_id
898
+
899
+ self.identified_vulnerabilities.append(vulnerability)
900
+ logger.info(f"Strategic Agent identified vulnerability: {description} (Severity: {severity})")
901
+
902
+ # Instantiate the Strategic Agent Brain
903
+ strategic_brain = StrategicAgentBrain()
904
+
905
+ # --- Request Models for API Endpoints ---
906
+ class RAGRequest(BaseModel):
907
+ query: constr(min_length=3, max_length=500)
908
+ top_k: int = Field(5, gt=0, le=20)
909
+
910
+ class FirebaseQueryRequest(BaseModel):
911
+ goal: str
912
+ phase: str = None
913
+ limit: int = 10
914
+
915
+ class DeepSearchRequest(BaseModel):
916
+ device_info: str
917
+ os_version: str
918
+
919
+ class SetGoalRequest(BaseModel):
920
+ goal: str
921
+
922
+ class GetNextCommandRequest(BaseModel):
923
+ current_state: str
924
+ last_command_output: str
925
+ last_command_success: bool
926
+ execution_history_summary: List[Dict] = []
927
+ gathered_info_summary: List[str] = []
928
+ identified_vulnerabilities_summary: List[Dict] = []
929
+
930
+ class ObserveResultRequest(BaseModel):
931
+ command: str
932
+ output: str
933
+ success: bool
934
+
935
+ class LoadStrategicModelRequest(BaseModel):
936
+ model_url: str # Now expects a URL instead of a local path
937
+
938
+ # --- API Endpoints ---
939
+ @app.get("/health")
940
+ async def health_check():
941
+ """Endpoint to check the health of the service."""
942
+ logger.debug("Health check requested.")
943
+ return {"status": "ok", "message": "Knowledge service is running."}
944
+
945
+ @app.post("/rag/retrieve")
946
+ async def rag_retrieve_endpoint(request: RAGRequest):
947
+ logger.debug(f"RAG retrieve endpoint called with query: {request.query}")
948
+ try:
949
+ results = rag_index.retrieve(request.query, request.top_k)
950
+ return {"success": True, "data": {"results": results}, "error": None}
951
+ except Exception as e:
952
+ logger.error(f"RAG retrieval failed: {e}", exc_info=True)
953
+ raise HTTPException(status_code=500, detail=str(e))
954
+
955
+ @app.post("/firebase/query")
956
+ async def firebase_query_endpoint(request: FirebaseQueryRequest):
957
+ logger.debug(f"Firebase query endpoint called with goal: {request.goal}, phase: {request.phase}")
958
+ try:
959
+ results = firebase_kb.query(request.goal, request.phase, request.limit)
960
+ return {"success": True, "data": {"results": results}, "error": None}
961
+ except Exception as e:
962
+ logger.error(f"Firebase query failed: {e}", exc_info=True)
963
+ raise HTTPException(status_code=500, detail=str(e))
964
+
965
+ @app.post("/deep_search")
966
+ async def deep_search_endpoint(request: DeepSearchRequest):
967
+ logger.debug(f"Deep search endpoint called for device: {request.device_info}, OS: {request.os_version}")
968
+ try:
969
+ results = deep_search_engine.search_device_info(request.device_info, request.os_version)
970
+ results["public_resources"] = deep_search_engine.search_public_resources(request.device_info)
971
+ return {"success": True, "data": results, "error": None}
972
+ except Exception as e:
973
+ logger.error(f"Deep search failed: {e}", exc_info=True)
974
+ raise HTTPException(status_code=500, detail=str(e))
975
+
976
+ @app.post("/strategic_agent/load_model")
977
+ async def load_strategic_model(request: LoadStrategicModelRequest):
978
+ logger.info(f"Request to load strategic model: {request.model_url}")
979
+ success, message = await strategic_brain.load_strategic_llm(request.model_url)
980
+ if success:
981
+ logger.info(f"Strategic model loaded successfully: {message}")
982
+ return {"status": "success", "message": message, "model": strategic_brain.loaded_model_name}
983
+ else:
984
+ logger.error(f"Failed to load strategic model: {message}")
985
+ raise HTTPException(status_code=500, detail=message)
986
+
987
+ @app.post("/strategic_agent/unload_model")
988
+ async def unload_strategic_model():
989
+ logger.info("Request to unload strategic model.")
990
+ await strategic_brain.unload_strategic_llm()
991
+ return {"status": "success", "message": "Strategic LLM unloaded."}
992
+
993
+ @app.post("/strategic_agent/set_goal")
994
+ async def strategic_set_goal(request: SetGoalRequest):
995
+ logger.info(f"Strategic Agent received new goal: {request.goal}")
996
+ # Call the synchronous set_goal method
997
+ strategic_brain.set_goal(request.goal)
998
+ return {"status": "success", "message": f"Goal set to: {request.goal}"}
999
+
1000
+ @app.post("/strategic_agent/get_next_command")
1001
+ async def strategic_get_next_command(request: GetNextCommandRequest):
1002
+ logger.debug("Strategic Agent received request for next command.")
1003
+ # Update strategic brain's state with latest from execution agent
1004
+ strategic_brain.execution_history = request.execution_history_summary
1005
+ strategic_brain.gathered_info = request.gathered_info_summary
1006
+ strategic_brain.identified_vulnerabilities = request.identified_vulnerabilities_summary
1007
+
1008
+ # Simulate agent's thinking process
1009
+ command = strategic_brain.parse_llm_response(
1010
+ strategic_brain._get_llm_response(
1011
+ strategic_brain._generate_llm_prompt() # Generate prompt based on updated state
1012
+ )
1013
+ )
1014
+
1015
+ if command:
1016
+ strategic_brain.used_commands.add(command) # Ensure strategic agent tracks used commands
1017
+ logger.info(f"Strategic Agent generated command: {command}")
1018
+ return {"command": command, "status": "success"}
1019
+ else:
1020
+ # Fallback if strategic agent fails to generate a valid command
1021
+ target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', strategic_brain.current_goal or "")
1022
+ fallback_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
1023
+ logger.warning(f"Strategic Agent failed to generate command. Returning fallback: {fallback_ip}")
1024
+ # If no LLM is loaded, provide a more informative fallback
1025
+ if strategic_brain.llm is None:
1026
+ return {"command": f"echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {fallback_ip}'", "status": "fallback", "message": "No LLM loaded on Strategic Agent. Please load one from the frontend settings."}
1027
+ else:
1028
+ return {"command": f"nmap -sV -Pn {fallback_ip}", "status": "fallback", "message": "Strategic Agent could not determine a valid next command."}
1029
+
1030
+
1031
+ @app.post("/strategic_agent/observe_result")
1032
+ async def strategic_observe_result(request: ObserveResultRequest):
1033
+ logger.debug(f"Strategic Agent received observation for command: {request.command}, success: {request.success}")
1034
+ strategic_brain.observe_result(request.command, request.output, request.success)
1035
+ return {"status": "success", "message": "Observation received and processed."}
1036
+
1037
+ @app.get("/strategic_agent/get_status")
1038
+ async def strategic_get_status():
1039
+ logger.debug("Strategic Agent status requested.")
1040
+ return {
1041
+ "currentGoal": strategic_brain.current_goal,
1042
+ "currentPhase": strategic_brain.current_phase.replace('_', ' ').title(),
1043
+ "reactCycleCount": strategic_brain.react_cycle_count,
1044
+ "noProgressCount": strategic_brain.no_progress_count,
1045
+ "identifiedVulnerabilities": [v['description'] for v in strategic_brain.identified_vulnerabilities],
1046
+ "gatheredInfo": [info[:100] + "..." for info in strategic_brain.gathered_info[-5:]] if strategic_brain.gathered_info else [],
1047
+ "executionHistorySummary": [{
1048
+ "command": e['command'],
1049
+ "success": e['success'],
1050
+ "timestamp": e['timestamp']
1051
+ } for e in strategic_brain.execution_history[-10:]],
1052
+ "strategicPlan": strategic_brain.current_plan,
1053
+ "currentPhaseIndex": strategic_brain.current_phase_index,
1054
+ "goalAchieved": strategic_brain.goal_achieved,
1055
+ "strategicAgentStatus": "Running" if strategic_brain.current_goal and not strategic_brain.goal_achieved else "Idle",
1056
+ "loadedModel": strategic_brain.loaded_model_name # Return the name of the loaded model
1057
+ }
1058
+
1059
+ @app.get("/api/models")
1060
+ async def get_available_models_strategic():
1061
+ """List predefined Hugging Face models for strategic agent."""
1062
+ logger.debug("Request for available strategic models received.")
1063
+ # Explicitly return JSONResponse to ensure correct content type
1064
+ return JSONResponse(content=json.dumps(HUGGINGFACE_MODELS), media_type="application/json")
1065
+
1066
+ # --- Startup Event to Download All Models and Start ngrok Tunnel (Modified for HF Spaces) ---
1067
+ @app.on_event("startup")
1068
+ async def startup_event_download_models(): # Renamed function
1069
+ logger.info("Application startup event triggered. Attempting to download all predefined models.")
1070
+
1071
+ # Download all models
1072
+ for model_info in HUGGINGFACE_MODELS:
1073
+ model_url = model_info["url"]
1074
+ model_name = model_info["name"]
1075
+ model_filename = model_url.split('/')[-1]
1076
+ local_model_path = os.path.join(DOWNLOAD_DIR, model_filename)
1077
+
1078
+ if not os.path.exists(local_model_path):
1079
+ logger.info(f"Downloading model '{model_name}' from {model_url} to {local_model_path}...")
1080
+ try:
1081
+ response = requests.get(model_url, stream=True)
1082
+ response.raise_for_status()
1083
+ with open(local_model_path, 'wb') as f:
1084
+ for chunk in response.iter_content(chunk_size=8192):
1085
+ f.write(chunk)
1086
+ logger.info(f"Model '{model_name}' downloaded successfully.")
1087
+ except Exception as e:
1088
+ logger.error(f"Failed to download model '{model_name}': {e}", exc_info=True)
1089
+ else:
1090
+ logger.info(f"Model '{model_name}' already exists at {local_model_path}. Skipping download.")
1091
+ logger.info("Finished attempting to download all predefined models.")
1092
+
1093
+ # --- Shutdown Event (ngrok related parts removed) ---
1094
+ @app.on_event("shutdown")
1095
+ async def shutdown_event_cleanup(): # Renamed function
1096
+ logger.info("Application shutdown event triggered. Performing cleanup.")
1097
+ # No ngrok.kill() needed here as ngrok is not used
1098
+
1099
+ if __name__ == "__main__":
1100
+ import uvicorn
1101
+ logger.info("Starting FastAPI application on Hugging Face Spaces (port 7860)...")
1102
+ uvicorn.run(
1103
+ app,
1104
+ host="0.0.0.0",
1105
+ port=7860, # Standard port for Hugging Face Spaces
1106
+ log_level="info" # Changed to info for less verbose default output
1107
+ )