File size: 52,638 Bytes
e26acf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 |
# app.py (Strategic Agent Service for Hugging Face Spaces - CPU Only, Preload All Models, No ngrok)
import os
import json
import logging
import numpy as np
import requests
from fastapi import FastAPI, HTTPException, Depends, status
from pydantic import BaseModel, Field, constr
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datetime import datetime
import firebase_admin
from firebase_admin import credentials, firestore
from bs4 import BeautifulSoup
import re
from typing import List, Dict, Optional, Tuple
from cachetools import TTLCache
import gc
from llama_cpp import Llama
import asyncio
import nest_asyncio
from fastapi.responses import JSONResponse # Added for explicit JSONResponse
# Apply nest_asyncio to allow running asyncio.run() in environments with existing event loops
nest_asyncio.apply()
# --- Configuration ---
# Directory to store downloaded GGUF models within Hugging Face Space's writable space
DOWNLOAD_DIR = "./downloaded_models/" # Changed to a local directory within the Space
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
# Predefined Hugging Face GGUF model URLs for dynamic loading
HUGGINGFACE_MODELS = [
{
"name": "Foundation-Sec-8B-Q8_0",
"url": "https://huggingface.co/fdtn-ai/Foundation-Sec-8B-Q8_0-GGUF/resolve/main/foundation-sec-8b-q8_0.gguf"
},
{
"name": "Lily-Cybersecurity-7B-v0.2-Q8_0",
"url": "https://huggingface.co/Nekuromento/Lily-Cybersecurity-7B-v0.2-Q8_0-GGUF/resolve/main/lily-cybersecurity-7b-v0.2-q8_0.gguf"
},
{
"name": "SecurityLLM-GGUF (sarvam-m-q8_0)",
"url": "https://huggingface.co/QuantFactory/SecurityLLM-GGUF/resolve/main/sarvam-m-q8_0.gguf"
}
]
DATA_DIR = "./data" # Local data for Hugging Face Space
DEEP_SEARCH_CACHE_TTL = 3600
# --- ngrok Configuration (Removed) ---
# NGROK_AUTH_TOKEN and NGROK_STRATEGIC_AGENT_TUNNEL_URL are removed
# --- Logging Setup ---
logging.basicConfig(
level=logging.DEBUG, # Changed from INFO to DEBUG
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
logger.info("Logging initialized with DEBUG level.")
# Initialize FastAPI app
app = FastAPI(
title="Hugging Face Strategic Agent Service",
description="Provides knowledge base access and strategic reasoning for the pentest agent on Hugging Face Spaces.",
version="1.0.0"
)
# Initialize Firebase
firebase_creds_path = os.getenv("FIREBASE_CREDS_PATH", "cred.json")
db = None
if not firebase_admin._apps:
try:
if os.path.exists(firebase_creds_path):
cred = credentials.Certificate(firebase_creds_path)
firebase_admin.initialize_app(cred)
db = firestore.client()
logger.info("Firebase initialized successfully.")
else:
logger.warning(f"Firebase credentials file not found at {firebase_creds_path}. Firebase will not be initialized.")
except Exception as e:
logger.error(f"Failed to initialize Firebase: {e}. Ensure FIREBASE_CREDS_PATH is set correctly and the file exists.", exc_info=True)
# Global LLM instance for Strategic Agent
strategic_llm: Optional[Llama] = None
current_strategic_model_url: Optional[str] = None # Now tracks URL, not local path
# Supported tools (Strategic Agent needs to know these for command generation)
SUPPORTED_TOOLS = [
"nmap", "gobuster", "nikto", "sqlmap", "adb", "frida",
"drozer", "apktool", "msfconsole", "mobsfscan", "burpsuite",
"metasploit", "curl", "wget", "hydra", "john", "aircrack-ng"
]
# --- Deep Search Cache ---
deep_search_cache = TTLCache(maxsize=100, ttl=DEEP_SEARCH_CACHE_TTL)
# --- Enhanced System Instruction (English) ---
SYSTEM_INSTRUCTION = (
"You are an expert pentest agent. Strictly follow these rules:\n"
"1. Output ONLY valid shell commands\n"
"2. NEVER include timestamps, dates, or any text outside commands\n"
"3. Never repeat previous commands\n"
"4. Always verify command safety before execution\n\n"
"Example valid response:\n"
"nmap -sV 192.168.1.6\n\n"
"Key Principles:\n"
"- Never give up until the goal is achieved\n"
"- Learn from failures and adapt strategies\n"
"- Leverage all available knowledge and tools\n"
"- Break complex tasks into smaller achievable steps\n"
"- Always ensure actions are ethical and within scope\n\n"
"Available Tools:\n"
"- nmap: Network scanning and service detection\n"
"- gobuster: Web directory brute-forcing\n"
"- nikto: Web server vulnerability scanner\n"
"- sqlmap: SQL injection testing\n"
"- adb: Android Debug Bridge\n"
"- metasploit: Exploitation framework\n\n"
"Error Handling Examples:\n"
"Example 1 (Command Failure):\n"
" If nmap fails because host is down, try: nmap -Pn -sV 192.168.1.6\n"
"Example 2 (Web Server Error):\n"
" If web server returns 403, try: gobuster dir -u http://192.168.1.6 -w /usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt\n"
"Example 3 (ADB Connection Failed):\n"
" If ADB connection fails, try: adb kill-server && adb start-server"
)
# --- Firebase Knowledge Base Integration ---
class FirebaseKnowledgeBase:
def __init__(self):
self.collection = db.collection('knowledge_base') if db else None
def query(self, goal: str, phase: str = None, limit: int = 10) -> list:
if not db or not firebase_admin._apps: # Check if Firebase is initialized
logger.error("Firestore client not initialized. Cannot query knowledge base.")
return []
# Re-instantiate collection if it's None (e.g., if Firebase init failed initially)
if not hasattr(self, 'collection') or self.collection is None:
self.collection = db.collection('knowledge_base')
keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device']
if phase:
keywords.append(phase.lower())
try:
query_ref = self.collection
results = []
docs = query_ref.stream() # Use query_ref instead of self.collection directly
for doc in docs:
data = doc.to_dict()
text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}"
if any(keyword in text for keyword in keywords):
results.append(data)
if len(results) >= 10: # Use a fixed limit for stream
break
priority_order = {"high": 1, "medium": 2, "low": 3}
results.sort(key=lambda x: (
priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3),
x.get('metadata', {}).get('timestamp', 0)
))
return results[:10] # Ensure limit is applied
except Exception as e:
logger.error(f"Failed to query knowledge base: {e}", exc_info=True)
return []
# --- RAG Knowledge Index ---
class KnowledgeIndex:
def __init__(self, model_name="all-MiniLM-L6-v2"):
self.model = SentenceTransformer(
model_name,
cache_folder=os.path.join(DATA_DIR, "hf_cache") # Use local data dir for cache
)
self.knowledge_base = []
os.makedirs(DATA_DIR, exist_ok=True)
self.load_knowledge_from_file(os.path.join(DATA_DIR, 'knowledge_base.json'))
def load_knowledge_from_file(self, file_path):
logger.debug(f"Attempting to load knowledge from file: {file_path}")
if os.path.exists(file_path):
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
logger.error("Knowledge base file is not a list. Please check the file format.")
return
for item in data:
if isinstance(item, dict):
text = item.get('text', '')
source = item.get('source', 'local')
elif isinstance(item, str):
text = item
source = 'local'
else:
logger.warning(f"Skipping unsupported item type: {type(item)}")
continue
if text:
embedding = self.model.encode(text).tolist()
self.knowledge_base.append({'text': text, 'embedding': embedding, 'source': source})
logger.info(f"Loaded {len(self.knowledge_base)} items into RAG knowledge base.")
except Exception as e:
logger.error(f"Error loading knowledge from {file_path}: {e}", exc_info=True)
else:
logger.warning(f"Knowledge base file not found: {file_path}. RAG will operate on an empty knowledge base.")
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump([], f)
logger.info(f"Created empty knowledge base file at: {file_path}")
except Exception as e:
logger.error(f"Error creating empty knowledge base file at {file_path}: {e}", exc_info=True)
def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
if not self.knowledge_base:
logger.debug("Knowledge base is empty, no RAG retrieval possible.")
return []
try:
query_embedding = self.model.encode(query).reshape(1, -1)
embeddings = np.array([item['embedding'] for item in self.knowledge_base])
similarities = cosine_similarity(query_embedding, embeddings)[0]
top_indices = similarities.argsort()[-top_k:][::-1]
results = []
for i in top_indices:
results.append({
"text": self.knowledge_base[i]['text'],
"similarity": similarities[i],
"source": self.knowledge_base[i].get('source', 'RAG')
})
logger.debug(f"RAG retrieved {len(results)} results for query: '{query}'")
return results
except Exception as e:
logger.error(f"Error during RAG retrieval for query '{query}': {e}", exc_info=True)
return []
# --- Deep Search Engine ---
class DeepSearchEngine:
def __init__(self):
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
def search_device_info(self, device_info: str, os_version: str) -> dict:
logger.debug(f"Performing deep search for device: {device_info}, OS: {os_version}")
results = {
"device": device_info,
"os_version": os_version,
"vulnerabilities": [],
"exploits": [],
"recommendations": []
}
try:
cve_results = self.search_cve(device_info, os_version)
results["vulnerabilities"] = cve_results
exploit_results = self.search_exploits(device_info, os_version)
results["exploits"] = exploit_results
recommendations = self.get_security_recommendations(os_version)
results["recommendations"] = recommendations
logger.debug("Deep search completed.")
except Exception as e:
logger.error(f"Deep search failed: {e}", exc_info=True)
return results
def search_cve(self, device: str, os_version: str) -> list:
cves = []
try:
query = f"{device} {os_version} CVE"
search_url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}"
logger.debug(f"Searching CVE Mitre: {search_url}")
response = requests.get(search_url, headers=self.headers)
response.raise_for_status() # Raise an exception for HTTP errors
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
table = soup.find('div', id='TableWithRules')
if table:
rows = table.find_all('tr')[1:]
for row in rows:
cols = row.find_all('td')
if len(cols) >= 2:
cve_id = cols[0].get_text(strip=True)
description = cols[1].get_text(strip=True)
cves.append({
"cve_id": cve_id,
"description": description,
"source": "CVE Mitre"
})
logger.debug(f"Found {len(cves)} CVEs.")
return cves[:10]
except Exception as e:
logger.error(f"CVE search failed: {e}", exc_info=True)
return []
def search_exploits(self, device: str, os_version: str) -> list:
exploits = []
try:
query = f"{device} {os_version}"
search_url = f"https://www.exploit-db.com/search?q={query}"
logger.debug(f"Searching ExploitDB: {search_url}")
response = requests.get(search_url, headers=self.headers)
response.raise_for_status() # Raise an exception for HTTP errors
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
cards = soup.select('.card .card-title')
for card in cards:
title = card.get_text(strip=True)
link = card.find('a')['href']
if not link.startswith('http'):
link = f"https://www.exploit-db.com{link}"
exploits.append({
"title": title,
"link": link,
"source": "ExploitDB"
})
logger.debug(f"Found {len(exploits)} exploits.")
return exploits[:10]
except Exception as e:
logger.error(f"Exploit search failed: {e}", exc_info=True)
return []
def get_security_recommendations(self, os_version: str) -> list:
recommendations = []
try:
logger.debug(f"Getting security recommendations for OS: {os_version}")
if "android" in os_version.lower():
url = "https://source.android.com/docs/security/bulletin"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
versions = soup.select('.devsite-article-body h2')
for version in versions:
if os_version in version.get_text():
next_ul = version.find_next('ul')
if next_ul:
items = next_ul.select('li')
for item in items:
recommendations.append(item.get_text(strip=True))
elif "ios" in os_version.lower():
url = "https://support.apple.com/en-us/HT201222"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
sections = soup.select('#sections')
for section in sections:
if os_version in section.get_text():
items = section.select('li')
for item in items:
recommendations.append(item.get_text(strip=True))
logger.debug(f"Found {len(recommendations)} recommendations.")
return recommendations[:5]
except Exception as e:
logger.error(f"Security recommendations search failed: {e}", exc_info=True)
return []
def search_public_resources(self, device_info: str) -> list:
resources = []
try:
logger.debug(f"Searching public resources for device: {device_info}")
github_url = f"https://github.com/search?q={device_info.replace(' ', '+')}+pentest"
response = requests.get(github_url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
repos = soup.select('.repo-list-item')
for repo in repos:
title = repo.select_one('.v-align-middle').get_text(strip=True)
description = repo.select_one('.mb-1').get_text(strip=True) if repo.select_one('.mb-1') else ""
url = f"https://github.com{repo.select_one('.v-align-middle')['href']}"
resources.append({
"title": title,
"description": description,
"url": url,
"source": "GitHub"
})
forum_url = f"https://hackforums.net/search.php?action=finduserthreads&keywords={device_info.replace(' ', '+')}"
response = requests.get(forum_url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
threads = soup.select('.thread')
for thread in threads:
title = thread.select_one('.threadtitle').get_text(strip=True)
url = f"https://hackforums.net{thread.select_one('.threadtitle a')['href']}"
resources.append({
"title": title,
"description": "Forum discussion",
"url": url,
"source": "HackForums"
})
logger.debug(f"Found {len(resources)} public resources.")
return resources[:10]
except Exception as e:
logger.error(f"Public resources search failed: {e}", exc_info=True)
return []
# --- Initialize Services (Local to Strategic Agent) ---
firebase_kb = FirebaseKnowledgeBase()
rag_index = KnowledgeIndex()
deep_search_engine = DeepSearchEngine()
# --- Strategic Agent Brain (formerly SmartExecutionEngine logic) ---
class StrategicAgentBrain:
def __init__(self):
self.llm: Optional[Llama] = None
self.current_goal: Optional[str] = None
self.current_phase: str = "initial_reconnaissance"
self.current_plan: List[Dict] = []
self.current_phase_index: int = 0
self.identified_vulnerabilities: List[Dict] = []
self.gathered_info: List[str] = []
self.command_retry_counts: Dict[str, int] = {}
self.conversation_history: List[Dict] = []
self.used_commands = set()
self.execution_history = []
self.goal_achieved = False
self.no_progress_count = 0
self.react_cycle_count = 0
self.loaded_model_name: Optional[str] = None # To store the name of the loaded model
logger.info("StrategicAgentBrain initialized.")
async def load_strategic_llm(self, model_url: str):
global strategic_llm, current_strategic_model_url
logger.info(f"Attempting to load strategic LLM from URL: {model_url}")
# Determine local path for the model
model_filename = model_url.split('/')[-1]
local_model_path = os.path.join(DOWNLOAD_DIR, model_filename)
if strategic_llm and current_strategic_model_url == model_url:
logger.info(f"Strategic LLM model from {model_url} is already loaded.")
self.llm = strategic_llm
return True, f"Model '{self.loaded_model_name}' is already loaded."
# If a model is currently loaded, unload it first
if strategic_llm:
await self.unload_strategic_llm()
# Ensure model is downloaded before attempting to load
if not os.path.exists(local_model_path):
logger.info(f"Model not found locally. Attempting to download from {model_url} to {local_model_path}...")
try:
response = requests.get(model_url, stream=True)
response.raise_for_status()
with open(local_model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Model downloaded successfully to {local_model_path}.")
except Exception as e:
logger.error(f"Failed to download model from {model_url}: {e}", exc_info=True)
return False, f"Failed to download model: {str(e)}"
try:
logger.info(f"Loading Strategic LLM model from {local_model_path}...")
strategic_llm = Llama(
model_path=local_model_path,
n_ctx=3096,
n_gpu_layers=0, # Explicitly set to 0 for CPU-only
n_threads=os.cpu_count(), # Use all available CPU threads
n_batch=512,
verbose=False
)
current_strategic_model_url = model_url
self.llm = strategic_llm
self.loaded_model_name = model_filename # Store the filename
logger.info(f"Strategic LLM model {model_filename} loaded successfully (CPU-only).")
return True, f"Model '{model_filename}' loaded successfully (CPU-only)."
except Exception as e:
logger.error(f"Failed to load Strategic LLM model from {local_model_path}: {e}", exc_info=True)
strategic_llm = None
current_strategic_model_url = None
self.llm = None
self.loaded_model_name = None
return False, f"Failed to load model: {str(e)}"
async def unload_strategic_llm(self):
global strategic_llm, current_strategic_model_url
if strategic_llm:
logger.info("Unloading Strategic LLM model...")
del strategic_llm
strategic_llm = None
current_strategic_model_url = None
self.llm = None
self.loaded_model_name = None
gc.collect()
logger.info("Strategic LLM model unloaded.")
def _get_rag_context(self, query: str) -> str:
results = rag_index.retrieve(query)
if not results:
return ""
rag_context = "Relevant Knowledge for Current Context:\n"
for i, result in enumerate(results):
text = result.get('text', '') or result.get('completion', '')
source = result.get('source', 'RAG')
rag_context += f"{i+1}. [{source}] {text}\n"
return rag_context
def _get_firebase_knowledge(self, goal: str, phase: str = None) -> str:
if not db or not firebase_admin._apps: # Check if Firebase is initialized
logger.error("Firestore client not initialized. Cannot query knowledge base.")
return ""
# Re-instantiate collection if it's None (e.g., if Firebase init failed initially)
if not hasattr(self, 'collection') or self.collection is None:
self.collection = db.collection('knowledge_base')
keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device']
if phase:
keywords.append(phase.lower())
try:
query_ref = self.collection
results = []
docs = query_ref.stream() # Use query_ref instead of self.collection directly
for doc in docs:
data = doc.to_dict()
text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}"
if any(keyword in text for keyword in keywords):
results.append(data)
if len(results) >= 10: # Use a fixed limit for stream
break
priority_order = {"high": 1, "medium": 2, "low": 3}
results.sort(key=lambda x: (
priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3),
x.get('metadata', {}).get('timestamp', 0)
))
return results[:10] # Ensure limit is applied
except Exception as e:
logger.error(f"Failed to query knowledge base: {e}", exc_info=True)
return ""
def extract_device_info(self) -> str:
for info in self.gathered_info:
if "model" in info.lower() or "device" in info.lower():
match = re.search(r'(?:model|device)\s*[:=]\s*([^\n]+)', info, re.IGNORECASE)
if match:
return match.group(1).strip()
ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
return ip_match.group(0) if ip_match else "Unknown Device"
def extract_os_version(self) -> str:
for info in self.gathered_info:
if "android" in info.lower() or "ios" in info.lower() or "os" in info.lower():
android_match = re.search(r'android\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE)
if android_match:
return f"Android {android_match.group(1)}"
ios_match = re.search(r'ios\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE)
if ios_match:
return f"iOS {ios_match.group(1)}"
linux_match = re.search(r'linux\s+kernel\s+(\d+\.\d+\.\d+)', info, re.IGNORECASE)
if linux_match:
return f"Linux {linux_match.group(1)}"
return "Unknown OS Version"
def format_deep_search_results(self, results: dict) -> str:
context = "Deep Search Results:\n"
context += f"Device: {results.get('device', 'Unknown')}\n"
context += f"OS Version: {results.get('os_version', 'Unknown')}\n\n"
if results.get('vulnerabilities'):
context += "Discovered Vulnerabilities:\n"
for i, vuln in enumerate(results['vulnerabilities'][:5], 1):
context += f"{i}. {vuln.get('cve_id', 'CVE-XXXX-XXXX')}: {vuln.get('description', 'No description')}\n"
context += "\n"
if results.get('exploits'):
context += "Available Exploits:\n"
for i, exploit in enumerate(results['exploits'][:5], 1):
context += f"{i}. {exploit.get('title', 'Untitled exploit')} [Source: {exploit.get('source', 'Unknown')}]\n"
context += "\n"
if results.get('recommendations'):
context += "Security Recommendations:\n"
for i, rec in enumerate(results['recommendations'][:3], 1):
context += f"{i}. {rec}\n"
context += "\n"
if results.get('public_resources'):
context += "Public Resources:\n"
for i, res in enumerate(results['public_resources'][:3], 1):
context += f"{i}. {res.get('title', 'Untitled resource')} [Source: {res.get('source', 'Unknown')}]\n"
return context
def generate_deep_search_prompt(self, context: str) -> str:
return f"""
You are an expert pentester. Below are deep search results for the target device.
Use this information to generate the next penetration testing command.{context}
Current Goal: {self.current_goal}
Current Phase: {self.current_phase}
Recent Command History:{', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'}
Based on this information, what is the SINGLE MOST EFFECTIVE shell command to execute next?
Focus on exploiting the most critical vulnerabilities or gathering more information.
Response Format:
Command: <your_command_here>
"""
def _generate_llm_prompt(self) -> str:
rag_context = self._get_rag_context(f"{self.current_goal} {self.current_phase}")
firebase_knowledge = self._get_firebase_knowledge(self.current_goal, self.current_phase)
history_context = "\n".join(
f"{entry['role']}: {entry['content']}" for entry in self.conversation_history[-2:]
)
execution_history = "\n".join(
f"Command: {res['command']}\nResult: {res['output'][:100]}...\nSuccess: {res['success']}"
for res in self.execution_history[-2:]
) if self.execution_history else "No previous results."
strategic_advice = self._get_rag_context(self.current_phase) # Using RAG for strategic advice too
def shorten_text(text, max_length=300):
if len(text) > max_length:
return text[:max_length] + "... [truncated]"
return text
rag_context = shorten_text(rag_context, max_length=200)
firebase_knowledge = shorten_text(firebase_knowledge, max_length=200)
strategic_advice = shorten_text(strategic_advice, max_length=100)
history_context = shorten_text(history_context, max_length=150)
execution_history = shorten_text(execution_history, max_length=500)
prompt = f"""
System Instructions: {SYSTEM_INSTRUCTION}
Current Goal: '{self.current_goal}'
Current Phase: {self.current_phase} - {self.current_plan[self.current_phase_index]['objective'] if self.current_plan and self.current_phase_index < len(self.current_plan) else 'No objective'}
Based on the following knowledge and previous results, generate the SINGLE, VALID SHELL COMMAND to advance the penetration testing process.
**Knowledge from External Services (RAG & Firebase):**
{rag_context}
{firebase_knowledge}
**Previous Execution Results:**
{execution_history}
**Recent Conversation History:**
{history_context}
**Strategic Advice for Current Phase:**
{strategic_advice}
***CRITICAL RULES FOR OUTPUT:***
1. **OUTPUT ONLY THE COMMAND.**
2. **DO NOT INCLUDE ANY REASONING, THOUGHTS, EXPLANATIONS, OR ANY OTHER TEXT.**
3. The command MUST be directly executable in a Linux terminal.
4. Avoid repeating these recent commands: {', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'}
5. If the previous command failed, try a different approach or a related tool.
6. For the 'android_enumeration' phase, prioritize ADB commands.
Example valid commands for initial reconnaissance of an Android phone:
nmap -sV -Pn 192.168.1.14
adb devices
adb connect 192.168.1.14:5555
Command:
"""
return prompt
def _get_llm_response(self, custom_prompt: str = None) -> str:
if not self.llm:
logger.error("Strategic LLM instance is None. Cannot get response. Please load a model first.")
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
return f"Command: echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {target_ip}'"
prompt = custom_prompt if custom_prompt else self._generate_llm_prompt()
logger.info(f"Sending prompt to Strategic LLM:\n{prompt[:500]}...")
try:
response = self.llm(
prompt,
max_tokens=512,
temperature=0.3,
stop=["\n"]
)
llm_response = response['choices'][0]['text'].strip()
logger.info(f"Strategic LLM raw response: {llm_response}")
if not llm_response:
logger.warning("Strategic LLM returned an empty response. Using fallback command.")
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
return f"Command: nmap -sV -Pn {target_ip}"
return llm_response
except Exception as e:
logger.error(f"Error during Strategic LLM inference: {e}", exc_info=True)
logger.warning("Strategic LLM inference failed. Using fallback command.")
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
return f"Command: nmap -sV -Pn {target_ip}"
def parse_llm_response(self, response: str) -> str:
logger.info(f"Attempting to parse LLM response: '{response}'")
command = None
try:
code_block = re.search(r'```(?:bash|sh)?\s*([\s\S]*?)```', response)
if code_block:
command = code_block.group(1).strip()
logger.info(f"Command extracted from code block: '{command}'")
if not command:
command_match = re.search(r'^\s*Command\s*:\s*(.+)$', response, re.MULTILINE | re.IGNORECASE)
if command_match:
command = command_match.group(1).strip()
logger.info(f"Command extracted from 'Command:' line: '{command}'")
if not command:
stripped_response = response.strip()
if any(stripped_response.startswith(tool) for tool in SUPPORTED_TOOLS):
command = stripped_response
logger.info(f"Command extracted as direct supported tool command: '{command}'")
if command:
original_command = command
command = re.sub(r'^\s*(Command|Answer|Note|Result)\s*[:.-]?\s*', '', command, flags=re.IGNORECASE).strip()
logger.info(f"Cleaned command: from '{original_command}' to '{command}'")
if not re.match(r'^[a-zA-Z0-9_./:;= \-\'"\s]+$', command):
logger.error(f"Invalid command characters detected after cleanup: '{command}'")
return None
if re.search(r'(reason|thought|explanation|rationale|note|result):', command, re.IGNORECASE):
logger.warning(f"Command '{command}' appears to be reasoning/explanation. Rejecting.")
return None
if command not in self.used_commands:
self.used_commands.add(command)
logger.info(f"Returning valid and new command: '{command}'")
return command
else:
logger.warning(f"Command '{command}' already used. Skipping.")
return None
else:
logger.warning("No valid command could be extracted from LLM response based on strict rules.")
return None
except Exception as e:
logger.error(f"Error parsing LLM response: {e}", exc_info=True)
return None
def set_goal(self, goal: str):
self.current_goal = goal
self.goal_achieved = False
self.react_cycle_count = 0
self.no_progress_count = 0
self.current_plan = self._generate_strategic_plan(goal)
self.current_phase_index = 0
self.identified_vulnerabilities = []
self.gathered_info = []
self.command_retry_counts = {}
self.conversation_history = [{"role": "user", "content": f"New goal set: {goal}"}]
self.used_commands.clear()
self.execution_history = []
self.goal_achieved = False
logger.info(f"Strategic Agent Goal set: {goal}. Starting initial reconnaissance.")
def _generate_strategic_plan(self, goal: str) -> List[Dict]:
logger.debug(f"Generating strategic plan for goal: {goal}")
plan = []
goal_lower = goal.lower()
plan.append({"phase": "initial_reconnaissance", "objective": f"Perform initial reconnaissance for {goal}"})
if "web" in goal_lower or "http" in goal_lower:
plan.append({"phase": "web_enumeration", "objective": "Enumerate web server for directories and files"})
plan.append({"phase": "web_vulnerability_analysis", "objective": "Analyze web vulnerabilities (SQLi, XSS, etc.)"})
plan.append({"phase": "web_exploitation", "objective": "Attempt to exploit web vulnerabilities"})
plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation activities"})
elif "android" in goal_lower or "mobile" in goal_lower or "adb" in goal_lower:
plan.append({"phase": "android_enumeration", "objective": "Enumerate Android device via ADB"})
plan.append({"phase": "android_app_analysis", "objective": "Analyze Android application for vulnerabilities"})
plan.append({"phase": "android_exploitation", "objective": "Attempt to exploit Android vulnerabilities"})
plan.append({"phase": "data_extraction", "objective": "Extract sensitive data from device"})
else:
plan.append({"phase": "network_scanning", "objective": "Perform detailed network scanning"})
plan.append({"phase": "service_enumeration", "objective": "Enumerate services and identify versions"})
plan.append({"phase": "vulnerability_analysis", "objective": "Analyze services for vulnerabilities"})
plan.append({"phase": "exploitation", "objective": "Attempt to exploit vulnerabilities"})
plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation (privilege escalation, data exfiltration)"})
plan.append({"phase": "reporting", "objective": "Generate pentest report"})
logger.info(f"Generated strategic plan for goal '{goal}': {plan}")
return plan
def evaluate_phase_completion(self) -> float:
phase_commands = [cmd for cmd in self.execution_history
if cmd.get('phase', '') == self.current_phase]
if not phase_commands:
return 0.0
successful = sum(1 for cmd in phase_commands if cmd['success'])
return successful / len(phase_commands)
def advance_phase(self):
if self.current_phase_index < len(self.current_plan) - 1:
self.current_phase_index += 1
self.current_phase = self.current_plan[self.current_phase_index]["phase"]
logger.info(f"Strategic Agent advancing to new phase: {self.current_phase.replace('_', ' ').title()}")
self.no_progress_count = 0
self.react_cycle_count = 0
else:
self.current_phase = "completed"
self.goal_achieved = True
logger.info("Strategic Agent: All planned phases completed. Goal achieved!")
def observe_result(self, command: str, output: str, success: bool):
logger.debug(f"Strategic Agent observing result for command '{command}': Success={success}")
self.execution_history.append({"command": command, "output": output, "success": success, "timestamp": datetime.now().isoformat()})
self.gathered_info.append(output)
self.analyze_command_output_strategic(command, output)
if not success:
self.no_progress_count += 1
else:
self.no_progress_count = 0
if success and self.current_phase_index < len(self.current_plan) - 1:
phase_completion = self.evaluate_phase_completion()
if phase_completion >= 0.8:
self.advance_phase()
def analyze_command_output_strategic(self, command: str, output: str):
"""Strategic Agent performs deeper analysis of command output for vulnerabilities."""
try:
logger.debug(f"Analyzing strategic command output for: {command}")
if command.startswith("nmap"):
if "open" in output and "vulnerable" in output.lower():
self.ingest_vulnerability(
"Potential vulnerability found in NMAP scan",
"Medium",
"NMAP-SCAN"
)
port_matches = re.findall(r'(\d+)/tcp\s+open\s+(\S+)', output)
for port, service in port_matches:
self.gathered_info.append(f"Discovered open port {port} with service {service}")
elif command.startswith("nikto"):
if "OSVDB-" in output:
vuln_matches = re.findall(r'OSVDB-\d+:\s*(.+)', output)
for vuln in vuln_matches[:3]:
self.ingest_vulnerability(
f"Nikto vulnerability: {vuln}",
"High",
"NIKTO-SCAN"
)
elif command.startswith("sqlmap"):
if "injection" in output.lower():
self.ingest_vulnerability(
"SQL injection vulnerability detected",
"Critical",
"SQLMAP-SCAN"
)
elif command.startswith("adb"):
if "debuggable" in output.lower():
self.ingest_vulnerability(
"Debuggable Android application found",
"High",
"ADB-DEBUG"
)
if "permission" in output.lower() and "denied" in output.lower():
self.ingest_vulnerability(
"Permission issue detected on Android device",
"Medium",
"ADB-PERMISSION"
)
except Exception as e:
logger.error(f"Strategic Agent: Error analyzing command output: {e}", exc_info=True)
def ingest_vulnerability(self, description: str, severity: str, cve_id: Optional[str] = None, exploit_id: Optional[str] = None):
vulnerability = {
"description": description,
"severity": severity,
"timestamp": datetime.now().isoformat()
}
if cve_id:
vulnerability["cve_id"] = cve_id
if exploit_id:
vulnerability["exploit_id"] = exploit_id
self.identified_vulnerabilities.append(vulnerability)
logger.info(f"Strategic Agent identified vulnerability: {description} (Severity: {severity})")
# Instantiate the Strategic Agent Brain
strategic_brain = StrategicAgentBrain()
# --- Request Models for API Endpoints ---
class RAGRequest(BaseModel):
query: constr(min_length=3, max_length=500)
top_k: int = Field(5, gt=0, le=20)
class FirebaseQueryRequest(BaseModel):
goal: str
phase: str = None
limit: int = 10
class DeepSearchRequest(BaseModel):
device_info: str
os_version: str
class SetGoalRequest(BaseModel):
goal: str
class GetNextCommandRequest(BaseModel):
current_state: str
last_command_output: str
last_command_success: bool
execution_history_summary: List[Dict] = []
gathered_info_summary: List[str] = []
identified_vulnerabilities_summary: List[Dict] = []
class ObserveResultRequest(BaseModel):
command: str
output: str
success: bool
class LoadStrategicModelRequest(BaseModel):
model_url: str # Now expects a URL instead of a local path
# --- API Endpoints ---
@app.get("/health")
async def health_check():
"""Endpoint to check the health of the service."""
logger.debug("Health check requested.")
return {"status": "ok", "message": "Knowledge service is running."}
@app.post("/rag/retrieve")
async def rag_retrieve_endpoint(request: RAGRequest):
logger.debug(f"RAG retrieve endpoint called with query: {request.query}")
try:
results = rag_index.retrieve(request.query, request.top_k)
return {"success": True, "data": {"results": results}, "error": None}
except Exception as e:
logger.error(f"RAG retrieval failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/firebase/query")
async def firebase_query_endpoint(request: FirebaseQueryRequest):
logger.debug(f"Firebase query endpoint called with goal: {request.goal}, phase: {request.phase}")
try:
results = firebase_kb.query(request.goal, request.phase, request.limit)
return {"success": True, "data": {"results": results}, "error": None}
except Exception as e:
logger.error(f"Firebase query failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/deep_search")
async def deep_search_endpoint(request: DeepSearchRequest):
logger.debug(f"Deep search endpoint called for device: {request.device_info}, OS: {request.os_version}")
try:
results = deep_search_engine.search_device_info(request.device_info, request.os_version)
results["public_resources"] = deep_search_engine.search_public_resources(request.device_info)
return {"success": True, "data": results, "error": None}
except Exception as e:
logger.error(f"Deep search failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/strategic_agent/load_model")
async def load_strategic_model(request: LoadStrategicModelRequest):
logger.info(f"Request to load strategic model: {request.model_url}")
success, message = await strategic_brain.load_strategic_llm(request.model_url)
if success:
logger.info(f"Strategic model loaded successfully: {message}")
return {"status": "success", "message": message, "model": strategic_brain.loaded_model_name}
else:
logger.error(f"Failed to load strategic model: {message}")
raise HTTPException(status_code=500, detail=message)
@app.post("/strategic_agent/unload_model")
async def unload_strategic_model():
logger.info("Request to unload strategic model.")
await strategic_brain.unload_strategic_llm()
return {"status": "success", "message": "Strategic LLM unloaded."}
@app.post("/strategic_agent/set_goal")
async def strategic_set_goal(request: SetGoalRequest):
logger.info(f"Strategic Agent received new goal: {request.goal}")
# Call the synchronous set_goal method
strategic_brain.set_goal(request.goal)
return {"status": "success", "message": f"Goal set to: {request.goal}"}
@app.post("/strategic_agent/get_next_command")
async def strategic_get_next_command(request: GetNextCommandRequest):
logger.debug("Strategic Agent received request for next command.")
# Update strategic brain's state with latest from execution agent
strategic_brain.execution_history = request.execution_history_summary
strategic_brain.gathered_info = request.gathered_info_summary
strategic_brain.identified_vulnerabilities = request.identified_vulnerabilities_summary
# Simulate agent's thinking process
command = strategic_brain.parse_llm_response(
strategic_brain._get_llm_response(
strategic_brain._generate_llm_prompt() # Generate prompt based on updated state
)
)
if command:
strategic_brain.used_commands.add(command) # Ensure strategic agent tracks used commands
logger.info(f"Strategic Agent generated command: {command}")
return {"command": command, "status": "success"}
else:
# Fallback if strategic agent fails to generate a valid command
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', strategic_brain.current_goal or "")
fallback_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
logger.warning(f"Strategic Agent failed to generate command. Returning fallback: {fallback_ip}")
# If no LLM is loaded, provide a more informative fallback
if strategic_brain.llm is None:
return {"command": f"echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {fallback_ip}'", "status": "fallback", "message": "No LLM loaded on Strategic Agent. Please load one from the frontend settings."}
else:
return {"command": f"nmap -sV -Pn {fallback_ip}", "status": "fallback", "message": "Strategic Agent could not determine a valid next command."}
@app.post("/strategic_agent/observe_result")
async def strategic_observe_result(request: ObserveResultRequest):
logger.debug(f"Strategic Agent received observation for command: {request.command}, success: {request.success}")
strategic_brain.observe_result(request.command, request.output, request.success)
return {"status": "success", "message": "Observation received and processed."}
@app.get("/strategic_agent/get_status")
async def strategic_get_status():
logger.debug("Strategic Agent status requested.")
return {
"currentGoal": strategic_brain.current_goal,
"currentPhase": strategic_brain.current_phase.replace('_', ' ').title(),
"reactCycleCount": strategic_brain.react_cycle_count,
"noProgressCount": strategic_brain.no_progress_count,
"identifiedVulnerabilities": [v['description'] for v in strategic_brain.identified_vulnerabilities],
"gatheredInfo": [info[:100] + "..." for info in strategic_brain.gathered_info[-5:]] if strategic_brain.gathered_info else [],
"executionHistorySummary": [{
"command": e['command'],
"success": e['success'],
"timestamp": e['timestamp']
} for e in strategic_brain.execution_history[-10:]],
"strategicPlan": strategic_brain.current_plan,
"currentPhaseIndex": strategic_brain.current_phase_index,
"goalAchieved": strategic_brain.goal_achieved,
"strategicAgentStatus": "Running" if strategic_brain.current_goal and not strategic_brain.goal_achieved else "Idle",
"loadedModel": strategic_brain.loaded_model_name # Return the name of the loaded model
}
@app.get("/api/models")
async def get_available_models_strategic():
"""List predefined Hugging Face models for strategic agent."""
logger.debug("Request for available strategic models received.")
# Explicitly return JSONResponse to ensure correct content type
return JSONResponse(content=json.dumps(HUGGINGFACE_MODELS), media_type="application/json")
# --- Startup Event to Download All Models and Start ngrok Tunnel (Modified for HF Spaces) ---
@app.on_event("startup")
async def startup_event_download_models(): # Renamed function
logger.info("Application startup event triggered. Attempting to download all predefined models.")
# Download all models
for model_info in HUGGINGFACE_MODELS:
model_url = model_info["url"]
model_name = model_info["name"]
model_filename = model_url.split('/')[-1]
local_model_path = os.path.join(DOWNLOAD_DIR, model_filename)
if not os.path.exists(local_model_path):
logger.info(f"Downloading model '{model_name}' from {model_url} to {local_model_path}...")
try:
response = requests.get(model_url, stream=True)
response.raise_for_status()
with open(local_model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Model '{model_name}' downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download model '{model_name}': {e}", exc_info=True)
else:
logger.info(f"Model '{model_name}' already exists at {local_model_path}. Skipping download.")
logger.info("Finished attempting to download all predefined models.")
# --- Shutdown Event (ngrok related parts removed) ---
@app.on_event("shutdown")
async def shutdown_event_cleanup(): # Renamed function
logger.info("Application shutdown event triggered. Performing cleanup.")
# No ngrok.kill() needed here as ngrok is not used
if __name__ == "__main__":
import uvicorn
logger.info("Starting FastAPI application on Hugging Face Spaces (port 7860)...")
uvicorn.run(
app,
host="0.0.0.0",
port=7860, # Standard port for Hugging Face Spaces
log_level="info" # Changed to info for less verbose default output
)
|