|
import gradio as gr |
|
import tempfile |
|
import os |
|
import fitz |
|
import uuid |
|
import shutil |
|
from pymilvus import MilvusClient |
|
import json |
|
import sqlite3 |
|
from datetime import datetime |
|
import hashlib |
|
import bcrypt |
|
import re |
|
from typing import List, Dict, Tuple, Optional |
|
import threading |
|
import queue |
|
import requests |
|
import base64 |
|
from PIL import Image |
|
import io |
|
import schemdraw |
|
import schemdraw.elements as elm |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import io |
|
import schemdraw |
|
import schemdraw.elements as elm |
|
import matplotlib.pyplot as plt |
|
|
|
from middleware import Middleware |
|
from rag import Rag |
|
from pathlib import Path |
|
import subprocess |
|
import getpass |
|
|
|
from dotenv import load_dotenv, dotenv_values |
|
import dotenv |
|
import platform |
|
import time |
|
from pptxtopdf import convert |
|
|
|
|
|
try: |
|
from docx import Document |
|
from docx.shared import Inches, Pt |
|
from docx.enum.text import WD_ALIGN_PARAGRAPH |
|
from docx.enum.style import WD_STYLE_TYPE |
|
from docx.oxml.shared import OxmlElement, qn |
|
from docx.oxml.ns import nsdecls |
|
from docx.oxml import parse_xml |
|
DOCX_AVAILABLE = True |
|
except ImportError: |
|
DOCX_AVAILABLE = False |
|
print("Warning: python-docx not available. DOC export will be disabled.") |
|
|
|
try: |
|
import openpyxl |
|
from openpyxl import Workbook |
|
from openpyxl.styles import Font, PatternFill, Alignment, Border, Side |
|
from openpyxl.chart import BarChart, LineChart, PieChart, Reference |
|
from openpyxl.utils.dataframe import dataframe_to_rows |
|
import pandas as pd |
|
EXCEL_AVAILABLE = True |
|
except ImportError: |
|
EXCEL_AVAILABLE = False |
|
print("Warning: openpyxl/pandas not available. Excel export will be disabled.") |
|
|
|
|
|
dotenv_file = dotenv.find_dotenv() |
|
dotenv.load_dotenv(dotenv_file) |
|
|
|
|
|
|
|
rag = Rag() |
|
|
|
|
|
class DatabaseManager: |
|
def __init__(self, db_path="app_database.db"): |
|
self.db_path = db_path |
|
self.init_database() |
|
|
|
def init_database(self): |
|
"""Initialize database tables""" |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS users ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
username TEXT UNIQUE NOT NULL, |
|
password_hash TEXT NOT NULL, |
|
team TEXT NOT NULL, |
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
|
) |
|
''') |
|
|
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS chat_history ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
user_id INTEGER, |
|
query TEXT NOT NULL, |
|
response TEXT NOT NULL, |
|
cited_pages TEXT, |
|
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
FOREIGN KEY (user_id) REFERENCES users (id) |
|
) |
|
''') |
|
|
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS document_collections ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
collection_name TEXT UNIQUE NOT NULL, |
|
team TEXT NOT NULL, |
|
uploaded_by INTEGER, |
|
upload_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
|
file_count INTEGER DEFAULT 0, |
|
FOREIGN KEY (uploaded_by) REFERENCES users (id) |
|
) |
|
''') |
|
|
|
conn.commit() |
|
conn.close() |
|
|
|
def create_user(self, username: str, password: str, team: str) -> bool: |
|
"""Create a new user""" |
|
try: |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
|
|
password_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()) |
|
|
|
cursor.execute( |
|
'INSERT INTO users (username, password_hash, team) VALUES (?, ?, ?)', |
|
(username, password_hash.decode('utf-8'), team) |
|
) |
|
conn.commit() |
|
conn.close() |
|
return True |
|
except sqlite3.IntegrityError: |
|
return False |
|
|
|
def authenticate_user(self, username: str, password: str) -> Optional[Dict]: |
|
"""Authenticate user and return user info""" |
|
try: |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute('SELECT id, username, password_hash, team FROM users WHERE username = ?', (username,)) |
|
user = cursor.fetchone() |
|
conn.close() |
|
|
|
if user and bcrypt.checkpw(password.encode('utf-8'), user[2].encode('utf-8')): |
|
return { |
|
'id': user[0], |
|
'username': user[1], |
|
'team': user[3] |
|
} |
|
return None |
|
except Exception as e: |
|
print(f"Authentication error: {e}") |
|
return None |
|
|
|
def save_chat_history(self, user_id: int, query: str, response: str, cited_pages: List[str]): |
|
"""Save chat interaction to database""" |
|
try: |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
cited_pages_json = json.dumps(cited_pages) |
|
cursor.execute( |
|
'INSERT INTO chat_history (user_id, query, response, cited_pages) VALUES (?, ?, ?, ?)', |
|
(user_id, query, response, cited_pages_json) |
|
) |
|
conn.commit() |
|
conn.close() |
|
except Exception as e: |
|
print(f"Error saving chat history: {e}") |
|
|
|
def get_chat_history(self, user_id: int, limit: int = 10) -> List[Dict]: |
|
"""Get recent chat history for user""" |
|
try: |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute(''' |
|
SELECT query, response, cited_pages, timestamp |
|
FROM chat_history |
|
WHERE user_id = ? |
|
ORDER BY timestamp DESC |
|
LIMIT ? |
|
''', (user_id, limit)) |
|
|
|
history = [] |
|
for row in cursor.fetchall(): |
|
history.append({ |
|
'query': row[0], |
|
'response': row[1], |
|
'cited_pages': json.loads(row[2]) if row[2] else [], |
|
'timestamp': row[3] |
|
}) |
|
|
|
conn.close() |
|
return history |
|
except Exception as e: |
|
print(f"Error getting chat history: {e}") |
|
return [] |
|
|
|
def save_document_collection(self, collection_name: str, team: str, user_id: int, file_count: int): |
|
"""Save document collection info""" |
|
try: |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute( |
|
'INSERT OR REPLACE INTO document_collections (collection_name, team, uploaded_by, file_count) VALUES (?, ?, ?, ?)', |
|
(collection_name, team, user_id, file_count) |
|
) |
|
conn.commit() |
|
conn.close() |
|
except Exception as e: |
|
print(f"Error saving document collection: {e}") |
|
|
|
def get_team_collections(self, team: str) -> List[str]: |
|
"""Get all collections for a team""" |
|
try: |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute('SELECT collection_name FROM document_collections WHERE team = ?', (team,)) |
|
collections = [row[0] for row in cursor.fetchall()] |
|
conn.close() |
|
return collections |
|
except Exception as e: |
|
print(f"Error getting team collections: {e}") |
|
return [] |
|
|
|
def clear_chat_history(self, user_id: int) -> bool: |
|
"""Clear all chat history for a user""" |
|
try: |
|
conn = sqlite3.connect(self.db_path) |
|
cursor = conn.cursor() |
|
|
|
cursor.execute('DELETE FROM chat_history WHERE user_id = ?', (user_id,)) |
|
conn.commit() |
|
conn.close() |
|
return True |
|
except Exception as e: |
|
print(f"Error clearing chat history: {e}") |
|
return False |
|
|
|
|
|
class SessionManager: |
|
def __init__(self): |
|
self.active_sessions = {} |
|
self.session_lock = threading.Lock() |
|
|
|
def create_session(self, user_info: Dict) -> str: |
|
"""Create a new session for user""" |
|
session_id = str(uuid.uuid4()) |
|
with self.session_lock: |
|
self.active_sessions[session_id] = { |
|
'user_info': user_info, |
|
'created_at': datetime.now(), |
|
'last_activity': datetime.now() |
|
} |
|
return session_id |
|
|
|
def get_session(self, session_id: str) -> Optional[Dict]: |
|
"""Get session info""" |
|
with self.session_lock: |
|
if session_id in self.active_sessions: |
|
self.active_sessions[session_id]['last_activity'] = datetime.now() |
|
return self.active_sessions[session_id] |
|
return None |
|
|
|
def remove_session(self, session_id: str): |
|
"""Remove session""" |
|
with self.session_lock: |
|
if session_id in self.active_sessions: |
|
del self.active_sessions[session_id] |
|
|
|
|
|
db_manager = DatabaseManager() |
|
session_manager = SessionManager() |
|
|
|
|
|
def create_default_users(): |
|
"""Create default team users""" |
|
teams = ["Team_A", "Team_B"] |
|
for team in teams: |
|
username = f"admin_{team.lower()}" |
|
password = f"admin123_{team.lower()}" |
|
if not db_manager.authenticate_user(username, password): |
|
db_manager.create_user(username, password, team) |
|
print(f"Created default user: {username} for {team}") |
|
|
|
create_default_users() |
|
|
|
|
|
def start_services(): |
|
|
|
if platform.system() == "Windows": |
|
def is_docker_desktop_running(): |
|
try: |
|
|
|
result = subprocess.run( |
|
["tasklist", "/FI", "IMAGENAME eq Docker Desktop.exe"], |
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) |
|
return "Docker Desktop.exe" in result.stdout.decode() |
|
except Exception as e: |
|
print("Error checking Docker Desktop:", e) |
|
return False |
|
|
|
def start_docker_desktop(): |
|
|
|
docker_desktop_path = r"C:\Program Files\Docker\Docker\Docker Desktop.exe" |
|
if not os.path.exists(docker_desktop_path): |
|
print("Docker Desktop executable not found. Please verify the installation path.") |
|
return |
|
try: |
|
subprocess.Popen([docker_desktop_path], shell=True) |
|
print("Docker Desktop is starting...") |
|
except Exception as e: |
|
print("Error starting Docker Desktop:", e) |
|
|
|
if is_docker_desktop_running(): |
|
print("Docker Desktop is already running.") |
|
else: |
|
print("Docker Desktop is not running. Starting it now...") |
|
start_docker_desktop() |
|
|
|
time.sleep(15) |
|
|
|
|
|
def is_ollama_running(): |
|
if platform.system() == "Windows": |
|
try: |
|
|
|
result = subprocess.run( |
|
['tasklist', '/FI', 'IMAGENAME eq ollama.exe'], |
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) |
|
return "ollama.exe" in result.stdout.decode().lower() |
|
except Exception as e: |
|
print("Error checking Ollama on Windows:", e) |
|
return False |
|
else: |
|
try: |
|
result = subprocess.run( |
|
['pgrep', '-f', 'ollama'], |
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) |
|
return result.returncode == 0 |
|
except Exception as e: |
|
print("Error checking Ollama:", e) |
|
return False |
|
|
|
def start_ollama(): |
|
if platform.system() == "Windows": |
|
try: |
|
subprocess.Popen(['ollama', 'serve'], shell=True) |
|
print("Ollama server started on Windows.") |
|
except Exception as e: |
|
print("Failed to start Ollama server on Windows:", e) |
|
else: |
|
try: |
|
subprocess.Popen(['ollama', 'serve']) |
|
print("Ollama server started.") |
|
except Exception as e: |
|
print("Failed to start Ollama server:", e) |
|
|
|
if is_ollama_running(): |
|
print("Ollama server is already running.") |
|
else: |
|
print("Ollama server is not running. Starting it...") |
|
start_ollama() |
|
|
|
|
|
def get_docker_containers(): |
|
try: |
|
result = subprocess.run( |
|
['docker', 'ps', '-aq'], |
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) |
|
if result.returncode != 0: |
|
print("Error retrieving Docker containers:", result.stderr.decode()) |
|
return [] |
|
return result.stdout.decode().splitlines() |
|
except Exception as e: |
|
print("Error retrieving Docker containers:", e) |
|
return [] |
|
|
|
def get_running_docker_containers(): |
|
try: |
|
result = subprocess.run( |
|
['docker', 'ps', '-q'], |
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) |
|
if result.returncode != 0: |
|
print("Error retrieving running Docker containers:", result.stderr.decode()) |
|
return [] |
|
return result.stdout.decode().splitlines() |
|
except Exception as e: |
|
print("Error retrieving running Docker containers:", e) |
|
return [] |
|
|
|
def start_docker_container(container_id): |
|
try: |
|
result = subprocess.run( |
|
['docker', 'start', container_id], |
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE |
|
) |
|
if result.returncode == 0: |
|
print(f"Started Docker container {container_id}.") |
|
else: |
|
print(f"Failed to start Docker container {container_id}: {result.stderr.decode()}") |
|
except Exception as e: |
|
print(f"Error starting Docker container {container_id}: {e}") |
|
|
|
all_containers = set(get_docker_containers()) |
|
running_containers = set(get_running_docker_containers()) |
|
stopped_containers = all_containers - running_containers |
|
|
|
if stopped_containers: |
|
print(f"Found {len(stopped_containers)} stopped Docker container(s). Starting them...") |
|
for container_id in stopped_containers: |
|
start_docker_container(container_id) |
|
else: |
|
print("All Docker containers are already running.") |
|
|
|
|
|
start_services() |
|
|
|
def generate_uuid(state): |
|
|
|
if state["user_uuid"] is None: |
|
|
|
state["user_uuid"] = str(uuid.uuid4()) |
|
|
|
return state["user_uuid"] |
|
|
|
|
|
class PDFSearchApp: |
|
def __init__(self): |
|
self.indexed_docs = {} |
|
self.current_pdf = None |
|
self.db_manager = db_manager |
|
self.session_manager = session_manager |
|
|
|
def upload_and_convert(self, state, files, max_pages, session_id=None, folder_name=None): |
|
"""Upload and convert files with team-based organization""" |
|
|
|
if files is None: |
|
return "No file uploaded" |
|
|
|
try: |
|
|
|
user_info = None |
|
team = "default" |
|
if session_id: |
|
session = self.session_manager.get_session(session_id) |
|
if session: |
|
user_info = session['user_info'] |
|
team = user_info['team'] |
|
|
|
total_pages = 0 |
|
uploaded_files = [] |
|
|
|
|
|
if folder_name: |
|
folder_name = folder_name.replace(" ", "_").replace("-", "_") |
|
collection_name = f"{team}_{folder_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
else: |
|
collection_name = f"{team}_documents_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
|
for file in files[:]: |
|
|
|
filename = os.path.basename(file.name) |
|
name, ext = os.path.splitext(filename) |
|
pdf_path = file.name |
|
|
|
|
|
if ext.lower() in [".ppt", ".pptx"]: |
|
output_file = os.path.splitext(file.name)[0] + '.pdf' |
|
output_directory = os.path.dirname(file.name) |
|
outfile = os.path.join(output_directory, output_file) |
|
convert(file.name, outfile) |
|
pdf_path = outfile |
|
name = os.path.basename(outfile) |
|
name, ext = os.path.splitext(name) |
|
|
|
|
|
doc_id = f"{collection_name}_{name.replace(' ', '_').replace('-', '_')}" |
|
|
|
print(f"Uploading file: {doc_id}") |
|
middleware = Middleware(collection_name, create_collection=True) |
|
|
|
pages = middleware.index(pdf_path, id=doc_id, max_pages=max_pages) |
|
total_pages += len(pages) if pages else 0 |
|
uploaded_files.append(doc_id) |
|
|
|
self.indexed_docs[doc_id] = True |
|
|
|
|
|
if user_info: |
|
self.db_manager.save_document_collection( |
|
collection_name, |
|
team, |
|
user_info['id'], |
|
len(uploaded_files) |
|
) |
|
|
|
return f"Uploaded {len(uploaded_files)} files with {total_pages} total pages to collection: {collection_name}" |
|
|
|
except Exception as e: |
|
return f"Error processing files: {str(e)}" |
|
|
|
|
|
def display_file_list(text): |
|
try: |
|
|
|
directory_path = "pages" |
|
current_working_directory = os.getcwd() |
|
directory_path = os.path.join(current_working_directory, directory_path) |
|
entries = os.listdir(directory_path) |
|
|
|
directories = [entry for entry in entries if os.path.isdir(os.path.join(directory_path, entry))] |
|
return directories |
|
except FileNotFoundError: |
|
return f"The directory {directory_path} does not exist." |
|
except PermissionError: |
|
return f"Permission denied to access {directory_path}." |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
def search_documents(self, state, query, num_results, session_id=None): |
|
print(f"Searching for query: {query}") |
|
|
|
if not query: |
|
print("Please enter a search query") |
|
return "Please enter a search query", "--", "Please enter a search query", [], None |
|
|
|
try: |
|
|
|
user_info = None |
|
if session_id: |
|
session = self.session_manager.get_session(session_id) |
|
if session: |
|
user_info = session['user_info'] |
|
|
|
middleware = Middleware("test", create_collection=False) |
|
|
|
|
|
|
|
|
|
search_results = middleware.search([query], topk=max(num_results * 3, 20))[0] |
|
|
|
|
|
print(f"π Retrieved {len(search_results)} total results from search") |
|
if len(search_results) > 0: |
|
print(f"π Top result score: {search_results[0][0]:.3f}") |
|
print(f"π Bottom result score: {search_results[-1][0]:.3f}") |
|
|
|
if not search_results: |
|
return "No search results found", "--", "No search results found for your query", [], None |
|
|
|
|
|
selected_results = self._select_relevant_pages(search_results, query, num_results) |
|
|
|
|
|
cited_pages = [] |
|
img_paths = [] |
|
all_paths = [] |
|
page_scores = [] |
|
|
|
print(f"π Processing {len(selected_results)} selected results...") |
|
|
|
for i, (score, page_num, coll_num) in enumerate(selected_results): |
|
|
|
display_page_num = page_num + 1 |
|
img_path = f"pages/{coll_num}/page_{display_page_num}.png" |
|
path = f"pages/{coll_num}/page_{display_page_num}" |
|
|
|
if os.path.exists(img_path): |
|
img_paths.append(img_path) |
|
all_paths.append(path) |
|
page_scores.append(score) |
|
cited_pages.append(f"Page {display_page_num} from {coll_num}") |
|
print(f"β
Retrieved page {i+1}: {img_path} (Score: {score:.3f})") |
|
else: |
|
print(f"β Image file not found: {img_path}") |
|
|
|
print(f"π Final count: {len(img_paths)} valid pages out of {len(selected_results)} selected") |
|
|
|
if not img_paths: |
|
return "No valid image files found", "--", "Error: No valid image files found for the search results", [], None |
|
|
|
|
|
rag_response, csv_filepath, doc_filepath, excel_filepath = self._generate_multi_page_response(query, img_paths, cited_pages, page_scores) |
|
|
|
|
|
if user_info: |
|
self.db_manager.save_chat_history( |
|
user_info['id'], |
|
query, |
|
rag_response, |
|
cited_pages |
|
) |
|
|
|
|
|
csv_download = self._prepare_csv_download(csv_filepath) |
|
doc_download = self._prepare_doc_download(doc_filepath) |
|
excel_download = self._prepare_excel_download(excel_filepath) |
|
|
|
|
|
if len(img_paths) > 1: |
|
|
|
|
|
gallery_images = [] |
|
for i, img_path in enumerate(img_paths): |
|
|
|
page_info = cited_pages[i].split(" from ")[0] |
|
page_num = page_info.split("Page ")[1] |
|
gallery_images.append((img_path, f"Page {page_num}")) |
|
return ", ".join(all_paths), gallery_images, rag_response, cited_pages, csv_download, doc_download, excel_download |
|
else: |
|
|
|
page_info = cited_pages[0].split(" from ")[0] |
|
page_num = page_info.split("Page ")[1] |
|
return all_paths[0], [(img_paths[0], f"Page {page_num}")], rag_response, cited_pages, csv_download, doc_download, excel_download |
|
|
|
except Exception as e: |
|
error_msg = f"Error during search: {str(e)}" |
|
return error_msg, "--", error_msg, [], None, None, None, None |
|
|
|
def _select_relevant_pages(self, search_results, query, num_results): |
|
""" |
|
Intelligent page selection using vision-guided chunking principles |
|
Based on research from M3DocRAG and multi-modal retrieval models |
|
""" |
|
if len(search_results) <= num_results: |
|
return search_results |
|
|
|
|
|
multi_page_keywords = [ |
|
'compare', 'difference', 'similarities', 'both', 'multiple', 'various', |
|
'different', 'types', 'kinds', 'categories', 'procedures', 'methods', |
|
'approaches', 'techniques', 'safety', 'protocols', 'guidelines', |
|
'overview', 'summary', 'comprehensive', 'complete', 'all', 'everything' |
|
] |
|
|
|
query_lower = query.lower() |
|
needs_multiple_pages = any(keyword in query_lower for keyword in multi_page_keywords) |
|
|
|
|
|
sorted_results = sorted(search_results, key=lambda x: x[0], reverse=True) |
|
|
|
|
|
|
|
|
|
|
|
selected = [] |
|
seen_collections = set() |
|
|
|
|
|
for score, page_num, coll_num in sorted_results: |
|
if coll_num not in seen_collections and len(selected) < min(num_results // 2, len(search_results)): |
|
selected.append((score, page_num, coll_num)) |
|
seen_collections.add(coll_num) |
|
|
|
|
|
for score, page_num, coll_num in sorted_results: |
|
if (score, page_num, coll_num) not in selected and len(selected) < num_results: |
|
selected.append((score, page_num, coll_num)) |
|
|
|
|
|
if len(selected) < num_results: |
|
for score, page_num, coll_num in sorted_results: |
|
if (score, page_num, coll_num) not in selected and len(selected) < num_results: |
|
selected.append((score, page_num, coll_num)) |
|
|
|
|
|
if len(selected) > num_results: |
|
selected = selected[:num_results] |
|
|
|
|
|
if len(selected) < num_results and len(sorted_results) >= num_results: |
|
for score, page_num, coll_num in sorted_results: |
|
if (score, page_num, coll_num) not in selected and len(selected) < num_results: |
|
selected.append((score, page_num, coll_num)) |
|
|
|
|
|
selected.sort(key=lambda x: x[0], reverse=True) |
|
|
|
print(f"Requested {num_results} pages, selected {len(selected)} pages from {len(seen_collections)} collections") |
|
|
|
|
|
if len(selected) != num_results: |
|
print(f"β οΈ Warning: Requested {num_results} pages but selected {len(selected)} pages") |
|
if len(selected) < num_results and len(sorted_results) >= num_results: |
|
|
|
for score, page_num, coll_num in sorted_results: |
|
if (score, page_num, coll_num) not in selected and len(selected) < num_results: |
|
selected.append((score, page_num, coll_num)) |
|
print(f"Added more pages to reach target: {len(selected)} pages") |
|
|
|
return selected |
|
|
|
def _optimize_consecutive_pages(self, selected, all_results, target_count=None): |
|
""" |
|
Optimize selection to include consecutive pages when beneficial |
|
""" |
|
|
|
collection_pages = {} |
|
for score, page_num, coll_num in selected: |
|
if coll_num not in collection_pages: |
|
collection_pages[coll_num] = [] |
|
collection_pages[coll_num].append((score, page_num, coll_num)) |
|
|
|
optimized = [] |
|
for coll_num, pages in collection_pages.items(): |
|
if len(pages) > 1: |
|
|
|
page_nums = [p[1] for p in pages] |
|
page_nums.sort() |
|
|
|
|
|
if max(page_nums) - min(page_nums) == len(page_nums) - 1: |
|
|
|
for score, page_num, coll in all_results: |
|
if (coll == coll_num and |
|
min(page_nums) <= page_num <= max(page_nums) and |
|
(score, page_num, coll) not in optimized): |
|
optimized.append((score, page_num, coll)) |
|
else: |
|
optimized.extend(pages) |
|
else: |
|
optimized.extend(pages) |
|
|
|
|
|
if target_count and len(optimized) != target_count: |
|
if len(optimized) > target_count: |
|
|
|
optimized.sort(key=lambda x: x[0], reverse=True) |
|
optimized = optimized[:target_count] |
|
elif len(optimized) < target_count: |
|
|
|
for score, page_num, coll in all_results: |
|
if (score, page_num, coll) not in optimized and len(optimized) < target_count: |
|
optimized.append((score, page_num, coll)) |
|
|
|
return optimized |
|
|
|
def _generate_comprehensive_analysis(self, query, cited_pages, page_scores): |
|
""" |
|
Generate comprehensive analysis section based on research strategies |
|
Implements hierarchical retrieval insights and cross-reference analysis |
|
""" |
|
try: |
|
|
|
query_lower = query.lower() |
|
|
|
|
|
query_types = [] |
|
if any(word in query_lower for word in ['compare', 'difference', 'similarities', 'versus']): |
|
query_types.append("Comparative Analysis") |
|
if any(word in query_lower for word in ['procedure', 'method', 'how to', 'steps']): |
|
query_types.append("Procedural Information") |
|
if any(word in query_lower for word in ['safety', 'warning', 'danger', 'risk']): |
|
query_types.append("Safety Information") |
|
if any(word in query_lower for word in ['specification', 'technical', 'measurement', 'data']): |
|
query_types.append("Technical Specifications") |
|
if any(word in query_lower for word in ['overview', 'summary', 'comprehensive', 'complete']): |
|
query_types.append("Comprehensive Overview") |
|
if any(word in query_lower for word in ['table', 'csv', 'spreadsheet', 'data', 'list', 'chart']): |
|
query_types.append("Tabular Data Request") |
|
|
|
|
|
avg_score = sum(page_scores) / len(page_scores) if page_scores else 0 |
|
score_variance = sum((score - avg_score) ** 2 for score in page_scores) / len(page_scores) if page_scores else 0 |
|
|
|
|
|
analysis = f""" |
|
π¬ **Comprehensive Analysis & Insights**: |
|
|
|
π **Query Analysis**: |
|
β’ Query Type: {', '.join(query_types) if query_types else 'General Information'} |
|
β’ Information Complexity: {'High' if len(cited_pages) > 3 else 'Medium' if len(cited_pages) > 1 else 'Low'} |
|
β’ Cross-Reference Depth: {'Excellent' if len(set([p.split(' from ')[1].split(' (')[0] for p in cited_pages])) > 2 else 'Good' if len(set([p.split(' from ')[1].split(' (')[0] for p in cited_pages])) > 1 else 'Limited'} |
|
|
|
π **Information Quality Assessment**: |
|
β’ Average Relevance: {avg_score:.3f} ({'Excellent' if avg_score > 0.9 else 'Very Good' if avg_score > 0.8 else 'Good' if avg_score > 0.7 else 'Moderate' if avg_score > 0.6 else 'Basic'}) |
|
β’ Information Consistency: {'High' if score_variance < 0.1 else 'Moderate' if score_variance < 0.2 else 'Variable'} |
|
β’ Source Reliability: {'High' if avg_score > 0.8 and len(cited_pages) > 2 else 'Moderate' if avg_score > 0.6 else 'Requires Verification'} |
|
|
|
π― **Information Coverage Analysis**: |
|
β’ Primary Information: {'Comprehensive' if any('primary' in p.lower() or 'main' in p.lower() for p in cited_pages) else 'Standard'} |
|
β’ Supporting Details: {'Extensive' if len(cited_pages) > 3 else 'Adequate' if len(cited_pages) > 1 else 'Basic'} |
|
β’ Technical Depth: {'High' if any('technical' in p.lower() or 'specification' in p.lower() for p in cited_pages) else 'Standard'} |
|
|
|
π‘ **Strategic Insights**: |
|
β’ Information Gaps: {'Minimal' if avg_score > 0.8 and len(cited_pages) > 3 else 'Moderate' if avg_score > 0.6 else 'Significant - consider additional sources'} |
|
β’ Cross-Validation: {'Strong' if len(set([p.split(' from ')[1].split(' (')[0] for p in cited_pages])) > 1 else 'Limited to single source'} |
|
β’ Practical Applicability: {'High' if any('procedure' in p.lower() or 'method' in p.lower() for p in cited_pages) else 'Moderate'} |
|
|
|
π **Recommendations for Further Research**: |
|
β’ {'Consider additional technical specifications' if not any('technical' in p.lower() for p in cited_pages) else 'Technical coverage adequate'} |
|
β’ {'Seek safety guidelines and warnings' if not any('safety' in p.lower() for p in cited_pages) else 'Safety information included'} |
|
β’ {'Look for comparative analysis' if not any('compare' in p.lower() for p in cited_pages) else 'Comparative analysis available'} |
|
""" |
|
|
|
return analysis |
|
|
|
except Exception as e: |
|
print(f"Error generating comprehensive analysis: {e}") |
|
return "π¬ **Analysis**: Comprehensive analysis of retrieved information completed." |
|
|
|
|
|
|
|
def _detect_table_request(self, query): |
|
""" |
|
Detect if the user is requesting tabular data |
|
""" |
|
query_lower = query.lower() |
|
table_keywords = [ |
|
'table', 'csv', 'spreadsheet', 'data table', 'list', 'chart', |
|
'tabular', 'matrix', 'grid', 'dataset', 'data set', |
|
'show me a table', 'create a table', 'generate table', |
|
'in table format', 'as a table', 'tabular format' |
|
] |
|
|
|
return any(keyword in query_lower for keyword in table_keywords) |
|
|
|
def _detect_report_request(self, query): |
|
""" |
|
Detect if the user is requesting a comprehensive report |
|
""" |
|
query_lower = query.lower() |
|
report_keywords = [ |
|
'report', 'comprehensive report', 'detailed report', 'full report', |
|
'complete report', 'comprehensive analysis', 'detailed analysis', |
|
'full analysis', 'complete analysis', 'comprehensive overview', |
|
'detailed overview', 'full overview', 'complete overview', |
|
'comprehensive summary', 'detailed summary', 'full summary', |
|
'complete summary', 'comprehensive document', 'detailed document', |
|
'full document', 'complete document', 'comprehensive review', |
|
'detailed review', 'full review', 'complete review', |
|
'export report', 'generate report', 'create report', |
|
'doc format', 'word document', 'word doc', 'document format' |
|
] |
|
|
|
return any(keyword in query_lower for keyword in report_keywords) |
|
|
|
def _detect_chart_request(self, query): |
|
""" |
|
Detect if the user is requesting charts, graphs, or visualizations |
|
""" |
|
query_lower = query.lower() |
|
chart_keywords = [ |
|
'chart', 'graph', 'bar chart', 'line chart', 'pie chart', |
|
'bar graph', 'line graph', 'pie graph', 'histogram', |
|
'scatter plot', 'scatter chart', 'area chart', 'column chart', |
|
'visualization', 'visualize', 'plot', 'figure', 'diagram', |
|
'excel chart', 'excel graph', 'spreadsheet chart', |
|
'create chart', 'generate chart', 'make chart', |
|
'create graph', 'generate graph', 'make graph', |
|
'chart data', 'graph data', 'plot data', 'visualize data', |
|
'bar graph', 'line graph', 'pie graph', 'histogram', |
|
'scatter plot', 'area chart', 'column chart' |
|
] |
|
|
|
return any(keyword in query_lower for keyword in chart_keywords) |
|
|
|
def _extract_custom_headers(self, query): |
|
""" |
|
Extract custom headers from user query for both tables and charts |
|
Examples: |
|
- "create table with columns: Name, Age, Department" |
|
- "create chart with headers: Threat Type, Frequency, Risk Level" |
|
- "excel export with columns: Category, Value, Description" |
|
""" |
|
try: |
|
|
|
header_patterns = [ |
|
r'columns?:\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'headers?:\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'\bwith\s+columns?\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'\bwith\s+headers?\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'headers?\s*=\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'format:\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'chart\s+headers?:\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'excel\s+headers?:\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'chart\s+with\s+headers?:\s*([^,]+(?:,\s*[^,]+)*)', |
|
r'excel\s+with\s+headers?:\s*([^,]+(?:,\s*[^,]+)*)', |
|
] |
|
|
|
for pattern in header_patterns: |
|
match = re.search(pattern, query, re.IGNORECASE) |
|
if match: |
|
headers_str = match.group(1) |
|
|
|
headers = [h.strip() for h in headers_str.split(',')] |
|
|
|
headers = [h for h in headers if h] |
|
if headers: |
|
print(f"π Custom headers detected: {headers}") |
|
return headers |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error extracting custom headers: {e}") |
|
return None |
|
|
|
def _generate_csv_table_response(self, query, rag_response, cited_pages, page_scores): |
|
""" |
|
Generate a CSV table response when user requests tabular data |
|
""" |
|
try: |
|
|
|
custom_headers = self._extract_custom_headers(query) |
|
|
|
|
|
csv_data = self._extract_structured_data(rag_response, cited_pages, page_scores, custom_headers) |
|
|
|
if csv_data: |
|
|
|
csv_content = self._format_as_csv(csv_data) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
safe_query = "".join(c for c in query[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip() |
|
safe_query = safe_query.replace(' ', '_') |
|
filename = f"table_{safe_query}_{timestamp}.csv" |
|
filepath = os.path.join("temp", filename) |
|
|
|
|
|
os.makedirs("temp", exist_ok=True) |
|
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
f.write(csv_content) |
|
|
|
|
|
header_info = "" |
|
if custom_headers: |
|
header_info = f""" |
|
π **Custom Headers Applied**: |
|
β’ Headers: {', '.join(custom_headers)} |
|
β’ Data automatically mapped to your specified columns |
|
""" |
|
|
|
table_response = f""" |
|
{rag_response} |
|
|
|
π **CSV Table Generated Successfully**: |
|
|
|
```csv |
|
{csv_content} |
|
``` |
|
|
|
{header_info} |
|
|
|
πΎ **Download Options**: |
|
β’ **Direct Download**: Click the download button below |
|
β’ **Manual Copy**: Copy the CSV content above and save as .csv file |
|
|
|
π **Table Information**: |
|
β’ Rows: {len(csv_data) if csv_data else 0} |
|
β’ Columns: {len(csv_data[0]) if csv_data and len(csv_data) > 0 else 0} |
|
β’ Data Source: {len(cited_pages)} document pages |
|
β’ Filename: {filename} |
|
""" |
|
return table_response, filepath |
|
else: |
|
|
|
header_suggestion = "" |
|
if custom_headers: |
|
header_suggestion = f""" |
|
π **Custom Headers Detected**: {', '.join(custom_headers)} |
|
The system found your specified headers but couldn't extract matching data from the response. |
|
""" |
|
|
|
fallback_response = f""" |
|
{rag_response} |
|
|
|
π **Table Request Detected**: |
|
The system detected you requested tabular data, but the current response doesn't contain structured information suitable for a CSV table. |
|
|
|
{header_suggestion} |
|
|
|
π‘ **Suggestions**: |
|
β’ Try asking for specific data types (e.g., "list of safety procedures", "compare different methods") |
|
β’ Request numerical data or comparisons |
|
β’ Ask for categorized information |
|
β’ Specify custom headers: "create table with columns: Name, Age, Department" |
|
""" |
|
return fallback_response, None |
|
|
|
except Exception as e: |
|
print(f"Error generating CSV table response: {e}") |
|
return rag_response, None |
|
|
|
def _extract_structured_data(self, rag_response, cited_pages, page_scores, custom_headers=None): |
|
""" |
|
Extract ANY structured data from RAG response - no predefined templates |
|
""" |
|
try: |
|
lines = rag_response.split('\n') |
|
structured_data = [] |
|
|
|
|
|
if custom_headers: |
|
headers = custom_headers |
|
structured_data = [headers] |
|
|
|
|
|
data_rows = [] |
|
|
|
|
|
for line in lines: |
|
line = line.strip() |
|
if line and not line.startswith('#'): |
|
|
|
data_row = self._extract_data_from_line(line, headers) |
|
if data_row: |
|
data_rows.append(data_row) |
|
|
|
|
|
if data_rows: |
|
structured_data.extend(data_rows) |
|
else: |
|
|
|
for i, citation in enumerate(cited_pages): |
|
row = self._create_placeholder_row(citation, headers, i) |
|
structured_data.append(row) |
|
|
|
return structured_data |
|
|
|
|
|
else: |
|
|
|
table_data = self._find_table_structures(lines) |
|
if table_data: |
|
return table_data |
|
|
|
|
|
list_data = self._find_list_structures(lines) |
|
if list_data: |
|
return list_data |
|
|
|
|
|
kv_data = self._find_key_value_structures(lines) |
|
if kv_data: |
|
return kv_data |
|
|
|
|
|
return self._create_summary_table(cited_pages) |
|
|
|
except Exception as e: |
|
print(f"Error extracting structured data: {e}") |
|
return None |
|
|
|
def _extract_data_from_line(self, line, headers): |
|
"""Extract data from a line that could fit the specified headers""" |
|
try: |
|
|
|
line = re.sub(r'^[\dβ’\-\.\s]+', '', line) |
|
|
|
|
|
if len(headers) > 1: |
|
|
|
if ',' in line: |
|
parts = [p.strip() for p in line.split(',')] |
|
elif ';' in line: |
|
parts = [p.strip() for p in line.split(';')] |
|
elif ' - ' in line: |
|
parts = [p.strip() for p in line.split(' - ')] |
|
elif ':' in line: |
|
parts = [p.strip() for p in line.split(':', 1)] |
|
else: |
|
|
|
parts = [line] + [''] * (len(headers) - 1) |
|
|
|
|
|
while len(parts) < len(headers): |
|
parts.append('') |
|
return parts[:len(headers)] |
|
else: |
|
return [line] |
|
|
|
except Exception as e: |
|
print(f"Error extracting data from line: {e}") |
|
return None |
|
|
|
def _create_placeholder_row(self, citation, headers, index): |
|
"""Create a placeholder row based on available data""" |
|
try: |
|
row = [] |
|
for header in headers: |
|
header_lower = header.lower() |
|
|
|
if 'page' in header_lower or 'number' in header_lower: |
|
page_num = citation.split('Page ')[1].split(' from')[0] if 'Page ' in citation else str(index + 1) |
|
row.append(page_num) |
|
elif 'collection' in header_lower or 'source' in header_lower or 'document' in header_lower: |
|
collection = citation.split(' from ')[1] if ' from ' in citation else 'Unknown' |
|
row.append(collection) |
|
elif 'content' in header_lower or 'description' in header_lower or 'summary' in header_lower: |
|
row.append(f"Content from {citation}") |
|
else: |
|
|
|
if 'page' in citation: |
|
row.append(citation) |
|
else: |
|
row.append('') |
|
|
|
return row |
|
|
|
except Exception as e: |
|
print(f"Error creating placeholder row: {e}") |
|
return [''] * len(headers) |
|
|
|
def _find_table_structures(self, lines): |
|
"""Find any table-like structures in the text""" |
|
try: |
|
table_lines = [] |
|
for line in lines: |
|
line = line.strip() |
|
|
|
if '|' in line or '\t' in line or re.search(r'\s{3,}', line): |
|
table_lines.append(line) |
|
|
|
if table_lines: |
|
|
|
first_line = table_lines[0] |
|
if '|' in first_line: |
|
headers = [h.strip() for h in first_line.split('|')] |
|
else: |
|
headers = re.split(r'\s{3,}', first_line) |
|
|
|
structured_data = [headers] |
|
|
|
|
|
for line in table_lines[1:]: |
|
if '|' in line: |
|
columns = [col.strip() for col in line.split('|')] |
|
else: |
|
columns = re.split(r'\s{3,}', line) |
|
|
|
if len(columns) >= 2: |
|
structured_data.append(columns) |
|
|
|
return structured_data |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error finding table structures: {e}") |
|
return None |
|
|
|
def _find_list_structures(self, lines): |
|
"""Find any list-like structures in the text""" |
|
try: |
|
items = [] |
|
for line in lines: |
|
line = line.strip() |
|
|
|
if re.match(r'^[\dβ’\-\.]+', line): |
|
item = re.sub(r'^[\dβ’\-\.\s]+', '', line) |
|
if item: |
|
items.append(item) |
|
|
|
if items: |
|
|
|
structured_data = [['Item', 'Description']] |
|
for i, item in enumerate(items, 1): |
|
structured_data.append([str(i), item]) |
|
|
|
return structured_data |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error finding list structures: {e}") |
|
return None |
|
|
|
def _find_key_value_structures(self, lines): |
|
"""Find any key-value structures in the text""" |
|
try: |
|
kv_pairs = [] |
|
for line in lines: |
|
line = line.strip() |
|
|
|
if re.match(r'^[A-Za-z\s]+:\s+', line): |
|
kv_pairs.append(line) |
|
|
|
if kv_pairs: |
|
structured_data = [['Property', 'Value']] |
|
for pair in kv_pairs: |
|
if ':' in pair: |
|
key, value = pair.split(':', 1) |
|
structured_data.append([key.strip(), value.strip()]) |
|
|
|
return structured_data |
|
|
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error finding key-value structures: {e}") |
|
return None |
|
|
|
def _create_summary_table(self, cited_pages): |
|
"""Create a simple summary table as last resort""" |
|
try: |
|
structured_data = [['Page', 'Collection', 'Content']] |
|
for i, citation in enumerate(cited_pages): |
|
collection = citation.split(' from ')[1] if ' from ' in citation else 'Unknown' |
|
page_num = citation.split('Page ')[1].split(' from')[0] if 'Page ' in citation else str(i+1) |
|
structured_data.append([page_num, collection, f"Content from {citation}"]) |
|
|
|
return structured_data |
|
|
|
except Exception as e: |
|
print(f"Error creating summary table: {e}") |
|
return None |
|
|
|
except Exception as e: |
|
print(f"Error extracting structured data: {e}") |
|
return None |
|
|
|
def _format_as_csv(self, data): |
|
""" |
|
Format structured data as CSV |
|
""" |
|
try: |
|
csv_lines = [] |
|
for row in data: |
|
|
|
escaped_row = [] |
|
for cell in row: |
|
cell_str = str(cell) |
|
if ',' in cell_str or '"' in cell_str or '\n' in cell_str: |
|
|
|
cell_str = f'"{cell_str.replace('"', '""')}"' |
|
escaped_row.append(cell_str) |
|
csv_lines.append(','.join(escaped_row)) |
|
|
|
return '\n'.join(csv_lines) |
|
|
|
except Exception as e: |
|
print(f"Error formatting CSV: {e}") |
|
return "Error,Generating,CSV,Format" |
|
|
|
def _prepare_csv_download(self, csv_filepath): |
|
""" |
|
Prepare CSV file for download in Gradio |
|
""" |
|
if csv_filepath and os.path.exists(csv_filepath): |
|
return csv_filepath |
|
else: |
|
return None |
|
|
|
def _generate_comprehensive_doc_report(self, query, rag_response, cited_pages, page_scores, user_info=None): |
|
""" |
|
Generate a comprehensive DOC report with proper formatting and structure |
|
""" |
|
if not DOCX_AVAILABLE: |
|
return None, "DOC export not available - python-docx library not installed" |
|
|
|
try: |
|
print("π [REPORT] Generating comprehensive DOC report...") |
|
|
|
|
|
doc = Document() |
|
|
|
|
|
self._setup_document_styles(doc) |
|
|
|
|
|
self._add_title_page(doc, query, user_info) |
|
|
|
|
|
self._add_executive_summary(doc, query, rag_response) |
|
|
|
|
|
self._add_detailed_analysis(doc, rag_response, cited_pages, page_scores) |
|
|
|
|
|
self._add_methodology_section(doc, cited_pages, page_scores) |
|
|
|
|
|
self._add_findings_conclusions(doc, rag_response, cited_pages) |
|
|
|
|
|
self._add_appendices(doc, cited_pages, page_scores) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
safe_query = "".join(c for c in query[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip() |
|
safe_query = safe_query.replace(' ', '_') |
|
filename = f"comprehensive_report_{safe_query}_{timestamp}.docx" |
|
filepath = os.path.join("temp", filename) |
|
|
|
|
|
os.makedirs("temp", exist_ok=True) |
|
|
|
|
|
doc.save(filepath) |
|
|
|
print(f"β
[REPORT] Comprehensive DOC report generated: {filepath}") |
|
return filepath, None |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating DOC report: {str(e)}" |
|
print(f"β [REPORT] {error_msg}") |
|
return None, error_msg |
|
|
|
def _setup_document_styles(self, doc): |
|
"""Set up professional document styles""" |
|
try: |
|
|
|
from docx.shared import RGBColor |
|
|
|
|
|
title_style = doc.styles.add_style('CustomTitle', WD_STYLE_TYPE.PARAGRAPH) |
|
title_font = title_style.font |
|
title_font.name = 'Calibri' |
|
title_font.size = Pt(24) |
|
title_font.bold = True |
|
title_font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
h1_style = doc.styles.add_style('CustomHeading1', WD_STYLE_TYPE.PARAGRAPH) |
|
h1_font = h1_style.font |
|
h1_font.name = 'Calibri' |
|
h1_font.size = Pt(16) |
|
h1_font.bold = True |
|
h1_font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
h2_style = doc.styles.add_style('CustomHeading2', WD_STYLE_TYPE.PARAGRAPH) |
|
h2_font = h2_style.font |
|
h2_font.name = 'Calibri' |
|
h2_font.size = Pt(14) |
|
h2_font.bold = True |
|
h2_font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
body_style = doc.styles.add_style('CustomBody', WD_STYLE_TYPE.PARAGRAPH) |
|
body_font = body_style.font |
|
body_font.name = 'Calibri' |
|
body_font.size = Pt(11) |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not set up custom styles: {e}") |
|
|
|
def _add_title_page(self, doc, query, user_info): |
|
"""Add professional title page for security analysis report""" |
|
try: |
|
|
|
from docx.shared import RGBColor |
|
|
|
|
|
title = doc.add_paragraph() |
|
title.alignment = WD_ALIGN_PARAGRAPH.CENTER |
|
title_run = title.add_run("SECURITY THREAT ANALYSIS REPORT") |
|
title_run.font.name = 'Calibri' |
|
title_run.font.size = Pt(24) |
|
title_run.font.bold = True |
|
title_run.font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
subtitle = doc.add_paragraph() |
|
subtitle.alignment = WD_ALIGN_PARAGRAPH.CENTER |
|
subtitle_run = subtitle.add_run(f"Threat Intelligence Query: {query}") |
|
subtitle_run.font.name = 'Calibri' |
|
subtitle_run.font.size = Pt(14) |
|
subtitle_run.font.italic = True |
|
|
|
|
|
doc.add_paragraph() |
|
doc.add_paragraph() |
|
|
|
|
|
classification = doc.add_paragraph() |
|
classification.alignment = WD_ALIGN_PARAGRAPH.CENTER |
|
classification_run = classification.add_run("SECURITY ANALYSIS & THREAT INTELLIGENCE") |
|
classification_run.font.name = 'Calibri' |
|
classification_run.font.size = Pt(12) |
|
classification_run.font.bold = True |
|
classification_run.font.color.rgb = RGBColor(220, 53, 69) |
|
|
|
|
|
details = doc.add_paragraph() |
|
details.alignment = WD_ALIGN_PARAGRAPH.CENTER |
|
details_run = details.add_run(f"Generated on: {datetime.now().strftime('%B %d, %Y at %I:%M %p')}") |
|
details_run.font.name = 'Calibri' |
|
details_run.font.size = Pt(11) |
|
|
|
if user_info: |
|
user_details = doc.add_paragraph() |
|
user_details.alignment = WD_ALIGN_PARAGRAPH.CENTER |
|
user_run = user_details.add_run(f"Generated by: {user_info['username']} ({user_info['team']})") |
|
user_run.font.name = 'Calibri' |
|
user_run.font.size = Pt(11) |
|
|
|
|
|
doc.add_page_break() |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not add title page: {e}") |
|
|
|
def _add_executive_summary(self, doc, query, rag_response): |
|
"""Add executive summary section aligned with security analysis framework""" |
|
try: |
|
|
|
from docx.shared import RGBColor |
|
|
|
|
|
heading = doc.add_paragraph() |
|
heading_run = heading.add_run("EXECUTIVE SUMMARY") |
|
heading_run.font.name = 'Calibri' |
|
heading_run.font.size = Pt(16) |
|
heading_run.font.bold = True |
|
heading_run.font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
purpose = doc.add_paragraph() |
|
purpose_run = purpose.add_run("This security analysis report provides comprehensive threat assessment and operational insights based on the query: ") |
|
purpose_run.font.name = 'Calibri' |
|
purpose_run.font.size = Pt(11) |
|
|
|
|
|
query_text = doc.add_paragraph() |
|
query_run = query_text.add_run(f'"{query}"') |
|
query_run.font.name = 'Calibri' |
|
query_run.font.size = Pt(11) |
|
query_run.font.bold = True |
|
|
|
|
|
framework_heading = doc.add_paragraph() |
|
framework_run = framework_heading.add_run("Analysis Framework:") |
|
framework_run.font.name = 'Calibri' |
|
framework_run.font.size = Pt(12) |
|
framework_run.font.bold = True |
|
|
|
|
|
framework_components = [ |
|
"β’ Fact-Finding & Contextualization: Background information and context development", |
|
"β’ Case Study Identification: Incident prevalence and TTP extraction", |
|
"β’ Analytical Assessment: Intent, motivation, and threat landscape evaluation", |
|
"β’ Operational Relevance: Ground-level actionable insights and recommendations" |
|
] |
|
|
|
for component in framework_components: |
|
comp_para = doc.add_paragraph() |
|
comp_run = comp_para.add_run(component) |
|
comp_run.font.name = 'Calibri' |
|
comp_run.font.size = Pt(11) |
|
|
|
|
|
findings_heading = doc.add_paragraph() |
|
findings_run = findings_heading.add_run("Key Findings:") |
|
findings_run.font.name = 'Calibri' |
|
findings_run.font.size = Pt(12) |
|
findings_run.font.bold = True |
|
|
|
|
|
key_points = self._extract_key_points(rag_response) |
|
for point in key_points[:5]: |
|
point_para = doc.add_paragraph() |
|
point_run = point_para.add_run(f"β’ {point}") |
|
point_run.font.name = 'Calibri' |
|
point_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not add executive summary: {e}") |
|
|
|
def _add_detailed_analysis(self, doc, rag_response, cited_pages, page_scores): |
|
"""Add detailed analysis section aligned with security analysis framework""" |
|
try: |
|
|
|
from docx.shared import RGBColor |
|
|
|
|
|
heading = doc.add_paragraph() |
|
heading_run = heading.add_run("DETAILED ANALYSIS") |
|
heading_run.font.name = 'Calibri' |
|
heading_run.font.size = Pt(16) |
|
heading_run.font.bold = True |
|
heading_run.font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
fact_finding_heading = doc.add_paragraph() |
|
fact_finding_run = fact_finding_heading.add_run("1. FACT-FINDING & CONTEXTUALIZATION") |
|
fact_finding_run.font.name = 'Calibri' |
|
fact_finding_run.font.size = Pt(14) |
|
fact_finding_run.font.bold = True |
|
fact_finding_run.font.color.rgb = RGBColor(40, 167, 69) |
|
|
|
fact_finding_para = doc.add_paragraph() |
|
fact_finding_para_run = fact_finding_para.add_run("This section provides background information for readers to understand the origin, development, and context of the subject topic.") |
|
fact_finding_para_run.font.name = 'Calibri' |
|
fact_finding_para_run.font.size = Pt(11) |
|
|
|
|
|
context_info = self._extract_contextual_info(rag_response) |
|
for info in context_info: |
|
info_para = doc.add_paragraph() |
|
info_run = info_para.add_run(f"β’ {info}") |
|
info_run.font.name = 'Calibri' |
|
info_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
|
|
case_study_heading = doc.add_paragraph() |
|
case_study_run = case_study_heading.add_run("2. CASE STUDY IDENTIFICATION") |
|
case_study_run.font.name = 'Calibri' |
|
case_study_run.font.size = Pt(14) |
|
case_study_run.font.bold = True |
|
case_study_run.font.color.rgb = RGBColor(255, 193, 7) |
|
|
|
case_study_para = doc.add_paragraph() |
|
case_study_para_run = case_study_para.add_run("This section provides context and prevalence assessment, highlighting past incidents to establish patterns and extract relevant TTPs for analysis.") |
|
case_study_para_run.font.name = 'Calibri' |
|
case_study_para_run.font.size = Pt(11) |
|
|
|
|
|
case_studies = self._extract_case_studies(rag_response) |
|
for case in case_studies: |
|
case_para = doc.add_paragraph() |
|
case_run = case_para.add_run(f"β’ {case}") |
|
case_run.font.name = 'Calibri' |
|
case_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
|
|
analytical_heading = doc.add_paragraph() |
|
analytical_run = analytical_heading.add_run("3. ANALYTICAL ASSESSMENT") |
|
analytical_run.font.name = 'Calibri' |
|
analytical_run.font.size = Pt(14) |
|
analytical_run.font.bold = True |
|
analytical_run.font.color.rgb = RGBColor(220, 53, 69) |
|
|
|
analytical_para = doc.add_paragraph() |
|
analytical_para_run = analytical_para.add_run("This section evaluates gathered information to assess intent, motivation, TTPs, emerging trends, and relevance to threat landscapes.") |
|
analytical_para_run.font.name = 'Calibri' |
|
analytical_para_run.font.size = Pt(11) |
|
|
|
|
|
analytical_insights = self._extract_analytical_insights(rag_response) |
|
for insight in analytical_insights: |
|
insight_para = doc.add_paragraph() |
|
insight_run = insight_para.add_run(f"β’ {insight}") |
|
insight_run.font.name = 'Calibri' |
|
insight_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
|
|
operational_heading = doc.add_paragraph() |
|
operational_run = operational_heading.add_run("4. OPERATIONAL RELEVANCE") |
|
operational_run.font.name = 'Calibri' |
|
operational_run.font.size = Pt(14) |
|
operational_run.font.bold = True |
|
operational_run.font.color.rgb = RGBColor(111, 66, 193) |
|
|
|
operational_para = doc.add_paragraph() |
|
operational_para_run = operational_para.add_run("This section translates research insights into actionable knowledge for ground-level personnel, highlighting operational risks and procedural recommendations.") |
|
operational_para_run.font.name = 'Calibri' |
|
operational_para_run.font.size = Pt(11) |
|
|
|
|
|
operational_insights = self._extract_operational_insights(rag_response) |
|
for insight in operational_insights: |
|
insight_para = doc.add_paragraph() |
|
insight_run = insight_para.add_run(f"β’ {insight}") |
|
insight_run.font.name = 'Calibri' |
|
insight_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
|
|
main_analysis_heading = doc.add_paragraph() |
|
main_analysis_run = main_analysis_heading.add_run("COMPREHENSIVE ANALYSIS") |
|
main_analysis_run.font.name = 'Calibri' |
|
main_analysis_run.font.size = Pt(12) |
|
main_analysis_run.font.bold = True |
|
|
|
response_para = doc.add_paragraph() |
|
response_run = response_para.add_run(rag_response) |
|
response_run.font.name = 'Calibri' |
|
response_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not add detailed analysis: {e}") |
|
|
|
def _add_methodology_section(self, doc, cited_pages, page_scores): |
|
"""Add methodology section aligned with security analysis framework""" |
|
try: |
|
|
|
from docx.shared import RGBColor |
|
|
|
|
|
heading = doc.add_paragraph() |
|
heading_run = heading.add_run("METHODOLOGY") |
|
heading_run.font.name = 'Calibri' |
|
heading_run.font.size = Pt(16) |
|
heading_run.font.bold = True |
|
heading_run.font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
method_para = doc.add_paragraph() |
|
method_run = method_para.add_run("This security analysis was conducted using advanced AI-powered threat intelligence and document analysis techniques:") |
|
method_run.font.name = 'Calibri' |
|
method_run.font.size = Pt(11) |
|
|
|
|
|
framework_heading = doc.add_paragraph() |
|
framework_run = framework_heading.add_run("Security Analysis Framework:") |
|
framework_run.font.name = 'Calibri' |
|
framework_run.font.size = Pt(12) |
|
framework_run.font.bold = True |
|
|
|
framework_components = [ |
|
"β’ Fact-Finding & Contextualization: Background research and context development", |
|
"β’ Case Study Identification: Incident analysis and TTP extraction", |
|
"β’ Analytical Assessment: Threat landscape evaluation and risk assessment", |
|
"β’ Operational Relevance: Ground-level actionable intelligence generation" |
|
] |
|
|
|
for component in framework_components: |
|
comp_para = doc.add_paragraph() |
|
comp_run = comp_para.add_run(component) |
|
comp_run.font.name = 'Calibri' |
|
comp_run.font.size = Pt(11) |
|
|
|
|
|
sources_heading = doc.add_paragraph() |
|
sources_run = sources_heading.add_run("Intelligence Sources:") |
|
sources_run.font.name = 'Calibri' |
|
sources_run.font.size = Pt(12) |
|
sources_run.font.bold = True |
|
|
|
|
|
for i, citation in enumerate(cited_pages): |
|
source_para = doc.add_paragraph() |
|
source_run = source_para.add_run(f"{i+1}. {citation}") |
|
source_run.font.name = 'Calibri' |
|
source_run.font.size = Pt(11) |
|
|
|
|
|
approach_heading = doc.add_paragraph() |
|
approach_run = approach_heading.add_run("Technical Analysis Approach:") |
|
approach_run.font.name = 'Calibri' |
|
approach_run.font.size = Pt(12) |
|
approach_run.font.bold = True |
|
|
|
approach_para = doc.add_paragraph() |
|
approach_run = approach_para.add_run("β’ Multi-modal document analysis using AI vision models for threat pattern recognition") |
|
approach_run.font.name = 'Calibri' |
|
approach_run.font.size = Pt(11) |
|
|
|
approach2_para = doc.add_paragraph() |
|
approach2_run = approach2_para.add_run("β’ Intelligent content retrieval and relevance scoring for threat intelligence prioritization") |
|
approach2_run.font.name = 'Calibri' |
|
approach2_run.font.size = Pt(11) |
|
|
|
approach3_para = doc.add_paragraph() |
|
approach3_run = approach3_para.add_run("β’ Comprehensive threat synthesis and actionable intelligence generation") |
|
approach3_run.font.name = 'Calibri' |
|
approach3_run.font.size = Pt(11) |
|
|
|
approach4_para = doc.add_paragraph() |
|
approach4_run = approach4_para.add_run("β’ Evidence-based risk assessment and operational recommendation development") |
|
approach4_run.font.name = 'Calibri' |
|
approach4_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not add methodology section: {e}") |
|
|
|
def _add_findings_conclusions(self, doc, rag_response, cited_pages): |
|
"""Add findings and conclusions section aligned with security analysis framework""" |
|
try: |
|
|
|
from docx.shared import RGBColor |
|
|
|
|
|
heading = doc.add_paragraph() |
|
heading_run = heading.add_run("FINDINGS AND CONCLUSIONS") |
|
heading_run.font.name = 'Calibri' |
|
heading_run.font.size = Pt(16) |
|
heading_run.font.bold = True |
|
heading_run.font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
threat_heading = doc.add_paragraph() |
|
threat_run = threat_heading.add_run("Threat Assessment Summary:") |
|
threat_run.font.name = 'Calibri' |
|
threat_run.font.size = Pt(12) |
|
threat_run.font.bold = True |
|
|
|
|
|
threat_findings = self._extract_threat_findings(rag_response) |
|
for finding in threat_findings: |
|
finding_para = doc.add_paragraph() |
|
finding_run = finding_para.add_run(f"β’ {finding}") |
|
finding_run.font.name = 'Calibri' |
|
finding_run.font.size = Pt(11) |
|
|
|
|
|
ttp_heading = doc.add_paragraph() |
|
ttp_run = ttp_heading.add_run("Tactics, Techniques, and Procedures (TTPs):") |
|
ttp_run.font.name = 'Calibri' |
|
ttp_run.font.size = Pt(12) |
|
ttp_run.font.bold = True |
|
|
|
|
|
ttps = self._extract_ttps(rag_response) |
|
for ttp in ttps: |
|
ttp_para = doc.add_paragraph() |
|
ttp_run = ttp_para.add_run(f"β’ {ttp}") |
|
ttp_run.font.name = 'Calibri' |
|
ttp_run.font.size = Pt(11) |
|
|
|
|
|
recommendations_heading = doc.add_paragraph() |
|
recommendations_run = recommendations_heading.add_run("Operational Recommendations:") |
|
recommendations_run.font.name = 'Calibri' |
|
recommendations_run.font.size = Pt(12) |
|
recommendations_run.font.bold = True |
|
|
|
|
|
recommendations = self._extract_operational_recommendations(rag_response) |
|
for rec in recommendations: |
|
rec_para = doc.add_paragraph() |
|
rec_run = rec_para.add_run(f"β’ {rec}") |
|
rec_run.font.name = 'Calibri' |
|
rec_run.font.size = Pt(11) |
|
|
|
|
|
risk_heading = doc.add_paragraph() |
|
risk_run = risk_heading.add_run("Risk Assessment:") |
|
risk_run.font.name = 'Calibri' |
|
risk_run.font.size = Pt(12) |
|
risk_run.font.bold = True |
|
|
|
|
|
risks = self._extract_risk_assessment(rag_response) |
|
for risk in risks: |
|
risk_para = doc.add_paragraph() |
|
risk_run = risk_para.add_run(f"β’ {risk}") |
|
risk_run.font.name = 'Calibri' |
|
risk_run.font.size = Pt(11) |
|
|
|
|
|
conclusions_heading = doc.add_paragraph() |
|
conclusions_run = conclusions_heading.add_run("Conclusions:") |
|
conclusions_run.font.name = 'Calibri' |
|
conclusions_run.font.size = Pt(12) |
|
conclusions_run.font.bold = True |
|
|
|
conclusions_para = doc.add_paragraph() |
|
conclusions_run = conclusions_para.add_run("This security analysis provides actionable intelligence for threat mitigation and operational preparedness. The findings support evidence-based decision making for security operations and risk management.") |
|
conclusions_run.font.name = 'Calibri' |
|
conclusions_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not add findings and conclusions: {e}") |
|
|
|
def _add_appendices(self, doc, cited_pages, page_scores): |
|
"""Add appendices section""" |
|
try: |
|
|
|
from docx.shared import RGBColor |
|
|
|
|
|
heading = doc.add_paragraph() |
|
heading_run = heading.add_run("APPENDICES") |
|
heading_run.font.name = 'Calibri' |
|
heading_run.font.size = Pt(16) |
|
heading_run.font.bold = True |
|
heading_run.font.color.rgb = RGBColor(47, 84, 150) |
|
|
|
|
|
appendix_a = doc.add_paragraph() |
|
appendix_a_run = appendix_a.add_run("Appendix A: Document Sources and Relevance Scores") |
|
appendix_a_run.font.name = 'Calibri' |
|
appendix_a_run.font.size = Pt(12) |
|
appendix_a_run.font.bold = True |
|
|
|
for i, (citation, score) in enumerate(zip(cited_pages, page_scores)): |
|
source_para = doc.add_paragraph() |
|
source_run = source_para.add_run(f"{i+1}. {citation} (Relevance Score: {score:.3f})") |
|
source_run.font.name = 'Calibri' |
|
source_run.font.size = Pt(11) |
|
|
|
doc.add_paragraph() |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not add appendices: {e}") |
|
|
|
def _extract_key_points(self, rag_response): |
|
"""Extract key points from RAG response""" |
|
try: |
|
|
|
sentences = re.split(r'[.!?]+', rag_response) |
|
key_points = [] |
|
|
|
|
|
key_indicators = ['important', 'key', 'critical', 'essential', 'significant', 'major', 'primary', 'main'] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 20 and any(indicator in sentence.lower() for indicator in key_indicators): |
|
key_points.append(sentence) |
|
|
|
|
|
if len(key_points) < 3: |
|
key_points = [s.strip() for s in sentences[:5] if len(s.strip()) > 20] |
|
|
|
return key_points[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract key points: {e}") |
|
return ["Analysis completed successfully", "Comprehensive review performed", "Key insights identified"] |
|
|
|
def _extract_contextual_info(self, rag_response): |
|
"""Extract contextual information for fact-finding section""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
contextual_info = [] |
|
|
|
|
|
context_indicators = [ |
|
'background', 'history', 'origin', 'development', 'context', 'definition', |
|
'introduction', 'overview', 'description', 'characteristics', 'features', |
|
'components', 'types', 'categories', 'classification', 'structure' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in context_indicators): |
|
contextual_info.append(sentence) |
|
|
|
|
|
if len(contextual_info) < 3: |
|
contextual_info = [s.strip() for s in sentences[:3] if len(s.strip()) > 15] |
|
|
|
return contextual_info[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract contextual info: {e}") |
|
return ["Background information extracted from analysis", "Contextual details identified", "Historical context established"] |
|
|
|
def _extract_case_studies(self, rag_response): |
|
"""Extract case study information for incident identification""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
case_studies = [] |
|
|
|
|
|
case_indicators = [ |
|
'incident', 'case', 'example', 'instance', 'occurrence', 'event', |
|
'attack', 'threat', 'vulnerability', 'exploit', 'breach', 'compromise', |
|
'pattern', 'trend', 'frequency', 'prevalence', 'statistics', 'data' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in case_indicators): |
|
case_studies.append(sentence) |
|
|
|
|
|
if len(case_studies) < 3: |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and (re.search(r'\d+', sentence) or any(word in sentence.lower() for word in ['first', 'second', 'third', 'recent', 'previous'])): |
|
case_studies.append(sentence) |
|
|
|
return case_studies[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract case studies: {e}") |
|
return ["Incident patterns identified", "Case study information extracted", "Prevalence data analyzed"] |
|
|
|
def _extract_analytical_insights(self, rag_response): |
|
"""Extract analytical insights for threat assessment""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
analytical_insights = [] |
|
|
|
|
|
analytical_indicators = [ |
|
'intent', 'motivation', 'purpose', 'objective', 'goal', 'target', |
|
'technique', 'procedure', 'method', 'approach', 'strategy', 'tactic', |
|
'trend', 'emerging', 'evolution', 'development', 'change', 'shift', |
|
'threat', 'risk', 'vulnerability', 'impact', 'consequence', 'effect' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in analytical_indicators): |
|
analytical_insights.append(sentence) |
|
|
|
|
|
if len(analytical_insights) < 3: |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(word in sentence.lower() for word in ['because', 'therefore', 'however', 'although', 'while', 'despite']): |
|
analytical_insights.append(sentence) |
|
|
|
return analytical_insights[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract analytical insights: {e}") |
|
return ["Analytical assessment completed", "Threat landscape evaluated", "Risk factors identified"] |
|
|
|
def _extract_operational_insights(self, rag_response): |
|
"""Extract operational insights for ground-level recommendations""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
operational_insights = [] |
|
|
|
|
|
operational_indicators = [ |
|
'recommendation', 'action', 'procedure', 'protocol', 'guideline', |
|
'training', 'awareness', 'vigilance', 'monitoring', 'detection', |
|
'prevention', 'mitigation', 'response', 'recovery', 'preparation', |
|
'equipment', 'tool', 'technology', 'system', 'process', 'workflow' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in operational_indicators): |
|
operational_insights.append(sentence) |
|
|
|
|
|
if len(operational_insights) < 3: |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(word in sentence.lower() for word in ['should', 'must', 'need', 'require', 'implement', 'establish', 'develop']): |
|
operational_insights.append(sentence) |
|
|
|
return operational_insights[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract operational insights: {e}") |
|
return ["Operational recommendations identified", "Ground-level procedures suggested", "Training requirements outlined"] |
|
|
|
def _extract_findings(self, rag_response): |
|
"""Extract findings from RAG response""" |
|
try: |
|
|
|
sentences = re.split(r'[.!?]+', rag_response) |
|
findings = [] |
|
|
|
|
|
finding_indicators = ['found', 'discovered', 'identified', 'revealed', 'shows', 'indicates', 'demonstrates', 'suggests'] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in finding_indicators): |
|
findings.append(sentence) |
|
|
|
|
|
if len(findings) < 3: |
|
findings = [s.strip() for s in sentences[:5] if len(s.strip()) > 15] |
|
|
|
return findings[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract findings: {e}") |
|
return ["Analysis completed successfully", "Comprehensive review performed", "Key insights identified"] |
|
|
|
def _extract_threat_findings(self, rag_response): |
|
"""Extract threat-related findings for security analysis""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
threat_findings = [] |
|
|
|
|
|
threat_indicators = [ |
|
'threat', 'attack', 'vulnerability', 'exploit', 'breach', 'compromise', |
|
'malware', 'phishing', 'social engineering', 'ransomware', 'ddos', |
|
'intrusion', 'infiltration', 'espionage', 'sabotage', 'terrorism' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in threat_indicators): |
|
threat_findings.append(sentence) |
|
|
|
|
|
if len(threat_findings) < 3: |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(word in sentence.lower() for word in ['security', 'risk', 'danger', 'hazard', 'warning']): |
|
threat_findings.append(sentence) |
|
|
|
return threat_findings[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract threat findings: {e}") |
|
return ["Threat assessment completed", "Security vulnerabilities identified", "Risk factors analyzed"] |
|
|
|
def _extract_ttps(self, rag_response): |
|
"""Extract Tactics, Techniques, and Procedures (TTPs)""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
ttps = [] |
|
|
|
|
|
ttp_indicators = [ |
|
'technique', 'procedure', 'method', 'approach', 'strategy', 'tactic', |
|
'process', 'workflow', 'protocol', 'standard', 'practice', 'modus operandi', |
|
'attack vector', 'exploitation', 'infiltration', 'persistence', 'exfiltration' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in ttp_indicators): |
|
ttps.append(sentence) |
|
|
|
|
|
if len(ttps) < 3: |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(word in sentence.lower() for word in ['step', 'phase', 'stage', 'sequence', 'order']): |
|
ttps.append(sentence) |
|
|
|
return ttps[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract TTPs: {e}") |
|
return ["TTP analysis completed", "Attack methods identified", "Procedural patterns extracted"] |
|
|
|
def _extract_operational_recommendations(self, rag_response): |
|
"""Extract operational recommendations for ground-level personnel""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
recommendations = [] |
|
|
|
|
|
recommendation_indicators = [ |
|
'recommend', 'suggest', 'advise', 'propose', 'should', 'must', 'need', |
|
'implement', 'establish', 'develop', 'create', 'adopt', 'apply', |
|
'training', 'awareness', 'education', 'preparation', 'readiness' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in recommendation_indicators): |
|
recommendations.append(sentence) |
|
|
|
|
|
if len(recommendations) < 3: |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(word in sentence.lower() for word in ['action', 'measure', 'step', 'procedure', 'protocol']): |
|
recommendations.append(sentence) |
|
|
|
return recommendations[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract operational recommendations: {e}") |
|
return ["Operational procedures recommended", "Training requirements identified", "Security measures suggested"] |
|
|
|
def _extract_risk_assessment(self, rag_response): |
|
"""Extract risk assessment information""" |
|
try: |
|
sentences = re.split(r'[.!?]+', rag_response) |
|
risks = [] |
|
|
|
|
|
risk_indicators = [ |
|
'risk', 'danger', 'hazard', 'threat', 'vulnerability', 'exposure', |
|
'probability', 'likelihood', 'impact', 'consequence', 'severity', |
|
'critical', 'high', 'medium', 'low', 'minimal', 'significant' |
|
] |
|
|
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(indicator in sentence.lower() for indicator in risk_indicators): |
|
risks.append(sentence) |
|
|
|
|
|
if len(risks) < 3: |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
if len(sentence) > 15 and any(word in sentence.lower() for word in ['potential', 'possible', 'likely', 'unlikely', 'certain']): |
|
risks.append(sentence) |
|
|
|
return risks[:5] |
|
|
|
except Exception as e: |
|
print(f"Warning: Could not extract risk assessment: {e}") |
|
return ["Risk assessment completed", "Vulnerability analysis performed", "Threat evaluation conducted"] |
|
|
|
def _generate_enhanced_excel_export(self, query, rag_response, cited_pages, page_scores, custom_headers=None): |
|
""" |
|
Generate enhanced Excel export with proper formatting for charts and graphs |
|
""" |
|
if not EXCEL_AVAILABLE: |
|
return None, "Excel export not available - openpyxl/pandas libraries not installed" |
|
|
|
try: |
|
print("π [EXCEL] Generating enhanced Excel export...") |
|
|
|
|
|
if custom_headers is None: |
|
custom_headers = self._extract_custom_headers(query) |
|
|
|
|
|
wb = Workbook() |
|
|
|
|
|
wb.remove(wb.active) |
|
|
|
|
|
data_sheet = wb.create_sheet("Data") |
|
|
|
|
|
summary_sheet = wb.create_sheet("Summary") |
|
|
|
|
|
charts_sheet = wb.create_sheet("Charts") |
|
|
|
|
|
structured_data = self._extract_structured_data_for_excel(rag_response, cited_pages, page_scores, custom_headers) |
|
|
|
|
|
self._populate_data_sheet(data_sheet, structured_data, query) |
|
|
|
|
|
self._populate_summary_sheet(summary_sheet, query, cited_pages, page_scores) |
|
|
|
|
|
if self._detect_chart_request(query): |
|
self._create_excel_charts(charts_sheet, structured_data, query, custom_headers) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
safe_query = "".join(c for c in query[:30] if c.isalnum() or c in (' ', '-', '_')).rstrip() |
|
safe_query = safe_query.replace(' ', '_') |
|
filename = f"enhanced_export_{safe_query}_{timestamp}.xlsx" |
|
filepath = os.path.join("temp", filename) |
|
|
|
|
|
os.makedirs("temp", exist_ok=True) |
|
|
|
|
|
wb.save(filepath) |
|
|
|
print(f"β
[EXCEL] Enhanced Excel export generated: {filepath}") |
|
return filepath, None |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating Excel export: {str(e)}" |
|
print(f"β [EXCEL] {error_msg}") |
|
return None, error_msg |
|
|
|
def _extract_structured_data_for_excel(self, rag_response, cited_pages, page_scores, custom_headers=None): |
|
"""Extract structured data specifically for Excel export""" |
|
try: |
|
|
|
if custom_headers: |
|
headers = custom_headers |
|
print(f"π [EXCEL] Using custom headers: {headers}") |
|
else: |
|
|
|
headers = self._auto_detect_excel_headers(rag_response, cited_pages) |
|
print(f"π [EXCEL] Auto-detected headers: {headers}") |
|
|
|
|
|
data_rows = [] |
|
|
|
|
|
if custom_headers: |
|
mapped_data = self._map_data_to_custom_headers(rag_response, cited_pages, page_scores, custom_headers) |
|
if mapped_data: |
|
data_rows.extend(mapped_data) |
|
|
|
|
|
if not data_rows: |
|
|
|
numerical_data = self._extract_numerical_data(rag_response) |
|
if numerical_data: |
|
data_rows.extend(numerical_data) |
|
|
|
|
|
categorical_data = self._extract_categorical_data(rag_response, cited_pages) |
|
if categorical_data: |
|
data_rows.extend(categorical_data) |
|
|
|
|
|
source_data = self._extract_source_data(cited_pages, page_scores) |
|
if source_data: |
|
data_rows.extend(source_data) |
|
|
|
|
|
if not data_rows: |
|
data_rows = self._create_summary_data(rag_response, cited_pages, page_scores) |
|
|
|
return { |
|
'headers': headers, |
|
'data': data_rows |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error extracting structured data for Excel: {e}") |
|
return { |
|
'headers': ['Category', 'Value', 'Description'], |
|
'data': [['Analysis', 'Completed', 'Data extracted successfully']] |
|
} |
|
|
|
def _auto_detect_excel_headers(self, rag_response, cited_pages): |
|
"""Auto-detect contextually appropriate headers for Excel export based on query content""" |
|
try: |
|
headers = [] |
|
|
|
|
|
rag_lower = rag_response.lower() |
|
|
|
|
|
if any(word in rag_lower for word in ['threat', 'attack', 'vulnerability', 'security', 'risk']): |
|
if 'threat' in rag_lower or 'attack' in rag_lower: |
|
headers.append('Threat Type') |
|
if 'frequency' in rag_lower or 'count' in rag_lower or 'percentage' in rag_lower: |
|
headers.append('Frequency') |
|
if 'risk' in rag_lower or 'severity' in rag_lower: |
|
headers.append('Risk Level') |
|
if 'impact' in rag_lower or 'damage' in rag_lower: |
|
headers.append('Impact') |
|
if 'mitigation' in rag_lower or 'solution' in rag_lower: |
|
headers.append('Mitigation') |
|
|
|
|
|
elif any(word in rag_lower for word in ['sales', 'revenue', 'performance', 'growth', 'profit']): |
|
if 'month' in rag_lower or 'quarter' in rag_lower or 'year' in rag_lower: |
|
headers.append('Time Period') |
|
if 'sales' in rag_lower or 'revenue' in rag_lower: |
|
headers.append('Sales/Revenue') |
|
if 'growth' in rag_lower or 'increase' in rag_lower: |
|
headers.append('Growth Rate') |
|
if 'region' in rag_lower or 'location' in rag_lower: |
|
headers.append('Region') |
|
|
|
|
|
elif any(word in rag_lower for word in ['system', 'component', 'device', 'technology', 'software']): |
|
if 'component' in rag_lower or 'device' in rag_lower: |
|
headers.append('Component') |
|
if 'status' in rag_lower or 'condition' in rag_lower: |
|
headers.append('Status') |
|
if 'priority' in rag_lower or 'importance' in rag_lower: |
|
headers.append('Priority') |
|
if 'version' in rag_lower or 'release' in rag_lower: |
|
headers.append('Version') |
|
|
|
|
|
elif any(word in rag_lower for word in ['data', 'statistics', 'analysis', 'report', 'survey']): |
|
if 'category' in rag_lower or 'type' in rag_lower: |
|
headers.append('Category') |
|
if 'value' in rag_lower or 'number' in rag_lower or 'count' in rag_lower: |
|
headers.append('Value') |
|
if 'percentage' in rag_lower or 'rate' in rag_lower: |
|
headers.append('Percentage') |
|
if 'trend' in rag_lower or 'change' in rag_lower: |
|
headers.append('Trend') |
|
|
|
|
|
else: |
|
|
|
if re.search(r'\d+', rag_response): |
|
headers.append('Value') |
|
|
|
|
|
if any(word in rag_lower for word in ['type', 'category', 'class', 'group']): |
|
headers.append('Category') |
|
|
|
|
|
if len(rag_response) > 100: |
|
headers.append('Description') |
|
|
|
|
|
if cited_pages: |
|
headers.append('Source') |
|
|
|
|
|
if any(word in rag_lower for word in ['score', 'rating', 'level', 'grade']): |
|
headers.append('Score') |
|
|
|
|
|
if len(headers) < 2: |
|
if 'Category' not in headers: |
|
headers.append('Category') |
|
if 'Value' not in headers: |
|
headers.append('Value') |
|
|
|
if len(headers) < 3: |
|
if 'Description' not in headers: |
|
headers.append('Description') |
|
|
|
|
|
headers = headers[:4] |
|
|
|
print(f"π [EXCEL] Auto-detected contextually relevant headers: {headers}") |
|
return headers |
|
|
|
except Exception as e: |
|
print(f"Error auto-detecting headers: {e}") |
|
return ['Category', 'Value', 'Description'] |
|
|
|
def _extract_numerical_data(self, rag_response): |
|
"""Extract numerical data from RAG response""" |
|
try: |
|
data_rows = [] |
|
|
|
|
|
number_patterns = [ |
|
r'(\d+(?:\.\d+)?)\s*(percent|%|units|items|components|devices|procedures)', |
|
r'(\d+(?:\.\d+)?)\s*(voltage|current|resistance|power|frequency)', |
|
r'(\d+(?:\.\d+)?)\s*(safety|risk|danger|warning)', |
|
r'(\d+(?:\.\d+)?)\s*(steps|phases|stages|levels)' |
|
] |
|
|
|
for pattern in number_patterns: |
|
matches = re.findall(pattern, rag_response, re.IGNORECASE) |
|
for match in matches: |
|
value, category = match |
|
data_rows.append([category.title(), value, f"Found in analysis"]) |
|
|
|
return data_rows |
|
|
|
except Exception as e: |
|
print(f"Error extracting numerical data: {e}") |
|
return [] |
|
|
|
def _extract_categorical_data(self, rag_response, cited_pages): |
|
"""Extract categorical data from RAG response""" |
|
try: |
|
data_rows = [] |
|
|
|
|
|
categories = [] |
|
|
|
|
|
category_patterns = [ |
|
r'(safety|security|warning|danger|risk)', |
|
r'(procedure|method|technique|approach)', |
|
r'(component|device|equipment|tool)', |
|
r'(type|category|class|group)', |
|
r'(input|output|control|monitoring)' |
|
] |
|
|
|
for pattern in category_patterns: |
|
matches = re.findall(pattern, rag_response, re.IGNORECASE) |
|
categories.extend(matches) |
|
|
|
|
|
categories = list(set(categories)) |
|
|
|
for category in categories[:10]: |
|
data_rows.append([category.title(), 'Identified', f"Category found in analysis"]) |
|
|
|
return data_rows |
|
|
|
except Exception as e: |
|
print(f"Error extracting categorical data: {e}") |
|
return [] |
|
|
|
def _extract_source_data(self, cited_pages, page_scores): |
|
"""Extract source information for Excel""" |
|
try: |
|
data_rows = [] |
|
|
|
for i, (citation, score) in enumerate(zip(cited_pages, page_scores)): |
|
collection = citation.split(' from ')[1] if ' from ' in citation else 'Unknown' |
|
page_num = citation.split('Page ')[1].split(' from')[0] if 'Page ' in citation else str(i+1) |
|
|
|
data_rows.append([ |
|
f"Source {i+1}", |
|
collection, |
|
f"Page {page_num} (Score: {score:.3f})" |
|
]) |
|
|
|
return data_rows |
|
|
|
except Exception as e: |
|
print(f"Error extracting source data: {e}") |
|
return [] |
|
|
|
def _map_data_to_custom_headers(self, rag_response, cited_pages, page_scores, custom_headers): |
|
"""Map extracted data to custom headers for Excel export with context-aware sample data""" |
|
try: |
|
data_rows = [] |
|
|
|
|
|
numerical_data = self._extract_numerical_data(rag_response) |
|
categorical_data = self._extract_categorical_data(rag_response, cited_pages) |
|
source_data = self._extract_source_data(cited_pages, page_scores) |
|
|
|
|
|
all_data = [] |
|
if numerical_data: |
|
all_data.extend(numerical_data) |
|
if categorical_data: |
|
all_data.extend(categorical_data) |
|
if source_data: |
|
all_data.extend(source_data) |
|
|
|
|
|
for i, data_row in enumerate(all_data): |
|
mapped_row = [] |
|
|
|
|
|
while len(mapped_row) < len(custom_headers): |
|
if len(data_row) > len(mapped_row): |
|
mapped_row.append(data_row[len(mapped_row)]) |
|
else: |
|
|
|
header = custom_headers[len(mapped_row)] |
|
mapped_row.append(self._generate_contextual_sample_data(header, i, rag_response)) |
|
|
|
|
|
mapped_row = mapped_row[:len(custom_headers)] |
|
data_rows.append(mapped_row) |
|
|
|
|
|
if not data_rows: |
|
data_rows = self._create_contextual_sample_data(custom_headers, rag_response) |
|
|
|
print(f"π [EXCEL] Mapped {len(data_rows)} rows to custom headers") |
|
return data_rows |
|
|
|
except Exception as e: |
|
print(f"Error mapping data to custom headers: {e}") |
|
return [] |
|
|
|
def _generate_contextual_sample_data(self, header, index, rag_response): |
|
"""Generate contextually relevant sample data based on header and content""" |
|
try: |
|
header_lower = header.lower() |
|
rag_lower = rag_response.lower() |
|
|
|
|
|
if any(word in rag_lower for word in ['threat', 'attack', 'security', 'vulnerability']): |
|
if 'threat' in header_lower or 'attack' in header_lower: |
|
threats = ['Phishing', 'Malware', 'DDoS', 'Social Engineering', 'Ransomware'] |
|
return threats[index % len(threats)] |
|
elif 'frequency' in header_lower or 'count' in header_lower: |
|
return str((index + 1) * 15) + '%' |
|
elif 'risk' in header_lower or 'severity' in header_lower: |
|
risk_levels = ['Low', 'Medium', 'High', 'Critical'] |
|
return risk_levels[index % len(risk_levels)] |
|
elif 'impact' in header_lower: |
|
impacts = ['Minimal', 'Moderate', 'Significant', 'Severe'] |
|
return impacts[index % len(impacts)] |
|
elif 'mitigation' in header_lower: |
|
mitigations = ['Training', 'Firewall', 'Monitoring', 'Backup'] |
|
return mitigations[index % len(mitigations)] |
|
|
|
|
|
elif any(word in rag_lower for word in ['sales', 'revenue', 'business', 'performance']): |
|
if 'time' in header_lower or 'period' in header_lower: |
|
periods = ['Q1 2024', 'Q2 2024', 'Q3 2024', 'Q4 2024'] |
|
return periods[index % len(periods)] |
|
elif 'sales' in header_lower or 'revenue' in header_lower: |
|
return f"${(index + 1) * 10000:,}" |
|
elif 'growth' in header_lower: |
|
return f"+{(index + 1) * 5}%" |
|
elif 'region' in header_lower: |
|
regions = ['North', 'South', 'East', 'West'] |
|
return regions[index % len(regions)] |
|
|
|
|
|
elif any(word in rag_lower for word in ['system', 'component', 'device', 'technology']): |
|
if 'component' in header_lower: |
|
components = ['Server', 'Database', 'Network', 'Application'] |
|
return components[index % len(components)] |
|
elif 'status' in header_lower: |
|
statuses = ['Active', 'Inactive', 'Maintenance', 'Error'] |
|
return statuses[index % len(statuses)] |
|
elif 'priority' in header_lower: |
|
priorities = ['Low', 'Medium', 'High', 'Critical'] |
|
return priorities[index % len(priorities)] |
|
elif 'version' in header_lower: |
|
return f"v{index + 1}.{index + 2}" |
|
|
|
|
|
else: |
|
if any(word in header_lower for word in ['name', 'title', 'category', 'type']): |
|
return f"Item {index + 1}" |
|
elif any(word in header_lower for word in ['value', 'score', 'number', 'count']): |
|
return str((index + 1) * 10) |
|
elif any(word in header_lower for word in ['description', 'detail', 'info']): |
|
return f"Sample description for {header}" |
|
else: |
|
return f"Sample {header} {index + 1}" |
|
|
|
except Exception as e: |
|
print(f"Error generating contextual sample data: {e}") |
|
return f"Sample {header} {index + 1}" |
|
|
|
def _create_contextual_sample_data(self, custom_headers, rag_response): |
|
"""Create contextually relevant sample data based on headers and content""" |
|
try: |
|
data_rows = [] |
|
rag_lower = rag_response.lower() |
|
|
|
|
|
if any(word in rag_lower for word in ['threat', 'attack', 'security']): |
|
sample_count = 4 |
|
elif any(word in rag_lower for word in ['sales', 'revenue', 'business']): |
|
sample_count = 4 |
|
elif any(word in rag_lower for word in ['system', 'component', 'device']): |
|
sample_count = 4 |
|
else: |
|
sample_count = 5 |
|
|
|
for i in range(sample_count): |
|
sample_row = [] |
|
for header in custom_headers: |
|
sample_row.append(self._generate_contextual_sample_data(header, i, rag_response)) |
|
data_rows.append(sample_row) |
|
|
|
return data_rows |
|
|
|
except Exception as e: |
|
print(f"Error creating contextual sample data: {e}") |
|
return [] |
|
|
|
def _create_summary_data(self, rag_response, cited_pages, page_scores): |
|
"""Create summary data when no structured data is found""" |
|
try: |
|
data_rows = [] |
|
|
|
|
|
data_rows.append(['Analysis Type', 'Comprehensive Review', 'AI-powered document analysis']) |
|
|
|
|
|
data_rows.append(['Sources Analyzed', str(len(cited_pages)), f"From {len(set([p.split(' from ')[1] for p in cited_pages if ' from ' in p]))} collections"]) |
|
|
|
|
|
if page_scores: |
|
avg_score = sum(page_scores) / len(page_scores) |
|
data_rows.append(['Average Relevance', f"{avg_score:.3f}", 'Based on AI relevance scoring']) |
|
|
|
|
|
data_rows.append(['Response Length', f"{len(rag_response)} characters", 'Comprehensive analysis provided']) |
|
|
|
return data_rows |
|
|
|
except Exception as e: |
|
print(f"Error creating summary data: {e}") |
|
return [['Analysis', 'Completed', 'Data extracted successfully']] |
|
|
|
def _populate_data_sheet(self, sheet, structured_data, query): |
|
"""Populate the data sheet with structured information""" |
|
try: |
|
|
|
sheet['A1'] = f"Data Export for Query: {query}" |
|
sheet['A1'].font = Font(bold=True, size=14) |
|
sheet['A1'].fill = PatternFill(start_color="2F5496", end_color="2F5496", fill_type="solid") |
|
sheet['A1'].font = Font(color="FFFFFF", bold=True) |
|
|
|
|
|
headers = structured_data['headers'] |
|
for col, header in enumerate(headers, 1): |
|
cell = sheet.cell(row=3, column=col, value=header) |
|
cell.font = Font(bold=True) |
|
cell.fill = PatternFill(start_color="D9E2F3", end_color="D9E2F3", fill_type="solid") |
|
cell.border = Border( |
|
left=Side(style='thin'), |
|
right=Side(style='thin'), |
|
top=Side(style='thin'), |
|
bottom=Side(style='thin') |
|
) |
|
|
|
|
|
data = structured_data['data'] |
|
for row_idx, row_data in enumerate(data, 4): |
|
for col_idx, value in enumerate(row_data, 1): |
|
cell = sheet.cell(row=row_idx, column=col_idx, value=value) |
|
cell.border = Border( |
|
left=Side(style='thin'), |
|
right=Side(style='thin'), |
|
top=Side(style='thin'), |
|
bottom=Side(style='thin') |
|
) |
|
|
|
|
|
for column in sheet.columns: |
|
max_length = 0 |
|
column_letter = column[0].column_letter |
|
for cell in column: |
|
try: |
|
if len(str(cell.value)) > max_length: |
|
max_length = len(str(cell.value)) |
|
except: |
|
pass |
|
adjusted_width = min(max_length + 2, 50) |
|
sheet.column_dimensions[column_letter].width = adjusted_width |
|
|
|
except Exception as e: |
|
print(f"Error populating data sheet: {e}") |
|
|
|
def _populate_summary_sheet(self, sheet, query, cited_pages, page_scores): |
|
"""Populate the summary sheet with analysis overview""" |
|
try: |
|
|
|
sheet['A1'] = "Analysis Summary" |
|
sheet['A1'].font = Font(bold=True, size=16) |
|
sheet['A1'].fill = PatternFill(start_color="2F5496", end_color="2F5496", fill_type="solid") |
|
sheet['A1'].font = Font(color="FFFFFF", bold=True) |
|
|
|
|
|
sheet['A3'] = "Query:" |
|
sheet['A3'].font = Font(bold=True) |
|
sheet['B3'] = query |
|
|
|
|
|
sheet['A5'] = "Analysis Statistics:" |
|
sheet['A5'].font = Font(bold=True) |
|
|
|
sheet['A6'] = "Sources Analyzed:" |
|
sheet['B6'] = len(cited_pages) |
|
|
|
sheet['A7'] = "Collections Used:" |
|
collections = set([p.split(' from ')[1] for p in cited_pages if ' from ' in p]) |
|
sheet['B7'] = len(collections) |
|
|
|
if page_scores: |
|
sheet['A8'] = "Average Relevance Score:" |
|
avg_score = sum(page_scores) / len(page_scores) |
|
sheet['B8'] = f"{avg_score:.3f}" |
|
|
|
sheet['A9'] = "Analysis Date:" |
|
sheet['B9'] = datetime.now().strftime('%B %d, %Y at %I:%M %p') |
|
|
|
|
|
sheet['A11'] = "Source Details:" |
|
sheet['A11'].font = Font(bold=True) |
|
|
|
for i, (citation, score) in enumerate(zip(cited_pages, page_scores)): |
|
row = 12 + i |
|
sheet[f'A{row}'] = f"Source {i+1}:" |
|
sheet[f'B{row}'] = citation |
|
sheet[f'C{row}'] = f"Score: {score:.3f}" |
|
|
|
|
|
for column in sheet.columns: |
|
max_length = 0 |
|
column_letter = column[0].column_letter |
|
for cell in column: |
|
try: |
|
if len(str(cell.value)) > max_length: |
|
max_length = len(str(cell.value)) |
|
except: |
|
pass |
|
adjusted_width = min(max_length + 2, 50) |
|
sheet.column_dimensions[column_letter].width = adjusted_width |
|
|
|
except Exception as e: |
|
print(f"Error populating summary sheet: {e}") |
|
|
|
def _create_excel_charts(self, sheet, structured_data, query, custom_headers=None): |
|
"""Create Excel charts based on the data with custom headers""" |
|
try: |
|
|
|
sheet['A1'] = "Data Visualizations" |
|
sheet['A1'].font = Font(bold=True, size=16) |
|
sheet['A1'].fill = PatternFill(start_color="2F5496", end_color="2F5496", fill_type="solid") |
|
sheet['A1'].font = Font(color="FFFFFF", bold=True) |
|
|
|
|
|
if custom_headers and len(custom_headers) >= 2: |
|
|
|
x_axis_title = custom_headers[0] if len(custom_headers) > 0 else "Categories" |
|
y_axis_title = custom_headers[1] if len(custom_headers) > 1 else "Values" |
|
|
|
|
|
if len(custom_headers) >= 3: |
|
chart_title = f"Analysis: {x_axis_title} vs {y_axis_title} by {custom_headers[2]}" |
|
else: |
|
chart_title = f"Analysis: {x_axis_title} vs {y_axis_title}" |
|
|
|
|
|
if len(structured_data['data']) > 1: |
|
chart = BarChart() |
|
chart.title = chart_title |
|
chart.x_axis.title = x_axis_title |
|
chart.y_axis.title = y_axis_title |
|
|
|
|
|
sheet.add_chart(chart, "A3") |
|
|
|
|
|
if len(structured_data['data']) > 2 and len(custom_headers) >= 3: |
|
pie_chart = PieChart() |
|
pie_chart.title = f"Distribution by {custom_headers[2]}" |
|
|
|
|
|
sheet.add_chart(pie_chart, "A15") |
|
elif len(structured_data['data']) > 2: |
|
|
|
pie_chart = PieChart() |
|
pie_chart.title = "Data Distribution" |
|
sheet.add_chart(pie_chart, "A15") |
|
else: |
|
|
|
if len(structured_data['data']) > 1: |
|
chart = BarChart() |
|
chart.title = f"Analysis Results for: {query[:30]}..." |
|
chart.x_axis.title = "Categories" |
|
chart.y_axis.title = "Values" |
|
|
|
|
|
sheet.add_chart(chart, "A3") |
|
|
|
|
|
if len(structured_data['data']) > 2: |
|
pie_chart = PieChart() |
|
pie_chart.title = "Data Distribution" |
|
|
|
|
|
sheet.add_chart(pie_chart, "A15") |
|
|
|
except Exception as e: |
|
print(f"Error creating Excel charts: {e}") |
|
|
|
def _prepare_doc_download(self, doc_filepath): |
|
""" |
|
Prepare DOC file for download in Gradio |
|
""" |
|
if doc_filepath and os.path.exists(doc_filepath): |
|
return doc_filepath |
|
else: |
|
return None |
|
|
|
def _prepare_excel_download(self, excel_filepath): |
|
""" |
|
Prepare Excel file for download in Gradio |
|
""" |
|
if excel_filepath and os.path.exists(excel_filepath): |
|
return excel_filepath |
|
else: |
|
return None |
|
|
|
def _generate_multi_page_response(self, query, img_paths, cited_pages, page_scores): |
|
""" |
|
Enhanced RAG response generation with multi-page citations |
|
Implements comprehensive detail enhancement based on research strategies |
|
""" |
|
try: |
|
|
|
detailed_prompt = f""" |
|
Please provide a comprehensive and detailed answer to the following query. |
|
Use all available information from the provided document pages to give a thorough response. |
|
|
|
Query: {query} |
|
|
|
Instructions for detailed response: |
|
1. Provide extensive background information and context |
|
2. Include specific details, examples, and data points from the documents |
|
3. Explain concepts thoroughly with step-by-step breakdowns |
|
4. Provide comprehensive analysis rather than simple answers when requested |
|
|
|
""" |
|
|
|
|
|
rag_response = rag.get_answer_from_openai(detailed_prompt, img_paths) |
|
|
|
|
|
citation_text = "π **Sources**:\n\n" |
|
|
|
|
|
collection_groups = {} |
|
for i, citation in enumerate(cited_pages): |
|
collection_name = citation.split(" from ")[1].split(" (")[0] |
|
if collection_name not in collection_groups: |
|
collection_groups[collection_name] = [] |
|
collection_groups[collection_name].append(citation) |
|
|
|
|
|
for collection_name, citations in collection_groups.items(): |
|
citation_text += f"π **{collection_name}**:\n" |
|
for citation in citations: |
|
|
|
clean_citation = citation.split(" (Relevance:")[0] |
|
citation_text += f" β’ {clean_citation}\n" |
|
citation_text += "\n" |
|
|
|
|
|
csv_filepath = None |
|
doc_filepath = None |
|
excel_filepath = None |
|
|
|
|
|
if self._detect_table_request(query): |
|
print("π Table request detected - generating CSV response") |
|
enhanced_rag_response, csv_filepath = self._generate_csv_table_response(query, rag_response, cited_pages, page_scores) |
|
else: |
|
enhanced_rag_response = rag_response |
|
|
|
|
|
if self._detect_report_request(query): |
|
print("π Report request detected - generating DOC report") |
|
doc_filepath, doc_error = self._generate_comprehensive_doc_report(query, rag_response, cited_pages, page_scores) |
|
if doc_error: |
|
print(f"β οΈ DOC report generation failed: {doc_error}") |
|
|
|
|
|
if self._detect_chart_request(query) or self._detect_table_request(query): |
|
print("π Chart/Excel request detected - generating enhanced Excel export") |
|
|
|
excel_custom_headers = self._extract_custom_headers(query) |
|
excel_filepath, excel_error = self._generate_enhanced_excel_export(query, rag_response, cited_pages, page_scores, excel_custom_headers) |
|
if excel_error: |
|
print(f"β οΈ Excel export generation failed: {excel_error}") |
|
|
|
|
|
export_info = "" |
|
|
|
if doc_filepath: |
|
export_info += f""" |
|
π **Comprehensive Report Generated**: |
|
β’ **Format**: Microsoft Word Document (.docx) |
|
β’ **Content**: Executive summary, detailed analysis, methodology, findings, and appendices |
|
β’ **Download**: Available below |
|
""" |
|
|
|
if excel_filepath: |
|
export_info += f""" |
|
π **Enhanced Excel Export Generated**: |
|
β’ **Format**: Microsoft Excel (.xlsx) |
|
β’ **Content**: Multiple sheets with data, summary, and charts |
|
β’ **Features**: Formatted tables, auto-generated charts, source analysis |
|
β’ **Download**: Available below |
|
""" |
|
|
|
if csv_filepath: |
|
export_info += f""" |
|
π **CSV Table Generated**: |
|
β’ **Format**: Comma-Separated Values (.csv) |
|
β’ **Content**: Structured data table |
|
β’ **Download**: Available below |
|
""" |
|
|
|
final_response = f""" |
|
{enhanced_rag_response} |
|
|
|
{citation_text} |
|
|
|
{export_info} |
|
""" |
|
|
|
return final_response, csv_filepath, doc_filepath, excel_filepath |
|
|
|
except Exception as e: |
|
print(f"Error generating multi-page response: {e}") |
|
|
|
return rag.get_answer_from_openai(detailed_prompt, img_paths), None, None, None |
|
|
|
def authenticate_user(self, username, password): |
|
"""Authenticate user and create session""" |
|
user_info = self.db_manager.authenticate_user(username, password) |
|
if user_info: |
|
session_id = self.session_manager.create_session(user_info) |
|
return f"Welcome {user_info['username']} from {user_info['team']}!", session_id, user_info['team'] |
|
else: |
|
return "Invalid username or password", None, None |
|
|
|
def logout_user(self, session_id): |
|
"""Logout user and remove session""" |
|
if session_id: |
|
self.session_manager.remove_session(session_id) |
|
return "Logged out successfully", None, None |
|
|
|
def get_chat_history(self, session_id, limit=10): |
|
"""Get chat history for logged-in user in a user-friendly format""" |
|
if not session_id: |
|
return "π **Please log in to view chat history**" |
|
|
|
session = self.session_manager.get_session(session_id) |
|
if not session: |
|
return "β° **Session expired. Please log in again.**" |
|
|
|
user_info = session['user_info'] |
|
history = self.db_manager.get_chat_history(user_info['id'], limit) |
|
|
|
if not history: |
|
return "π **No chat history found.**\n\nStart a conversation to see your chat history here!" |
|
|
|
|
|
def format_timestamp(timestamp_str): |
|
try: |
|
|
|
dt = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00')) |
|
return dt.strftime("%B %d, %Y at %I:%M %p") |
|
except: |
|
return timestamp_str |
|
|
|
|
|
def truncate_response(response, max_length=300): |
|
if len(response) <= max_length: |
|
return response |
|
return response[:max_length] + "..." |
|
|
|
history_text = f""" |
|
# π¬ Chat History for {user_info['username']} ({user_info['team']}) |
|
|
|
π **Showing last {len(history)} conversations** |
|
|
|
--- |
|
""" |
|
|
|
for i, entry in enumerate(reversed(history), 1): |
|
|
|
conversation_entry = f""" |
|
## π¨οΈ Conversation #{len(history) - i + 1} |
|
|
|
**β Your Question:** |
|
{entry['query']} |
|
|
|
**π€ AI Response:** |
|
{truncate_response(entry['response'])} |
|
|
|
**π Sources Referenced:** |
|
{', '.join(entry['cited_pages']) if entry['cited_pages'] else 'No specific pages cited'} |
|
|
|
**π
Date:** {format_timestamp(entry['timestamp'])} |
|
|
|
--- |
|
""" |
|
history_text += conversation_entry |
|
|
|
|
|
history_text += f""" |
|
## π Summary |
|
β’ **Total Conversations:** {len(history)} |
|
β’ **Date Range:** {format_timestamp(history[-1]['timestamp'])} to {format_timestamp(history[0]['timestamp'])} |
|
β’ **Team:** {user_info['team']} |
|
β’ **User:** {user_info['username']} |
|
""" |
|
|
|
return history_text |
|
|
|
def clear_chat_history(self, session_id): |
|
"""Clear chat history for logged-in user""" |
|
if not session_id: |
|
return "π **Please log in to manage chat history**" |
|
|
|
session = self.session_manager.get_session(session_id) |
|
if not session: |
|
return "β° **Session expired. Please log in again.**" |
|
|
|
user_info = session['user_info'] |
|
success = self.db_manager.clear_chat_history(user_info['id']) |
|
|
|
if success: |
|
return "ποΈ **Chat history cleared successfully!**\n\nYour conversation history has been removed." |
|
else: |
|
return "β **Error clearing chat history.**\n\nPlease try again or contact support." |
|
|
|
def get_team_collections(self, session_id): |
|
"""Get available collections for the user's team""" |
|
if not session_id: |
|
return "Please log in to view team collections" |
|
|
|
session = self.session_manager.get_session(session_id) |
|
if not session: |
|
return "Session expired. Please log in again." |
|
|
|
team = session['user_info']['team'] |
|
collections = self.db_manager.get_team_collections(team) |
|
|
|
if not collections: |
|
return f"No collections found for {team}" |
|
|
|
return f"**{team} Collections:**\n" + "\n".join([f"- {coll}" for coll in collections]) |
|
|
|
def delete(self, state, choice, session_id=None): |
|
"""Delete collection with team-based access control""" |
|
if session_id: |
|
session = self.session_manager.get_session(session_id) |
|
if not session: |
|
return "Session expired. Please log in again." |
|
|
|
team = session['user_info']['team'] |
|
|
|
team_collections = self.db_manager.get_team_collections(team) |
|
if choice not in team_collections: |
|
return f"Access denied. Collection {choice} does not belong to {team}" |
|
|
|
|
|
client = MilvusClient( |
|
uri="http://localhost:19530", |
|
token="root:Milvus" |
|
) |
|
path = f"pages/{choice}" |
|
if os.path.exists(path): |
|
shutil.rmtree(path) |
|
|
|
client.drop_collection(collection_name=choice) |
|
return f"Deleted {choice}" |
|
else: |
|
return "Directory not found" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def describe_image_with_gemma3(self, image): |
|
"""Describe image using Gemma3 vision model via Ollama""" |
|
try: |
|
print("π [CIRCUIT] Starting image description with Gemma3...") |
|
|
|
if image is None: |
|
print("β [CIRCUIT] No image provided") |
|
return "No image provided" |
|
|
|
print("πΈ [CIRCUIT] Converting image to base64...") |
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
print("β
[CIRCUIT] Image converted successfully") |
|
|
|
|
|
print("π€ [CIRCUIT] Preparing request for Gemma3 model...") |
|
payload = { |
|
"model": "gemma3:4b", |
|
"prompt": "Just generate a netlist of circuit components of the image with explanations ONLY, NO OTHER TEXT", |
|
"images": [img_str], |
|
"stream": False |
|
} |
|
|
|
print("π [CIRCUIT] Sending request to Ollama Gemma3...") |
|
|
|
response = requests.post("http://localhost:11434/api/generate", json=payload, timeout=1200) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
description = result.get('response', 'No description generated') |
|
print(f"β
[CIRCUIT] Image description completed successfully") |
|
print(f"π [CIRCUIT] Description length: {len(description)} characters") |
|
return description |
|
else: |
|
error_msg = f"Error: {response.status_code} - {response.text}" |
|
print(f"β [CIRCUIT] {error_msg}") |
|
return error_msg |
|
|
|
except Exception as e: |
|
error_msg = f"Error describing image: {str(e)}" |
|
print(f"β [CIRCUIT] {error_msg}") |
|
return error_msg |
|
|
|
def generate_circuit_with_deepseek(self, image_description, max_retries=3): |
|
"""Generate netlist and circuit diagram using DeepSeek R1 with error handling and retry logic""" |
|
previous_error = None |
|
consecutive_failures = 0 |
|
|
|
for attempt in range(max_retries): |
|
try: |
|
print(f"π§ [CIRCUIT] Starting circuit generation with DeepSeek R1 (Attempt {attempt + 1}/{max_retries})...") |
|
|
|
if not image_description or image_description == "No image provided": |
|
print("β [CIRCUIT] No image description available") |
|
return "No image description available" |
|
|
|
print("π [CIRCUIT] Preparing prompt for DeepSeek R1...") |
|
|
|
|
|
if attempt == 0: |
|
|
|
unique_filename = self._generate_unique_filename() |
|
|
|
|
|
circuit_data = self._parse_complex_circuit_description(image_description) |
|
|
|
|
|
if circuit_data and circuit_data.get('complexity_level') in ['complex', 'very_complex']: |
|
print(f"Using specialized prompt for {circuit_data['complexity_level']} circuit") |
|
prompt = self._generate_complex_circuit_prompt(circuit_data, unique_filename) |
|
if not prompt: |
|
|
|
prompt = f"""Generate a complex circuit diagram using the python schemdraw library based on this detailed description. |
|
|
|
COMPLEX CIRCUIT REQUIREMENTS: |
|
1. **Component Mapping**: Map ALL components from the description to schemdraw equivalents: |
|
- Resistors: elm.Resistor with proper values |
|
- Capacitors: elm.Capacitor with proper values |
|
- Inductors: elm.Inductor with proper values |
|
- Diodes: elm.Diode, elm.LED, elm.Zener with proper types |
|
- Transistors: elm.Transistor, elm.BjtNpn, elm.BjtPnp, elm.FetN, elm.FetP |
|
- ICs: elm.RBox with proper labels and pin configurations |
|
- Power sources: elm.SourceV, elm.Battery, elm.SourceSin, elm.SourceSquare |
|
- Switches: elm.Switch, elm.SwitchSpdt |
|
- Connectors: elm.Connector, elm.Dot for connection points |
|
|
|
2. **Complex Topology Handling**: |
|
- Use elm.Dot for wire junctions and connection points |
|
- Use elm.Line for explicit wire connections |
|
- Use elm.Label for power rails and voltage/current labels |
|
- Use elm.Text for component labels and values |
|
- Use elm.Node for connection nodes |
|
- Handle multiple power rails (VCC, GND, VDD, etc.) |
|
- Support feedback loops and control paths |
|
- Handle parallel and series connections properly |
|
|
|
3. **Advanced Positioning**: |
|
- Use .up(), .down(), .left(), .right() for basic positioning |
|
- Use .to() for precise connections: .to(d.elements[0].start) |
|
- Use .at() for absolute positioning when needed |
|
- Use .move() for relative positioning |
|
- Arrange components in logical blocks and sections |
|
- Use consistent spacing and alignment |
|
|
|
4. **Component Labeling**: |
|
- Label ALL components with their values and designators |
|
- Use .label() method for component values |
|
- Use elm.Text for additional labels and annotations |
|
- Include voltage/current ratings where applicable |
|
- Add pin numbers for ICs and connectors |
|
|
|
5. **Circuit Organization**: |
|
- Group related components together |
|
- Use clear signal flow from left to right or top to bottom |
|
- Separate power supply sections from signal processing |
|
- Use consistent naming conventions |
|
- Minimize wire crossings and clutter |
|
|
|
IMPORTANT REQUIREMENTS: |
|
1. Use ONLY ASCII characters - replace Ξ© with 'Ohm', ΞΌ with 'u', Β° with 'deg' |
|
2. Use ONLY components available in schemdraw.elements library |
|
3. If a component is not in schemdraw.elements, use elm.RBox and label it appropriately |
|
4. Do NOT use matplotlib or any other plotting library |
|
5. Generate a complete, executable Python script |
|
6. ALWAYS use d.save() to save the diagram, NEVER use d.draw() |
|
7. Save the output as a PNG file with the EXACT filename: {unique_filename} |
|
8. Handle all connections properly using schemdraw's native positioning methods |
|
9. Create a functional circuit that matches the description - all components must be properly connected |
|
10. INCLUDE ALL COMPONENTS mentioned in the description - do not miss any components |
|
11. Use .to() method for precise connections and circuit completion |
|
12. Support complex topologies with multiple power rails and signal paths |
|
13. NEVER use d.element - this is INVALID and will cause errors |
|
14. NEVER use d.last_end, d.last_start, d.end, d.start, d.position - these are INVALID attributes |
|
15. CRITICAL: If you use d.element, the circuit will fail validation and not be generated |
|
|
|
Description of the circuit: {image_description} |
|
|
|
CORRECT SCHEMDRAW API USAGE: |
|
- Use d += elm.Component() to add components |
|
- Use .up(), .down(), .left(), .right() for positioning |
|
- Use .to() to connect to specific points: .to(d.elements[0].start) |
|
- Use .label() to add labels: .label('10V') |
|
- Use .at() for absolute positioning: .at((x, y)) |
|
- Use d.save() to save the diagram |
|
- Use elm.Dot for connection points |
|
- NEVER use d.element - this is INVALID and will cause errors |
|
- ALWAYS use d.elements[-1] instead of d.element |
|
- NEVER use d.last_end, d.last_start, d.end, d.start, d.position - these are INVALID attributes |
|
- Use elm.Line for explicit wire connections |
|
- Use elm.Text for additional labels |
|
- DO NOT use: d.last_end, d.last_start, d.end, d.start, d.position, d.element |
|
|
|
COMPLEX CIRCUIT EXAMPLE (for reference only): |
|
```python |
|
import schemdraw |
|
import schemdraw.elements as elm |
|
|
|
d = schemdraw.Drawing() |
|
# Power supply section |
|
d += elm.SourceV().up().label('12V').at((0, 0)) |
|
d += elm.Resistor().right().label('1KOhm') |
|
d += elm.Capacitor().down().label('100uF') |
|
d += elm.Line().left().to(d.elements[0].start) # Close main loop |
|
|
|
# Signal processing section |
|
d += elm.Dot().at((4, 0)) |
|
d += elm.Transistor().up().label('Q1') |
|
d += elm.Resistor().right().label('10KOhm') |
|
d += elm.Line().down().to(d.elements[-2].start) # Close secondary loop |
|
d += elm.Line().left().to(d.elements[0].start) # Ensure complete closure |
|
d.save('{unique_filename}') |
|
``` |
|
|
|
IMPORTANT: Always use .to(d.elements[0].start) to close the circuit loop back to the power source! |
|
|
|
CRITICAL REQUIREMENTS: |
|
- Create a circuit that accurately represents the complex description provided |
|
- Use appropriate components and values that match the actual circuit described |
|
- INCLUDE ALL COMPONENTS listed above - missing components will cause validation failure |
|
- Ensure all components are properly connected and labeled |
|
- Handle complex topologies with multiple power rails and signal paths |
|
- Use proper component positioning and wire routing |
|
- Support feedback loops, control paths, and complex connections |
|
- Arrange components logically with clear signal flow |
|
- Use consistent labeling and naming conventions |
|
- Minimize wire clutter while maintaining circuit clarity |
|
|
|
CRITICAL CIRCUIT CLOSURE REQUIREMENTS: |
|
- ALWAYS close the circuit loop using .to() method: d += elm.Line().to(d.elements[0].start) |
|
- Ensure ALL components are connected in a complete loop |
|
- Use explicit Line() elements to connect components when needed |
|
- Start with a power source (elm.SourceV, elm.Battery) |
|
- End with a connection back to the power source |
|
- Use proper positioning to create logical circuit flow |
|
- For complex circuits, use multiple .to() connections to ensure complete closure |
|
""" |
|
else: |
|
|
|
prompt = f"""Generate a circuit diagram using the python schemdraw library based on this description. |
|
|
|
IMPORTANT REQUIREMENTS: |
|
1. Use ONLY ASCII characters - replace Ξ© with 'Ohm', ΞΌ with 'u', Β° with 'deg' |
|
2. Use ONLY components available in schemdraw.elements library |
|
3. If a component is not in schemdraw.elements, use a RBOX element (schemdraw.elements.twoterm.RBox) and label it with the component name |
|
4. Do NOT use matplotlib or any other plotting library |
|
5. Generate a complete, executable Python script |
|
6. Use d.save() to save the diagram, NOT d.draw() |
|
7. Save the output as a PNG file with the EXACT filename: {unique_filename} |
|
8. Handle all connections properly using schemdraw's native positioning methods |
|
9. Create a CLOSED LOOP circuit that matches the description - all components must form a complete loop |
|
10. INCLUDE ALL COMPONENTS mentioned in the description - do not miss any components |
|
11. DO NOT use any grounding elements (elm.Ground, elm.GroundChassis, etc.) - create a complete closed loop circuit |
|
12. Use .to() method to explicitly close the circuit loop back to the starting point |
|
|
|
Description of the circuit: {image_description} |
|
|
|
CORRECT USAGE EXAMPLE (for reference only): |
|
import schemdraw |
|
import schemdraw.elements as elm |
|
|
|
d = schemdraw.Drawing() |
|
d += elm.SourceV().up().label('10V') |
|
d += elm.Resistor().right().label('100KOhm') |
|
d += elm.Capacitor().down().label('0.1uF') |
|
d += elm.Line().left().to(d.elements[0].start) # Clean connection back to voltage source |
|
d.save('{unique_filename}') |
|
|
|
IMPORTANT: Always use .to(d.elements[0].start) to close the circuit loop back to the power source! |
|
|
|
CRITICAL REQUIREMENTS: |
|
- Do NOT copy the example circuit above |
|
- Create a completely different circuit that accurately represents the description provided |
|
- Use different components, values, and layout that match the actual circuit described in the image |
|
- INCLUDE ALL COMPONENTS listed above - missing components will cause validation failure |
|
- Ensure all components are properly connected and labeled |
|
- ENSURE COMPLETE CIRCUIT CONNECTIVITY - all components must form a connected, working circuit |
|
- Include power sources (voltage/current sources) and ground connections where appropriate |
|
- Use explicit Line() elements to connect components when needed |
|
- Create logical circuit flow with proper component sequencing |
|
- MINIMIZE WIRE CLUTTER - use direct component connections instead of unnecessary Line() elements |
|
- Use net labels (VoltageLabel, CurrentLabel) for power rails instead of long wires |
|
- Arrange components in clean, symmetrical layouts with consistent spacing |
|
- Use horizontal and vertical connections only - avoid diagonal wires |
|
- ENSURE COMPLETE CIRCUIT CONNECTIVITY - all components must form a connected, working circuit |
|
- Include power sources (voltage/current sources) and ground connections where appropriate |
|
- Use explicit Line() elements to connect components when needed |
|
- Create a logical circuit flow with proper component sequencing |
|
- MINIMIZE UNNECESSARY WIRES - use net labels and direct connections instead of excessive Line() elements |
|
- Use horizontal and vertical wire orientations only - avoid diagonal connections |
|
- Limit wire junctions to maximum 3 connections per point |
|
- Arrange components symmetrically and maintain consistent spacing |
|
|
|
COMMON ERRORS TO AVOID: |
|
- Do NOT use: elm.Tip, elm.DCSourceV, elm.SpiceNetlist |
|
- Do NOT use: matplotlib, pyplot, or any plotting libraries |
|
- Do NOT use Unicode characters in labels or component names |
|
- Do NOT use components not in schemdraw.elements |
|
- Do NOT use invalid assignment syntax like "light_bulb = d += elm.Lamp()" - use "d += elm.Lamp()" only |
|
- Do NOT use any grounding elements (elm.Ground, elm.GroundChassis, elm.GroundSignal) - create closed loop circuits only |
|
- Do NOT use excessive Line() elements - minimize unnecessary wires and use direct connections |
|
- Do NOT use redundant wire patterns (up().down(), left().right(), etc.) - use efficient routing |
|
- Do NOT use any other filename - use exactly: {unique_filename} |
|
- Do NOT copy the example circuit - create your own unique design |
|
- Do NOT miss any components from the description |
|
- DO NOT use: elm.Lightbulb, use elm.Lamp instead! |
|
|
|
CRITICAL CIRCUIT CLOSURE REQUIREMENTS: |
|
- ALWAYS close the circuit loop using .to() method: d += elm.Line().to(d.elements[0].start) |
|
- Ensure ALL components are connected in a complete loop |
|
- Use explicit Line() elements to connect components when needed |
|
- Start with a power source (elm.SourceV, elm.Battery) |
|
- End with a connection back to the power source |
|
- Use proper positioning to create logical circuit flow |
|
|
|
Generate ONLY the Python code, no explanations or markdown formatting.""" |
|
else: |
|
|
|
prompt = self._create_retry_prompt(image_description, previous_error) |
|
|
|
|
|
print("π€ [CIRCUIT] Preparing request for Reasoning model...") |
|
payload = { |
|
"model": "qwen3-coder:latest", |
|
"prompt": prompt, |
|
"stream": False, |
|
|
|
"temperature": 0.5, |
|
} |
|
|
|
print("π [CIRCUIT] Sending request to Reasoning Model...") |
|
response = requests.post("http://localhost:11434/api/generate", json=payload, timeout=3000) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
generated_code = result.get('response', '') |
|
print(f"β
[CIRCUIT] DeepSeek R1 response received successfully") |
|
print(f"π [CIRCUIT] Generated code length: {len(generated_code)} characters") |
|
|
|
|
|
print("π§ [CIRCUIT] Extracting Python code from response...") |
|
extracted_code = self._extract_python_code(generated_code) |
|
print(f"π [CIRCUIT] Extracted code length: {len(extracted_code)} characters") |
|
|
|
|
|
print("π§ [CIRCUIT] Fixing circuit structure and enhancing connections...") |
|
enhanced_code = self._fix_circuit_structure(extracted_code) |
|
|
|
|
|
if not self._validate_circuit_code(enhanced_code): |
|
print("β οΈ [CIRCUIT] Enhanced code validation failed, will retry...") |
|
if attempt < max_retries - 1: |
|
continue |
|
else: |
|
return "Error: Enhanced code failed validation after all retries" |
|
|
|
|
|
|
|
|
|
|
|
print("βοΈ [CIRCUIT] Executing enhanced circuit code...") |
|
result = self._execute_generated_circuit_code(enhanced_code) |
|
|
|
|
|
if result and result.endswith('.png'): |
|
print(f"β
[CIRCUIT] Circuit generation successful on attempt {attempt + 1}") |
|
consecutive_failures = 0 |
|
|
|
|
|
if attempt == max_retries - 1: |
|
print("β
[CIRCUIT] Circuit generated successfully") |
|
return f"{result} (Note: Circuit generated successfully)" |
|
|
|
return result |
|
else: |
|
print(f"β οΈ [CIRCUIT] Circuit execution failed: {result}") |
|
consecutive_failures += 1 |
|
previous_error = result |
|
|
|
|
|
if consecutive_failures >= 2 and attempt == max_retries - 1: |
|
print("β οΈ [CIRCUIT] Multiple consecutive failures detected, providing partial result...") |
|
return f"Partial circuit generated (Note: Some components may be missing due to generation difficulties)" |
|
|
|
if attempt < max_retries - 1: |
|
print(f"π [CIRCUIT] Retrying... (Attempt {attempt + 2}/{max_retries})") |
|
continue |
|
else: |
|
return f"Error: Circuit generation failed after {max_retries} attempts. Last error: {result}" |
|
else: |
|
error_msg = f"Error: {response.status_code} - {response.text}" |
|
print(f"β [CIRCUIT] {error_msg}") |
|
previous_error = error_msg |
|
if attempt < max_retries - 1: |
|
print(f"π [CIRCUIT] Retrying... (Attempt {attempt + 2}/{max_retries})") |
|
continue |
|
else: |
|
return error_msg |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating circuit: {str(e)}" |
|
print(f"β [CIRCUIT] {error_msg}") |
|
previous_error = error_msg |
|
if attempt < max_retries - 1: |
|
print(f"π [CIRCUIT] Retrying... (Attempt {attempt + 2}/{max_retries})") |
|
continue |
|
else: |
|
return error_msg |
|
|
|
return f"Error: Circuit generation failed after {max_retries} attempts" |
|
|
|
def _create_retry_prompt(self, image_description, previous_error): |
|
"""Create an enhanced prompt for retry attempts with error feedback""" |
|
|
|
unique_filename = self._generate_unique_filename() |
|
|
|
prompt = f"""The previous attempt to generate a circuit diagram failed. Please fix the issues and try again. |
|
|
|
PREVIOUS ERROR: {previous_error} |
|
|
|
IMPORTANT REQUIREMENTS (must follow exactly): |
|
1. Use ONLY ASCII characters - replace Ξ© with 'Ohm', ΞΌ with 'u', Β° with 'deg' |
|
2. Use ONLY components available in schemdraw.elements library |
|
3. If a component is not in schemdraw.elements, use a Rbox element (schemdraw.elements.twoterm.RBox) and label it with the component name |
|
4. Do NOT use matplotlib or any other plotting library |
|
5. Generate a complete, executable Python script |
|
6. Use d.save() to save the diagram, NOT d.draw() |
|
7. Save the output as a PNG file with the EXACT filename: {unique_filename} |
|
8. Handle all connections properly using schemdraw's native positioning methods |
|
9. Create a CLOSED LOOP circuit that matches the description - all components must form a complete loop |
|
10. INCLUDE ALL COMPONENTS mentioned in the description - do not miss any components |
|
11. DO NOT use any grounding elements (elm.Ground, elm.GroundChassis, etc.) - create a complete closed loop circuit |
|
12. Use .to() method to explicitly close the circuit loop back to the starting point |
|
|
|
Description of the circuit: {image_description} |
|
|
|
CORRECT USAGE EXAMPLE (for reference only - create your own unique circuit): |
|
```python |
|
import schemdraw |
|
import schemdraw.elements as elm |
|
|
|
d = schemdraw.Drawing() |
|
d += elm.SourceV().up().label('10V') |
|
d += elm.Resistor().right().label('100KOhm') |
|
d += elm.Capacitor().down().label('0.1uF') |
|
d += elm.Line().left().to(d.elements[0].start) # Close the loop back to voltage source |
|
d.save('{unique_filename}') |
|
``` |
|
|
|
IMPORTANT: Always use .to(d.elements[0].start) to close the circuit loop back to the power source! |
|
|
|
CRITICAL REQUIREMENTS: |
|
- Create a circuit that accurately represents the description provided |
|
- Use different components, values, and layout that match the actual circuit described in the image |
|
- INCLUDE ALL COMPONENTS listed above - missing components will cause validation failure |
|
- Ensure all components are properly connected and labeled |
|
|
|
COMMON ERRORS TO AVOID: |
|
- Do NOT use: elm.Tip, elm.DCSourceV, elm.SpiceNetlist |
|
- Do NOT use: matplotlib, pyplot, or any plotting libraries |
|
- Do NOT use Unicode characters in labels or component names |
|
- Do NOT use components not in schemdraw.elements |
|
- Do NOT use invalid assignment syntax like "light_bulb = d += elm.Lamp()" - use "d += elm.Lamp()" only |
|
- Do NOT use any other filename - use exactly: {unique_filename} |
|
- Do NOT miss any components from the description |
|
|
|
CRITICAL CIRCUIT CLOSURE REQUIREMENTS: |
|
- ALWAYS close the circuit loop using .to() method: d += elm.Line().to(d.elements[0].start) |
|
- Ensure ALL components are connected in a complete loop |
|
- Use explicit Line() elements to connect components when needed |
|
- Start with a power source (elm.SourceV, elm.Battery) |
|
- End with a connection back to the power source |
|
- Use proper positioning to create logical circuit flow |
|
|
|
Generate ONLY the Python code, no explanations or markdown formatting.""" |
|
return prompt |
|
|
|
def _cleanup_previous_circuit_files(self): |
|
"""Clean up previous circuit diagram files to ensure fresh generation""" |
|
try: |
|
print("π§Ή [CIRCUIT] Cleaning up previous circuit diagram files...") |
|
circuit_files = [] |
|
|
|
|
|
for file in os.listdir('.'): |
|
if file.endswith('.png') and any(keyword in file.lower() for keyword in ['circuit', 'diagram', 'schematic']): |
|
circuit_files.append(file) |
|
|
|
|
|
for file in circuit_files: |
|
try: |
|
os.remove(file) |
|
print(f"ποΈ [CIRCUIT] Removed previous circuit file: {file}") |
|
except Exception as e: |
|
print(f"β οΈ [CIRCUIT] Failed to remove {file}: {str(e)}") |
|
|
|
print(f"β
[CIRCUIT] Cleaned up {len(circuit_files)} previous circuit files") |
|
|
|
except Exception as e: |
|
print(f"β οΈ [CIRCUIT] Error during cleanup: {str(e)}") |
|
|
|
def _generate_unique_filename(self): |
|
"""Generate a unique filename for the circuit diagram""" |
|
import time |
|
timestamp = int(time.time()) |
|
return f"circuit_diagram_{timestamp}.png" |
|
|
|
def _preprocess_circuit_image(self, image): |
|
"""Preprocess circuit image for better component detection""" |
|
try: |
|
print("Preprocessing circuit image...") |
|
|
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
from PIL import ImageEnhance, ImageFilter |
|
|
|
|
|
enhancer = ImageEnhance.Contrast(image) |
|
image = enhancer.enhance(1.5) |
|
|
|
|
|
image = image.filter(ImageFilter.SHARPEN) |
|
|
|
|
|
enhancer = ImageEnhance.Brightness(image) |
|
image = enhancer.enhance(1.2) |
|
|
|
print("Image preprocessing completed") |
|
return image |
|
|
|
except Exception as e: |
|
print(f"Image preprocessing failed: {str(e)}") |
|
return image |
|
|
|
def _parse_complex_circuit_description(self, image_description): |
|
"""Parse complex circuit description and extract structured component information""" |
|
try: |
|
print("π [CIRCUIT] Parsing complex circuit description...") |
|
|
|
|
|
circuit_data = { |
|
'components': [], |
|
'connections': [], |
|
'power_rails': [], |
|
'signal_paths': [], |
|
'circuit_function': '', |
|
'complexity_level': 'simple' |
|
} |
|
|
|
|
|
import re |
|
|
|
|
|
component_patterns = [ |
|
|
|
r'\bSW\d+\b', |
|
r'\bDPDT\b', |
|
r'\bswitch\b', |
|
r'\bsafety\s*switch\b', |
|
r'\barming\s*arm\b', |
|
r'\bSAKLAR\s*PENGAMAN\b', |
|
|
|
|
|
r'\bBAT\d+\b', |
|
r'\bbattery\b', |
|
r'\b9V\b', |
|
r'\b12V\b', |
|
r'\bvoltage\s*source\b', |
|
r'\bpower\s*supply\b', |
|
r'\bVCC\b', r'\bGND\b', r'\bVDD\b', r'\bVSS\b', |
|
|
|
|
|
r'\bR\d+\b', |
|
r'\bresistor\b', |
|
r'\b1k\b', r'\b2k\b', r'\b100\b', r'\b10k\b', r'\b100k\b', |
|
r'\bohm\b', r'\bΞ©\b', |
|
|
|
|
|
r'\bLED\s*D\d+\b', |
|
r'\bled\b', |
|
r'\bblue\b', |
|
r'\bindicator\b', |
|
r'\bstatus\s*light\b', |
|
r'\bIDIKATOR\b', r'\bINDIKATOR\b', |
|
|
|
|
|
r'\bSCR\b', |
|
r'\bU\d+\b', |
|
r'\bSilicon\s*Controlled\s*Rectifier\b', |
|
r'\bthyristor\b', |
|
r'\btransistor\b', |
|
r'\bBJT\b', r'\bFET\b', r'\bMOSFET\b', |
|
r'\bopamp\b', r'\boperational\s*amplifier\b', |
|
r'\bIC\b', r'\bintegrated\s*circuit\b', |
|
|
|
|
|
r'\bL\d+\b', |
|
r'\binisiator\b', |
|
r'\binitiator\b', |
|
r'\bcoil\b', |
|
r'\b12V\s*inisiator\b', |
|
r'\binductor\b', |
|
|
|
|
|
r'\bcapacitor\b', r'\bcondenser\b', |
|
r'\bdiode\b', r'\brectifier\b', |
|
r'\bwire\b', r'\bconnection\b', |
|
r'\bterminal\b', r'\bnode\b', |
|
r'\bground\b', r'\bearth\b', |
|
|
|
|
|
r'\binput\s*section\b', |
|
r'\bcontrol\s*section\b', |
|
r'\boutput\s*section\b', |
|
r'\bpower\s*rail\b', |
|
r'\bsignal\s*path\b' |
|
] |
|
|
|
|
|
for pattern in component_patterns: |
|
matches = re.findall(pattern, image_description, re.IGNORECASE) |
|
circuit_data['components'].extend(matches) |
|
|
|
|
|
circuit_data['components'] = list(set(circuit_data['components'])) |
|
circuit_data['components'] = [comp for comp in circuit_data['components'] if len(comp) > 1] |
|
|
|
|
|
if 'COMPONENTS:' in image_description and not circuit_data['components']: |
|
components_section = image_description.split('COMPONENTS:')[1].split('CONNECTIONS:')[0] |
|
for line in components_section.strip().split('\n'): |
|
if line.strip().startswith('-'): |
|
component_info = line.strip()[1:].strip() |
|
circuit_data['components'].append(component_info) |
|
|
|
|
|
connection_patterns = [ |
|
|
|
r'\bpositive\s+terminal\b', |
|
r'\bnegative\s+terminal\b', |
|
r'\bconnected\s+to\b', |
|
r'\bconnected\s+between\b', |
|
r'\bconnected\s+together\b', |
|
r'\bconnected\s+via\b', |
|
r'\bconnected\s+through\b', |
|
|
|
|
|
r'\banode\b', |
|
r'\bcathode\b', |
|
r'\bgate\b', |
|
r'\bcollector\b', |
|
r'\bemitter\b', |
|
r'\bbase\b', |
|
r'\bdrain\b', |
|
r'\bsource\b', |
|
r'\bterminal\b', |
|
r'\bpin\b', |
|
|
|
|
|
r'\bground\b', |
|
r'\bcommon\s+ground\b', |
|
r'\bearth\b', |
|
r'\bVCC\b', |
|
r'\bGND\b', |
|
r'\bVDD\b', |
|
r'\bVSS\b', |
|
r'\bpower\s+rail\b', |
|
r'\bvoltage\s+rail\b', |
|
|
|
|
|
r'\boutput\s+throw\b', |
|
r'\binput\s+pole\b', |
|
r'\bswitch\s+position\b', |
|
r'\bswitch\s+state\b', |
|
r'\barming\s+position\b', |
|
r'\bsafety\s+position\b', |
|
|
|
|
|
r'\bone\s+end\b', |
|
r'\bother\s+end\b', |
|
r'\bwire\b', |
|
r'\bline\b', |
|
r'\bconnection\b', |
|
r'\bjunction\b', |
|
r'\bnode\b', |
|
r'\bpoint\b', |
|
|
|
|
|
r'\bsignal\s+path\b', |
|
r'\bcurrent\s+flow\b', |
|
r'\bvoltage\s+path\b', |
|
r'\bcontrol\s+signal\b', |
|
r'\btrigger\s+signal\b', |
|
r'\boutput\s+signal\b', |
|
|
|
|
|
r'\bseries\s+connection\b', |
|
r'\bparallel\s+connection\b', |
|
r'\bbranch\b', |
|
r'\bloop\b', |
|
r'\bcircuit\s+path\b', |
|
r'\breturn\s+path\b' |
|
] |
|
|
|
|
|
for pattern in connection_patterns: |
|
matches = re.findall(pattern, image_description, re.IGNORECASE) |
|
circuit_data['connections'].extend(matches) |
|
|
|
|
|
circuit_data['connections'] = list(set(circuit_data['connections'])) |
|
|
|
|
|
power_rail_patterns = [ |
|
|
|
r'\bVCC\b', r'\bGND\b', r'\bVDD\b', r'\bVSS\b', r'\bVEE\b', r'\bVBB\b', |
|
r'\bpower\s+rail\b', r'\bvoltage\s+rail\b', r'\bpositive\s+rail\b', |
|
r'\bnegative\s+rail\b', r'\bground\s+rail\b', |
|
r'\b12V\s+rail\b', r'\b5V\s+rail\b', r'\b3\.3V\s+rail\b', r'\b9V\s+rail\b', |
|
|
|
|
|
r'\bpower\s+supply\b', r'\bvoltage\s+supply\b', r'\bcurrent\s+supply\b', |
|
r'\bBAT\d+\b', r'\bbattery\b', r'\b9V\b', r'\b12V\b', r'\b5V\b', r'\b3\.3V\b', |
|
r'\bvoltage\s+source\b', r'\bcurrent\s+source\b', r'\bSourceV\b', r'\bSourceI\b', |
|
|
|
|
|
r'\bpower\s+distribution\b', r'\bvoltage\s+distribution\b', |
|
r'\bpower\s+bus\b', r'\bvoltage\s+bus\b', r'\bpower\s+line\b', r'\bvoltage\s+line\b' |
|
] |
|
|
|
for pattern in power_rail_patterns: |
|
matches = re.findall(pattern, image_description, re.IGNORECASE) |
|
circuit_data['power_rails'].extend(matches) |
|
|
|
|
|
circuit_data['power_rails'] = list(set(circuit_data['power_rails'])) |
|
|
|
|
|
if 'CONNECTIONS:' in image_description and not circuit_data['connections']: |
|
connections_section = image_description.split('CONNECTIONS:')[1].split('CIRCUIT FUNCTION:')[0] |
|
for line in connections_section.strip().split('\n'): |
|
if line.strip().startswith('-'): |
|
connection_info = line.strip()[1:].strip() |
|
circuit_data['connections'].append(connection_info) |
|
|
|
|
|
if 'CIRCUIT FUNCTION:' in image_description: |
|
function_section = image_description.split('CIRCUIT FUNCTION:')[1] |
|
circuit_data['circuit_function'] = function_section.strip() |
|
|
|
|
|
component_count = len(circuit_data['components']) |
|
connection_count = len(circuit_data['connections']) |
|
|
|
if component_count > 15 or connection_count > 20: |
|
circuit_data['complexity_level'] = 'very_complex' |
|
elif component_count > 10 or connection_count > 15: |
|
circuit_data['complexity_level'] = 'complex' |
|
elif component_count > 5 or connection_count > 10: |
|
circuit_data['complexity_level'] = 'moderate' |
|
else: |
|
circuit_data['complexity_level'] = 'simple' |
|
|
|
print(f"π [CIRCUIT] Circuit complexity: {circuit_data['complexity_level']}") |
|
print(f"π [CIRCUIT] Components found: {component_count}") |
|
print(f"π [CIRCUIT] Connections found: {connection_count}") |
|
print(f"β‘ [CIRCUIT] Power rails and supplies found: {len(circuit_data['power_rails'])}") |
|
if circuit_data['power_rails']: |
|
print(f" - Power rails/supplies: {', '.join(circuit_data['power_rails'])}") |
|
|
|
return circuit_data |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error parsing complex circuit description: {str(e)}") |
|
return None |
|
|
|
def _generate_complex_circuit_prompt(self, circuit_data, unique_filename): |
|
"""Generate a specialized prompt for complex circuit generation""" |
|
try: |
|
print("Generating specialized prompt for complex circuit...") |
|
|
|
complexity_level = circuit_data.get('complexity_level', 'simple') |
|
components = circuit_data.get('components', []) |
|
connections = circuit_data.get('connections', []) |
|
power_rails = circuit_data.get('power_rails', []) |
|
circuit_function = circuit_data.get('circuit_function', '') |
|
|
|
|
|
prompt = f"""Generate a {complexity_level} circuit diagram using the python schemdraw library. |
|
|
|
CIRCUIT ANALYSIS: |
|
- Complexity Level: {complexity_level} |
|
- Component Count: {len(components)} |
|
- Connection Count: {len(connections)} |
|
- Power Rails: {len(power_rails)} ({', '.join(power_rails) if power_rails else 'None detected'}) |
|
- Circuit Function: {circuit_function} |
|
|
|
COMPONENTS TO INCLUDE: |
|
""" |
|
|
|
|
|
for i, component in enumerate(components[:10]): |
|
prompt += f"- Component {i+1}: {component}\n" |
|
|
|
if len(components) > 10: |
|
prompt += f"- ... and {len(components) - 10} more components\n" |
|
|
|
prompt += f""" |
|
POWER RAILS AND SUPPLIES TO IMPLEMENT: |
|
""" |
|
|
|
|
|
if power_rails: |
|
for i, rail in enumerate(power_rails): |
|
prompt += f"- Power Rail/Supply {i+1}: {rail}\n" |
|
else: |
|
prompt += "- Power Rails/Supplies: Use standard VCC/GND rails and power supplies as needed\n" |
|
|
|
prompt += f""" |
|
CONNECTIONS TO IMPLEMENT: |
|
""" |
|
|
|
|
|
for i, connection in enumerate(connections[:10]): |
|
prompt += f"- Connection {i+1}: {connection}\n" |
|
|
|
if len(connections) > 10: |
|
prompt += f"- ... and {len(connections) - 10} more connections\n" |
|
|
|
|
|
if complexity_level == 'very_complex': |
|
prompt += """ |
|
VERY COMPLEX CIRCUIT REQUIREMENTS: |
|
- Use modular design with clear sections |
|
- Implement multiple power rails (VCC, GND, VDD, etc.) |
|
- Use elm.Dot for wire junctions and connection points |
|
- Use elm.Label for power rails and voltage/current labels |
|
- Organize components in logical blocks |
|
- Use absolute positioning (.at()) for precise placement |
|
- Minimize wire crossings and clutter |
|
- Support feedback loops and control paths |
|
- NEVER use d.element - this is INVALID and will cause errors |
|
- ALWAYS use d.elements[-1] instead of d.element |
|
- NEVER use d.last_end, d.last_start, d.end, d.start, d.position - these are INVALID attributes |
|
|
|
SPECIALIZED COMPONENT HANDLING: |
|
- DPDT switches: Use elm.Switch for double-pole double-throw switches |
|
- SCR/Thyristor: Use elm.SCR for Silicon Controlled Rectifiers |
|
- Multiple batteries: Use elm.Battery with proper labeling (BAT1, BAT2) |
|
- Indicator LEDs: Use elm.LED with color specifications |
|
- Initiator/Coil: Use elm.Inductor for coils and initiators |
|
- Safety switches: Use elm.Switch with safety labels |
|
- Power distribution: Use elm.Label for multiple voltage rails |
|
- Ground connections: Use elm.Ground for common ground points |
|
|
|
CIRCUIT ORGANIZATION: |
|
- Input section: Safety switches and indicators (left side) |
|
- Control section: Logic and power supplies (middle) |
|
- Output section: Initiator and final controls (right side) |
|
- Use elm.Text for section labels and component descriptions |
|
""" |
|
elif complexity_level == 'complex': |
|
prompt += """ |
|
COMPLEX CIRCUIT REQUIREMENTS: |
|
- Use clear signal flow from input to output |
|
- Implement proper power distribution |
|
- Use elm.Dot for connection points |
|
- Group related components together |
|
- Use consistent spacing and alignment |
|
- Support multiple signal paths |
|
""" |
|
else: |
|
prompt += """ |
|
STANDARD CIRCUIT REQUIREMENTS: |
|
- Use logical component arrangement |
|
- Implement proper connections |
|
- Use clear labeling |
|
- Maintain circuit clarity |
|
""" |
|
|
|
|
|
prompt += f""" |
|
STANDARD REQUIREMENTS: |
|
- Use ONLY ASCII characters |
|
- Use ONLY schemdraw.elements components |
|
- Generate complete, executable Python script |
|
- Use d.save() with filename: {unique_filename} |
|
- Use proper positioning methods (.up(), .down(), .left(), .right(), .to()) |
|
- Label all components appropriately |
|
- Handle all connections properly |
|
|
|
CRITICAL CIRCUIT CLOSURE REQUIREMENTS: |
|
- ALWAYS close the circuit loop using .to() method: d += elm.Line().to(d.elements[0].start) |
|
- Ensure ALL components are connected in a complete loop |
|
- Use explicit Line() elements to connect components when needed |
|
- Start with a power source (elm.SourceV, elm.Battery) |
|
- End with a connection back to the power source |
|
- Use proper positioning to create logical circuit flow |
|
|
|
Generate ONLY the Python code, no explanations.""" |
|
|
|
return prompt |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error generating complex circuit prompt: {str(e)}") |
|
return None |
|
|
|
def _fix_component_naming_issues(self, code): |
|
"""Fix common component naming issues in generated code""" |
|
try: |
|
print("π§ [CIRCUIT] Fixing component naming issues...") |
|
|
|
|
|
fixed_code = code.replace('elm.IC', 'elm.Ic') |
|
fixed_code = fixed_code.replace('elm.IC(', 'elm.Ic(') |
|
|
|
|
|
fixed_code = fixed_code.replace('elm.IC)', 'elm.Ic)') |
|
|
|
|
|
if fixed_code != code: |
|
print("β
[CIRCUIT] Fixed component naming issues") |
|
else: |
|
print("β
[CIRCUIT] No component naming issues found") |
|
|
|
return fixed_code |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error fixing component naming issues: {str(e)}") |
|
return code |
|
|
|
def _execute_generated_circuit_code(self, generated_code): |
|
"""Execute the generated circuit code and return the diagram file""" |
|
temp_script = None |
|
try: |
|
|
|
self._cleanup_previous_circuit_files() |
|
|
|
|
|
expected_filename = None |
|
import re |
|
save_match = re.search(r"d\.save\(['\"]([^'\"]+)['\"]\)", generated_code) |
|
if save_match: |
|
expected_filename = save_match.group(1) |
|
print(f"π― [CIRCUIT] Expected filename from code: {expected_filename}") |
|
|
|
print("π§ [CIRCUIT] Normalizing Unicode characters in generated code...") |
|
|
|
import unicodedata |
|
normalized_code = unicodedata.normalize('NFD', generated_code) |
|
|
|
normalized_code = normalized_code.replace('Ξ©', 'Ohm') |
|
normalized_code = normalized_code.replace('ΞΌ', 'u') |
|
normalized_code = normalized_code.replace('Β°', 'deg') |
|
normalized_code = normalized_code.replace('Β±', '+/-') |
|
normalized_code = normalized_code.replace('β€', '<=') |
|
normalized_code = normalized_code.replace('β₯', '>=') |
|
normalized_code = normalized_code.replace('β ', '!=') |
|
print("β
[CIRCUIT] Unicode normalization completed") |
|
|
|
print("π [CIRCUIT] Creating temporary Python script...") |
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: |
|
f.write(normalized_code) |
|
temp_script = f.name |
|
print(f"π [CIRCUIT] Temporary script created: {temp_script}") |
|
|
|
print("βοΈ [CIRCUIT] Setting up execution environment...") |
|
|
|
env = os.environ.copy() |
|
env['PYTHONIOENCODING'] = 'utf-8' |
|
|
|
print("π [CIRCUIT] Executing generated Python script...") |
|
result = subprocess.run(['python', temp_script], |
|
capture_output=True, text=True, timeout=60, |
|
env=env, encoding='utf-8') |
|
|
|
if result.returncode == 0: |
|
print("β
[CIRCUIT] Script executed successfully") |
|
print("π [CIRCUIT] Searching for generated PNG files...") |
|
|
|
|
|
if expected_filename and os.path.exists(expected_filename): |
|
print(f"β
[CIRCUIT] Found expected file: {expected_filename}") |
|
return expected_filename |
|
|
|
|
|
generated_files = [] |
|
for file in os.listdir('.'): |
|
if file.endswith('.png'): |
|
generated_files.append(file) |
|
|
|
if generated_files: |
|
|
|
circuit_files = [f for f in generated_files if 'circuit' in f.lower()] |
|
if circuit_files: |
|
selected_file = circuit_files[0] |
|
print(f"β
[CIRCUIT] Found generated circuit diagram: {selected_file}") |
|
return selected_file |
|
else: |
|
|
|
selected_file = generated_files[0] |
|
print(f"β
[CIRCUIT] Found generated diagram: {selected_file}") |
|
return selected_file |
|
else: |
|
print("β [CIRCUIT] No PNG files found after successful execution") |
|
return "Error: No PNG files generated despite successful script execution" |
|
else: |
|
print(f"β [CIRCUIT] Script execution failed with return code: {result.returncode}") |
|
print(f"π [CIRCUIT] Error output: {result.stderr}") |
|
print(f"π [CIRCUIT] Standard output: {result.stdout}") |
|
|
|
|
|
error_msg = result.stderr.strip() |
|
if "ModuleNotFoundError" in error_msg: |
|
return f"Error: Missing required module - {error_msg}" |
|
elif "AttributeError: module 'schemdraw.elements' has no attribute 'IC'. Did you mean: 'Ic'?" in error_msg: |
|
return f"Error: Use 'elm.Ic' instead of 'elm.IC' for integrated circuits - {error_msg}" |
|
elif "AttributeError" in error_msg: |
|
return f"Error: Invalid component or method used - {error_msg}" |
|
elif "SyntaxError" in error_msg: |
|
return f"Error: Syntax error in generated code - {error_msg}" |
|
elif "ImportError" in error_msg: |
|
return f"Error: Import error - {error_msg}" |
|
elif "d.draw()" in error_msg: |
|
|
|
return f"Warning: d.draw() was used but may not generate a file. Consider using d.save() for better results." |
|
elif "Duplicate `at` parameter in element" in error_msg: |
|
return f"Warning: Duplicate positioning parameters detected - {error_msg}" |
|
else: |
|
return f"Error: Script execution failed - {error_msg}" |
|
|
|
except subprocess.TimeoutExpired: |
|
print("β [CIRCUIT] Script execution timed out") |
|
return "Error: Script execution timed out (60 seconds)" |
|
except Exception as e: |
|
print(f"β [CIRCUIT] Exception during code execution: {str(e)}") |
|
return f"Error: Exception during code execution - {str(e)}" |
|
finally: |
|
|
|
if temp_script and os.path.exists(temp_script): |
|
try: |
|
os.unlink(temp_script) |
|
print("π§Ή [CIRCUIT] Temporary script cleaned up") |
|
except Exception as e: |
|
print(f"β οΈ [CIRCUIT] Failed to clean up temporary script: {str(e)}") |
|
|
|
def _validate_circuit_code(self, code): |
|
"""Validate the generated circuit code for common issues""" |
|
try: |
|
print("π [CIRCUIT] Validating generated code...") |
|
|
|
|
|
if 'import schemdraw' not in code: |
|
print("β [CIRCUIT] Missing schemdraw import") |
|
return False |
|
|
|
|
|
forbidden_components = [ |
|
'elm.Tip', 'elm.DCSourceV', 'elm.SpiceNetlist', 'elm.SpiceNetlistElement', |
|
'matplotlib', 'pyplot', 'plt', 'import matplotlib', 'from matplotlib' |
|
] |
|
|
|
for component in forbidden_components: |
|
if component in code: |
|
print(f"β [CIRCUIT] Forbidden component found: {component}") |
|
return False |
|
|
|
|
|
import re |
|
invalid_assignment_patterns = [ |
|
r'\w+\s*=\s*d\s*\+=', |
|
r'\w+\s*=\s*d\.add\(', |
|
r'\w+\s*=\s*d\.append\(', |
|
] |
|
for pattern in invalid_assignment_patterns: |
|
if re.search(pattern, code): |
|
print(f"β [CIRCUIT] Invalid assignment syntax detected: {pattern}") |
|
return False |
|
|
|
|
|
grounding_elements = ['elm.Ground', 'elm.GroundChassis', 'elm.GroundSignal', 'elm.Ground'] |
|
for ground_element in grounding_elements: |
|
if ground_element in code: |
|
print(f"β [CIRCUIT] Grounding element found: {ground_element} - closed loop circuits should not have grounding elements") |
|
return False |
|
|
|
|
|
if not self._validate_closed_loop_circuit(code): |
|
print("β [CIRCUIT] Circuit is not a complete closed loop") |
|
return False |
|
|
|
|
|
|
|
if 'd.draw()' in code: |
|
print("β οΈ [CIRCUIT] d.draw() found - allowing to pass validation") |
|
|
|
|
|
|
|
unicode_chars = ['Ξ©', 'ΞΌ', 'Β°', 'Β±', 'β€', 'β₯', 'β ', 'β', 'β', 'β', 'β«', 'β'] |
|
for char in unicode_chars: |
|
if char in code: |
|
print(f"β [CIRCUIT] Unicode character found: {char}") |
|
return False |
|
|
|
|
|
if 'd.save(' not in code: |
|
print("β [CIRCUIT] Missing d.save() method") |
|
return False |
|
|
|
|
|
if 'schemdraw.Drawing()' not in code: |
|
print("β [CIRCUIT] Missing schemdraw.Drawing() initialization") |
|
return False |
|
|
|
|
|
example_components = ['100KOhm', '0.1uF', '10V'] |
|
example_count = sum(1 for component in example_components if component in code) |
|
if example_count >= 2: |
|
print("β οΈ [CIRCUIT] Circuit appears to be copying example values too closely") |
|
|
|
|
|
|
|
component_patterns = [ |
|
'elm.Resistor', 'elm.Capacitor', 'elm.Inductor', 'elm.Diode', |
|
'elm.SourceV', 'elm.SourceI', 'elm.Ground', 'elm.Line', 'elm.Dot', |
|
'elm.Rect', 'elm.RBox', 'elm.Circle', 'elm.Transistor', 'elm.OpAmp', |
|
'elm.Switch', 'elm.LED', 'elm.Motor', 'elm.Relay', 'elm.Crystal', |
|
'elm.Transformer', 'elm.Potentiometer', 'elm.Thermistor', 'elm.Varistor', |
|
'elm.Fuse', 'elm.Connector', 'elm.Ic', 'elm.Battery', 'elm.CurrentLabel', |
|
'elm.VoltageLabel', 'elm.Node', 'elm.Dot2', 'elm.Contact', 'elm.Arrow', |
|
'elm.Text', 'elm.Lamp' |
|
] |
|
component_count = sum(1 for pattern in component_patterns if pattern in code) |
|
if component_count < 3: |
|
print("β οΈ [CIRCUIT] Circuit appears too simple - may be copying example") |
|
|
|
|
|
|
|
label_count = code.count('.label(') |
|
if component_count > 0 and label_count < component_count * 0.5: |
|
print("β οΈ [CIRCUIT] Many components are not labeled - consider adding labels") |
|
|
|
|
|
print("β
[CIRCUIT] Code validation passed") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error during code validation: {str(e)}") |
|
return False |
|
|
|
def _validate_closed_loop_circuit(self, code): |
|
"""Validate that the circuit forms a complete closed loop without grounding elements""" |
|
try: |
|
print("π [CIRCUIT] Validating closed loop circuit structure...") |
|
|
|
|
|
lines = code.split('\n') |
|
component_lines = [] |
|
|
|
for line in lines: |
|
line = line.strip() |
|
if line.startswith('d += elm.') and not line.startswith('d += elm.Ground'): |
|
component_lines.append(line) |
|
|
|
if len(component_lines) < 3: |
|
print("β [CIRCUIT] Circuit must have at least 3 components for a closed loop") |
|
return False |
|
|
|
|
|
power_sources = ['elm.SourceV', 'elm.SourceI', 'elm.Battery', 'elm.SourceSin', 'elm.SourceSquare'] |
|
has_power = any(source in code for source in power_sources) |
|
if not has_power: |
|
print("β [CIRCUIT] Closed loop circuit must have a power source") |
|
return False |
|
|
|
|
|
connection_methods = ['.up()', '.down()', '.left()', '.right()', '.to('] |
|
has_connections = any(method in code for method in connection_methods) |
|
if not has_connections: |
|
print("β [CIRCUIT] Circuit components must be properly connected using directional methods") |
|
return False |
|
|
|
|
|
if '.to(' not in code: |
|
|
|
|
|
print("β οΈ [CIRCUIT] Consider using .to() method to explicitly close the circuit loop") |
|
|
|
print("β
[CIRCUIT] Closed loop circuit validation passed") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error validating closed loop circuit: {str(e)}") |
|
return False |
|
|
|
def _extract_python_code(self, response_text): |
|
"""Extract Python code from AI model response, handling markdown code blocks""" |
|
try: |
|
print("π [CIRCUIT] Analyzing response for code blocks...") |
|
|
|
|
|
if '```python' in response_text: |
|
print("π¦ [CIRCUIT] Found Python code block, extracting...") |
|
|
|
start_marker = '```python' |
|
end_marker = '```' |
|
|
|
start_idx = response_text.find(start_marker) |
|
if start_idx != -1: |
|
|
|
code_start = start_idx + len(start_marker) |
|
end_idx = response_text.find(end_marker, code_start) |
|
|
|
if end_idx != -1: |
|
extracted_code = response_text[code_start:end_idx].strip() |
|
print("β
[CIRCUIT] Successfully extracted Python code from markdown block") |
|
return extracted_code |
|
else: |
|
print("β οΈ [CIRCUIT] Found start marker but no end marker, using rest of text") |
|
return response_text[code_start:].strip() |
|
else: |
|
print("β οΈ [CIRCUIT] No start marker found") |
|
return response_text |
|
|
|
|
|
elif '```' in response_text: |
|
print("π¦ [CIRCUIT] Found generic code block, extracting...") |
|
|
|
start_marker = '```' |
|
end_marker = '```' |
|
|
|
start_idx = response_text.find(start_marker) |
|
if start_idx != -1: |
|
code_start = start_idx + len(start_marker) |
|
end_idx = response_text.find(end_marker, code_start) |
|
|
|
if end_idx != -1: |
|
extracted_code = response_text[code_start:end_idx].strip() |
|
|
|
if extracted_code.startswith('python'): |
|
extracted_code = extracted_code[6:].strip() |
|
print("β
[CIRCUIT] Successfully extracted code from generic block") |
|
return extracted_code |
|
else: |
|
print("β οΈ [CIRCUIT] Found start marker but no end marker, using rest of text") |
|
return response_text[code_start:].strip() |
|
else: |
|
print("β οΈ [CIRCUIT] No start marker found") |
|
return response_text |
|
|
|
else: |
|
print("π [CIRCUIT] No code blocks found, using response as-is") |
|
return response_text |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error extracting Python code: {str(e)}") |
|
return response_text |
|
|
|
def process_circuit_image(self, image): |
|
"""Main function to process uploaded circuit image""" |
|
try: |
|
print("=" * 60) |
|
print("π [CIRCUIT] Starting circuit diagram generation process") |
|
print("=" * 60) |
|
|
|
if image is None: |
|
print("β [CIRCUIT] No image uploaded") |
|
return "No image uploaded", None |
|
|
|
print("πΈ [CIRCUIT] Image uploaded successfully") |
|
|
|
|
|
print("\n" + "=" * 40) |
|
print("π STEP 1: Image Description with Gemma3") |
|
print("=" * 40) |
|
description = self.describe_image_with_gemma3(image) |
|
|
|
|
|
print("\n" + "=" * 40) |
|
print("π§ STEP 2: Circuit Generation with DeepSeek R1") |
|
print("=" * 40) |
|
circuit_result = self.generate_circuit_with_deepseek(description) |
|
|
|
|
|
print("\n" + "=" * 40) |
|
print("π STEP 3: Finalizing Results") |
|
print("=" * 40) |
|
|
|
if circuit_result and (circuit_result.endswith('.png') or 'circuit_diagram_' in circuit_result): |
|
print(f"β
[CIRCUIT] Circuit diagram generated successfully: {circuit_result}") |
|
print("=" * 60) |
|
print("π [CIRCUIT] Process completed successfully!") |
|
print("=" * 60) |
|
|
|
|
|
if "(Note:" in circuit_result: |
|
|
|
filename = circuit_result.split(' (Note:')[0] |
|
note = circuit_result.split('(Note:')[1].rstrip(')') |
|
return f"Image Description: {description}\n\nCircuit Generated: {filename}\n\n{note}", filename |
|
else: |
|
return f"Image Description: {description}\n\nCircuit Generated: {circuit_result}", circuit_result |
|
else: |
|
print(f"β οΈ [CIRCUIT] Circuit generation failed: {circuit_result}") |
|
print("=" * 60) |
|
print("β [CIRCUIT] Process completed with errors") |
|
print("=" * 60) |
|
|
|
|
|
error_details = "" |
|
if "Error:" in circuit_result: |
|
error_details = f"\n\nError Details:\n{circuit_result}" |
|
|
|
return f"Image Description: {description}\n\nCircuit Generation Failed{error_details}", None |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing circuit image: {str(e)}" |
|
print(f"β [CIRCUIT] {error_msg}") |
|
print("=" * 60) |
|
print("π₯ [CIRCUIT] Process failed!") |
|
print("=" * 60) |
|
return error_msg, None |
|
|
|
def _enhance_circuit_connections(self, code): |
|
"""Enhance circuit connections to ensure proper closure and connectivity""" |
|
try: |
|
print("π§ [CIRCUIT] Enhancing circuit connections for proper closure...") |
|
|
|
lines = code.split('\n') |
|
component_lines = [] |
|
connection_lines = [] |
|
|
|
|
|
for i, line in enumerate(lines): |
|
line = line.strip() |
|
if line.startswith('d += elm.') and not line.startswith('d += elm.Ground'): |
|
component_lines.append((i, line)) |
|
elif line.startswith('d += elm.Line') or line.startswith('d += elm.Dot'): |
|
connection_lines.append((i, line)) |
|
|
|
if len(component_lines) < 2: |
|
print("β οΈ [CIRCUIT] Not enough components to enhance connections") |
|
return code |
|
|
|
|
|
has_closure = any('.to(' in line for _, line in component_lines + connection_lines) |
|
|
|
if not has_closure: |
|
print("π [CIRCUIT] Adding circuit closure connection...") |
|
|
|
|
|
last_component_idx, last_component_line = component_lines[-1] |
|
|
|
|
|
closure_line = f"d += elm.Line().to(d.elements[0].start)" |
|
|
|
|
|
lines.insert(last_component_idx + 1, closure_line) |
|
|
|
print("β
[CIRCUIT] Added circuit closure connection") |
|
|
|
|
|
enhanced_code = self._add_missing_connections(lines) |
|
|
|
return enhanced_code |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error enhancing circuit connections: {str(e)}") |
|
return code |
|
|
|
def _add_missing_connections(self, lines): |
|
"""Add missing connections between components""" |
|
try: |
|
print("π [CIRCUIT] Adding missing connections between components...") |
|
|
|
|
|
component_indices = [] |
|
for i, line in enumerate(lines): |
|
if line.strip().startswith('d += elm.') and not line.strip().startswith('d += elm.Ground'): |
|
component_indices.append(i) |
|
|
|
if len(component_indices) < 2: |
|
return '\n'.join(lines) |
|
|
|
|
|
enhanced_lines = lines.copy() |
|
insertions = 0 |
|
|
|
for i in range(len(component_indices) - 1): |
|
current_idx = component_indices[i] + insertions |
|
next_idx = component_indices[i + 1] + insertions |
|
|
|
|
|
has_connection = False |
|
for j in range(current_idx + 1, next_idx): |
|
if j < len(enhanced_lines) and enhanced_lines[j].strip().startswith('d += elm.Line'): |
|
has_connection = True |
|
break |
|
|
|
if not has_connection: |
|
|
|
connection_line = "d += elm.Line().right()" |
|
enhanced_lines.insert(next_idx, connection_line) |
|
insertions += 1 |
|
print(f"π [CIRCUIT] Added connection between components {i+1} and {i+2}") |
|
|
|
return '\n'.join(enhanced_lines) |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error adding missing connections: {str(e)}") |
|
return '\n'.join(lines) |
|
|
|
def _validate_circuit_connectivity(self, code): |
|
"""Validate that all components are properly connected""" |
|
try: |
|
print("π [CIRCUIT] Validating circuit connectivity...") |
|
|
|
lines = code.split('\n') |
|
component_count = 0 |
|
connection_count = 0 |
|
|
|
for line in lines: |
|
line = line.strip() |
|
if line.startswith('d += elm.') and not line.startswith('d += elm.Ground'): |
|
component_count += 1 |
|
elif line.startswith('d += elm.Line') or line.startswith('d += elm.Dot'): |
|
connection_count += 1 |
|
|
|
|
|
if component_count < 2: |
|
print("β [CIRCUIT] Circuit needs at least 2 components") |
|
return False |
|
|
|
if connection_count < 1: |
|
print("β [CIRCUIT] Circuit needs at least 1 connection") |
|
return False |
|
|
|
|
|
has_closure = '.to(' in code |
|
if not has_closure: |
|
print("β οΈ [CIRCUIT] Circuit may not be properly closed") |
|
|
|
print(f"β
[CIRCUIT] Circuit connectivity validation passed - {component_count} components, {connection_count} connections") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error validating circuit connectivity: {str(e)}") |
|
return False |
|
|
|
def _fix_circuit_structure(self, code): |
|
"""Fix common circuit structure issues""" |
|
try: |
|
print("π§ [CIRCUIT] Fixing circuit structure issues...") |
|
|
|
lines = code.split('\n') |
|
fixed_lines = [] |
|
|
|
for line in lines: |
|
line = line.strip() |
|
|
|
|
|
if 'd += elm.' in line: |
|
|
|
if not any(method in line for method in ['.up()', '.down()', '.left()', '.right()', '.to(', '.at(']): |
|
|
|
if 'elm.SourceV' in line or 'elm.Battery' in line: |
|
line = line.rstrip() + '.up()' |
|
elif 'elm.Resistor' in line or 'elm.Capacitor' in line: |
|
line = line.rstrip() + '.right()' |
|
elif 'elm.LED' in line or 'elm.Diode' in line: |
|
line = line.rstrip() + '.down()' |
|
|
|
|
|
line = line.replace('elm.IC', 'elm.Ic') |
|
line = line.replace('elm.IC(', 'elm.Ic(') |
|
|
|
fixed_lines.append(line) |
|
|
|
|
|
fixed_code = '\n'.join(fixed_lines) |
|
enhanced_code = self._enhance_circuit_connections(fixed_code) |
|
|
|
print("β
[CIRCUIT] Circuit structure fixes applied") |
|
return enhanced_code |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error fixing circuit structure: {str(e)}") |
|
return code |
|
|
|
def _generate_robust_circuit_template(self, components, unique_filename): |
|
"""Generate a robust circuit template with proper connections""" |
|
try: |
|
print("π§ [CIRCUIT] Generating robust circuit template...") |
|
|
|
template = f"""import schemdraw |
|
import schemdraw.elements as elm |
|
|
|
d = schemdraw.Drawing() |
|
|
|
# Power source |
|
d += elm.SourceV().up().label('12V').at((0, 0)) |
|
|
|
# Main circuit components |
|
""" |
|
|
|
|
|
for i, component in enumerate(components[:5]): |
|
if 'resistor' in component.lower(): |
|
template += f"d += elm.Resistor().right().label('R{i+1}')\n" |
|
elif 'capacitor' in component.lower(): |
|
template += f"d += elm.Capacitor().down().label('C{i+1}')\n" |
|
elif 'led' in component.lower(): |
|
template += f"d += elm.LED().right().label('LED{i+1}')\n" |
|
elif 'switch' in component.lower(): |
|
template += f"d += elm.Switch().up().label('SW{i+1}')\n" |
|
elif 'battery' in component.lower() or 'power' in component.lower(): |
|
template += f"d += elm.Battery().up().label('BAT{i+1}')\n" |
|
else: |
|
template += f"d += elm.RBox().right().label('{component}')\n" |
|
|
|
|
|
template += f""" |
|
# Close the circuit loop |
|
d += elm.Line().left().to(d.elements[0].start) |
|
|
|
# Save the diagram |
|
d.save('{unique_filename}') |
|
""" |
|
|
|
print("β
[CIRCUIT] Robust circuit template generated") |
|
return template |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error generating robust circuit template: {str(e)}") |
|
return None |
|
|
|
def _create_validated_circuit_template(self, image_description, unique_filename): |
|
"""Create a validated circuit template based on image description""" |
|
try: |
|
print("π§ [CIRCUIT] Creating validated circuit template...") |
|
|
|
|
|
components = self._extract_components_from_description(image_description) |
|
|
|
if not components: |
|
print("β οΈ [CIRCUIT] No specific components found, using generic template") |
|
return self._generate_generic_validated_template(unique_filename) |
|
|
|
|
|
template = f"""import schemdraw |
|
import schemdraw.elements as elm |
|
|
|
d = schemdraw.Drawing() |
|
|
|
# Power source - always start with power |
|
d += elm.SourceV().up().label('12V').at((0, 0)) |
|
|
|
# Circuit components based on image description |
|
""" |
|
|
|
|
|
component_count = 0 |
|
for component in components[:6]: |
|
component_count += 1 |
|
component_type = component.get('type', 'RBox') |
|
value = component.get('value', str(component_count)) |
|
|
|
if component_type.lower() == 'resistor': |
|
template += f"d += elm.Resistor().right().label('R{component_count}')\n" |
|
elif component_type.lower() == 'capacitor': |
|
template += f"d += elm.Capacitor().down().label('C{component_count}')\n" |
|
elif component_type.lower() == 'led': |
|
template += f"d += elm.LED().right().label('LED{component_count}')\n" |
|
elif component_type.lower() == 'diode': |
|
template += f"d += elm.Diode().right().label('D{component_count}')\n" |
|
elif component_type.lower() == 'switch': |
|
template += f"d += elm.Switch().up().label('SW{component_count}')\n" |
|
elif component_type.lower() == 'transistor': |
|
template += f"d += elm.Transistor().up().label('Q{component_count}')\n" |
|
elif component_type.lower() == 'battery': |
|
template += f"d += elm.Battery().up().label('BAT{component_count}')\n" |
|
elif component_type.lower() == 'sourcev': |
|
template += f"d += elm.SourceV().up().label('V{component_count}')\n" |
|
elif component_type.lower() == 'ic': |
|
template += f"d += elm.Ic().right().label('IC{component_count}')\n" |
|
else: |
|
template += f"d += elm.RBox().right().label('{component_type}{component_count}')\n" |
|
|
|
|
|
template += f""" |
|
# Ensure circuit closure - critical for proper operation |
|
d += elm.Line().left().to(d.elements[0].start) |
|
|
|
# Save the validated circuit diagram |
|
d.save('{unique_filename}') |
|
""" |
|
|
|
print(f"β
[CIRCUIT] Validated circuit template created with {component_count} components") |
|
return template |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error creating validated circuit template: {str(e)}") |
|
return self._generate_generic_validated_template(unique_filename) |
|
|
|
def _generate_generic_validated_template(self, unique_filename): |
|
"""Generate a generic but validated circuit template""" |
|
try: |
|
print("π§ [CIRCUIT] Generating generic validated template...") |
|
|
|
template = f"""import schemdraw |
|
import schemdraw.elements as elm |
|
|
|
d = schemdraw.Drawing() |
|
|
|
# Power source - essential for circuit operation |
|
d += elm.SourceV().up().label('12V').at((0, 0)) |
|
|
|
# Basic circuit components with proper connections |
|
d += elm.Resistor().right().label('R1') |
|
d += elm.LED().down().label('LED1') |
|
d += elm.Capacitor().left().label('C1') |
|
|
|
# Critical: Close the circuit loop for proper current flow |
|
d += elm.Line().up().to(d.elements[0].start) |
|
|
|
# Save the validated circuit |
|
d.save('{unique_filename}') |
|
""" |
|
|
|
print("β
[CIRCUIT] Generic validated template generated") |
|
return template |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error generating generic template: {str(e)}") |
|
return None |
|
|
|
def _extract_components_from_description(self, image_description): |
|
"""Extract component information from the image description""" |
|
try: |
|
components = [] |
|
|
|
|
|
component_patterns = [ |
|
(r'resistor[s]?\s+(\w+)', 'Resistor'), |
|
(r'capacitor[s]?\s+(\w+)', 'Capacitor'), |
|
(r'led[s]?\s+(\w+)', 'LED'), |
|
(r'diode[s]?\s+(\w+)', 'Diode'), |
|
(r'switch[s]?\s+(\w+)', 'Switch'), |
|
(r'transistor[s]?\s+(\w+)', 'Transistor'), |
|
(r'bjt[s]?\s+(\w+)', 'Transistor'), |
|
(r'battery[s]?\s+(\w+)', 'Battery'), |
|
(r'voltage\s+source[s]?\s+(\w+)', 'SourceV'), |
|
(r'power\s+supply[s]?\s+(\w+)', 'SourceV'), |
|
(r'ic[s]?\s+(\w+)', 'Ic'), |
|
(r'integrated\s+circuit[s]?\s+(\w+)', 'Ic'), |
|
(r'inductor[s]?\s+(\w+)', 'Inductor'), |
|
(r'relay[s]?\s+(\w+)', 'Relay'), |
|
(r'motor[s]?\s+(\w+)', 'Motor'), |
|
(r'fuse[s]?\s+(\w+)', 'Fuse'), |
|
(r'connector[s]?\s+(\w+)', 'Connector'), |
|
] |
|
|
|
import re |
|
for pattern, component_type in component_patterns: |
|
matches = re.findall(pattern, image_description.lower()) |
|
for match in matches: |
|
components.append({ |
|
'type': component_type, |
|
'value': match, |
|
'description': f"{component_type} {match}" |
|
}) |
|
|
|
|
|
seen = set() |
|
unique_components = [] |
|
for component in components: |
|
key = f"{component['type']}_{component['value']}" |
|
if key not in seen: |
|
seen.add(key) |
|
unique_components.append(component) |
|
|
|
return unique_components |
|
|
|
except Exception as e: |
|
print(f"β [CIRCUIT] Error extracting components from description: {str(e)}") |
|
return [] |
|
|
|
|
|
|
|
def create_ui(): |
|
app = PDFSearchApp() |
|
|
|
with gr.Blocks(theme=gr.themes.Ocean(), css="footer{display:none !important}") as demo: |
|
|
|
session_state = gr.State(value=None) |
|
user_info_state = gr.State(value=None) |
|
|
|
gr.Markdown("# Collar Multimodal RAG Demo - Production Ready") |
|
gr.Markdown("Made by Collar - Enhanced with Team Management & Chat History") |
|
|
|
|
|
with gr.Tab("π Authentication"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Login") |
|
username_input = gr.Textbox(label="Username", placeholder="Enter username") |
|
password_input = gr.Textbox(label="Password", type="password", placeholder="Enter password") |
|
login_btn = gr.Button("Login", variant="primary") |
|
logout_btn = gr.Button("Logout") |
|
auth_status = gr.Textbox(label="Authentication Status", interactive=False) |
|
current_team = gr.Textbox(label="Current Team", interactive=False) |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("### Default Users") |
|
gr.Markdown(""" |
|
**Team A:** admin_team_a / admin123_team_a |
|
**Team B:** admin_team_b / admin123_team_b |
|
""") |
|
|
|
|
|
with gr.Tab("π Document Management"): |
|
with gr.Column(): |
|
gr.Markdown("### Upload Documents to Team Repository") |
|
folder_name_input = gr.Textbox( |
|
label="Folder/Collection Name (Optional)", |
|
placeholder="Enter a name for this document collection" |
|
) |
|
max_pages_input = gr.Slider( |
|
minimum=1, |
|
maximum=10000, |
|
value=20, |
|
step=10, |
|
label="Max pages to extract and index per document" |
|
) |
|
file_input = gr.Files( |
|
label="Upload PPTs/PDFs (Multiple files supported)", |
|
file_count="multiple" |
|
) |
|
upload_btn = gr.Button("Upload to Repository", variant="primary") |
|
upload_status = gr.Textbox(label="Upload Status", interactive=False) |
|
|
|
gr.Markdown("### Team Collections") |
|
refresh_collections_btn = gr.Button("Refresh Collections") |
|
team_collections_display = gr.Textbox( |
|
label="Available Collections", |
|
interactive=False, |
|
lines=5 |
|
) |
|
|
|
|
|
with gr.Tab("π Advanced Query"): |
|
with gr.Column(): |
|
gr.Markdown("### Multi-Page Document Search") |
|
|
|
query_input = gr.Textbox( |
|
label="Enter your query", |
|
placeholder="Ask about any topic in your documents...", |
|
lines=2 |
|
) |
|
num_results = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
value=3, |
|
step=1, |
|
label="Number of pages to retrieve and cite" |
|
) |
|
search_btn = gr.Button("Search Documents", variant="primary") |
|
|
|
gr.Markdown("### Results") |
|
llm_answer = gr.Textbox( |
|
label="AI Response with Citations", |
|
interactive=False, |
|
lines=8 |
|
) |
|
cited_pages_display = gr.Textbox( |
|
label="Cited Pages", |
|
interactive=False, |
|
lines=3 |
|
) |
|
path = gr.Textbox(label="Document Paths", interactive=False) |
|
images = gr.Gallery(label="Retrieved Pages", show_label=True, columns=2, rows=2, height="auto") |
|
|
|
|
|
gr.Markdown("### π Export Downloads") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
csv_download = gr.File( |
|
label="π CSV Table", |
|
interactive=False, |
|
visible=True |
|
) |
|
with gr.Column(scale=1): |
|
doc_download = gr.File( |
|
label="π DOC Report", |
|
interactive=False, |
|
visible=True |
|
) |
|
with gr.Column(scale=1): |
|
excel_download = gr.File( |
|
label="π Excel Export", |
|
interactive=False, |
|
visible=True |
|
) |
|
|
|
|
|
with gr.Tab("π¬ Chat History"): |
|
with gr.Column(): |
|
gr.Markdown("### π Conversation History") |
|
gr.Markdown("View and manage your previous conversations with the AI assistant.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
history_limit = gr.Slider( |
|
minimum=5, |
|
maximum=50, |
|
value=10, |
|
step=5, |
|
label="Number of recent conversations to display" |
|
) |
|
with gr.Column(scale=1): |
|
refresh_history_btn = gr.Button("π Refresh History", variant="secondary") |
|
clear_history_btn = gr.Button("ποΈ Clear History", variant="stop") |
|
|
|
chat_history_display = gr.Markdown( |
|
label="Recent Conversations", |
|
value="π¬ **Welcome to Chat History!**\n\nLog in and start a conversation to see your chat history here." |
|
) |
|
|
|
|
|
with gr.Tab("βοΈ Data Management"): |
|
with gr.Column(): |
|
gr.Markdown("### Collection Management") |
|
choice = gr.Dropdown( |
|
choices=app.display_file_list(), |
|
label="Select Collection to Delete" |
|
) |
|
delete_button = gr.Button("Delete Collection", variant="stop") |
|
delete_status = gr.Textbox(label="Deletion Status", interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab("β‘ Circuit Diagram Generator"): |
|
with gr.Column(): |
|
gr.Markdown("### Circuit Diagram Generation") |
|
gr.Markdown("Upload a circuit image to generate a netlist and circuit diagram using AI models.") |
|
|
|
circuit_image_input = gr.Image( |
|
type="pil", |
|
label="Upload Circuit Image", |
|
height=300 |
|
) |
|
generate_circuit_btn = gr.Button("Generate Circuit Diagram", variant="primary") |
|
|
|
gr.Markdown("### Results") |
|
circuit_output = gr.Textbox( |
|
label="Processing Results", |
|
interactive=False, |
|
lines=8 |
|
) |
|
circuit_diagram_output = gr.Image( |
|
label="Generated Circuit Diagram", |
|
height=400 |
|
) |
|
|
|
|
|
|
|
login_btn.click( |
|
fn=app.authenticate_user, |
|
inputs=[username_input, password_input], |
|
outputs=[auth_status, session_state, current_team] |
|
) |
|
|
|
logout_btn.click( |
|
fn=app.logout_user, |
|
inputs=[session_state], |
|
outputs=[auth_status, session_state, current_team] |
|
) |
|
|
|
|
|
upload_btn.click( |
|
fn=app.upload_and_convert, |
|
inputs=[session_state, file_input, max_pages_input, session_state, folder_name_input], |
|
outputs=[upload_status] |
|
) |
|
|
|
refresh_collections_btn.click( |
|
fn=app.get_team_collections, |
|
inputs=[session_state], |
|
outputs=[team_collections_display] |
|
) |
|
|
|
|
|
search_btn.click( |
|
fn=app.search_documents, |
|
inputs=[session_state, query_input, num_results, session_state], |
|
outputs=[path, images, llm_answer, cited_pages_display, csv_download, doc_download, excel_download] |
|
) |
|
|
|
|
|
|
|
|
|
refresh_history_btn.click( |
|
fn=app.get_chat_history, |
|
inputs=[session_state, history_limit], |
|
outputs=[chat_history_display] |
|
) |
|
|
|
clear_history_btn.click( |
|
fn=app.clear_chat_history, |
|
inputs=[session_state], |
|
outputs=[chat_history_display] |
|
) |
|
|
|
|
|
delete_button.click( |
|
fn=app.delete, |
|
inputs=[session_state, choice, session_state], |
|
outputs=[delete_status] |
|
) |
|
|
|
|
|
|
|
|
|
generate_circuit_btn.click( |
|
fn=app.process_circuit_image, |
|
inputs=[circuit_image_input], |
|
outputs=[circuit_output, circuit_diagram_output] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_ui() |
|
|
|
demo.launch() |
|
|
|
|