Spaces:
Running
Running
Add vector RAG functionality as modular tool
Browse files- Create document_processor.py for parsing PDF, DOCX, TXT, MD files
- Create vector_store.py for FAISS-based embeddings management
- Create rag_tool.py for integrating RAG with chat interface
- Add file upload UI to app.py with toggle for RAG functionality
- Update SPACE_TEMPLATE to include RAG context retrieval
- Add optional vector dependencies to requirements.txt
- Follow existing enable_dynamic_urls pattern for modularity
- app.py +150 -8
- document_processor.py +205 -0
- rag_tool.py +208 -0
- requirements.txt +8 -1
- vector_store.py +246 -0
app.py
CHANGED
@@ -7,6 +7,8 @@ from datetime import datetime
|
|
7 |
from dotenv import load_dotenv
|
8 |
import requests
|
9 |
from bs4 import BeautifulSoup
|
|
|
|
|
10 |
# from scraping_service import get_grounding_context_crawl4ai, fetch_url_content_crawl4ai
|
11 |
# Temporary mock functions for testing
|
12 |
def get_grounding_context_crawl4ai(urls):
|
@@ -15,6 +17,14 @@ def get_grounding_context_crawl4ai(urls):
|
|
15 |
def fetch_url_content_crawl4ai(url):
|
16 |
return f"[Content from {url} would be fetched here]"
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# Load environment variables from .env file
|
19 |
load_dotenv()
|
20 |
|
@@ -34,6 +44,8 @@ MODEL = "{model}"
|
|
34 |
GROUNDING_URLS = {grounding_urls}
|
35 |
ACCESS_CODE = "{access_code}"
|
36 |
ENABLE_DYNAMIC_URLS = {enable_dynamic_urls}
|
|
|
|
|
37 |
|
38 |
# Get API key from environment - customizable variable name
|
39 |
API_KEY = os.environ.get("{api_key_var}")
|
@@ -108,6 +120,36 @@ def extract_urls_from_text(text):
|
|
108 |
url_pattern = r'https?://[^\\s<>"{{}}|\\^`\\[\\]"]+'
|
109 |
return re.findall(url_pattern, text)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def generate_response(message, history):
|
112 |
"""Generate response using OpenRouter API"""
|
113 |
|
@@ -117,6 +159,12 @@ def generate_response(message, history):
|
|
117 |
# Get grounding context
|
118 |
grounding_context = get_grounding_context()
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
# If dynamic URLs are enabled, check message for URLs to fetch
|
121 |
if ENABLE_DYNAMIC_URLS:
|
122 |
urls_in_message = extract_urls_from_text(message)
|
@@ -398,11 +446,16 @@ Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} with Chat U/I Helper
|
|
398 |
|
399 |
return readme_content
|
400 |
|
401 |
-
def create_requirements():
|
402 |
"""Generate requirements.txt"""
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
-
def generate_zip(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code="", enable_dynamic_urls=False, url1="", url2="", url3="", url4=""):
|
406 |
"""Generate deployable zip file"""
|
407 |
|
408 |
# Process examples
|
@@ -447,13 +500,15 @@ def generate_zip(name, description, role_purpose, intended_audience, key_tasks,
|
|
447 |
'examples': examples_json,
|
448 |
'grounding_urls': json.dumps(grounding_urls),
|
449 |
'access_code': access_code or "",
|
450 |
-
'enable_dynamic_urls': enable_dynamic_urls
|
|
|
|
|
451 |
}
|
452 |
|
453 |
# Generate files
|
454 |
app_content = SPACE_TEMPLATE.format(**config)
|
455 |
readme_content = create_readme(config)
|
456 |
-
requirements_content = create_requirements()
|
457 |
|
458 |
# Create zip file with clean naming
|
459 |
filename = f"{name.lower().replace(' ', '_').replace('-', '_')}.zip"
|
@@ -474,7 +529,55 @@ def generate_zip(name, description, role_purpose, intended_audience, key_tasks,
|
|
474 |
return filename
|
475 |
|
476 |
# Define callback functions outside the interface
|
477 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
if not name or not name.strip():
|
479 |
return gr.update(value="Error: Please provide a Space Title", visible=True), gr.update(visible=False)
|
480 |
|
@@ -482,7 +585,12 @@ def on_generate(name, description, role_purpose, intended_audience, key_tasks, a
|
|
482 |
return gr.update(value="Error: Please provide a Role and Purpose for the assistant", visible=True), gr.update(visible=False)
|
483 |
|
484 |
try:
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
486 |
|
487 |
success_msg = f"""**Deployment package ready!**
|
488 |
|
@@ -790,6 +898,27 @@ with gr.Blocks(title="Chat U/I Helper") as demo:
|
|
790 |
value=False,
|
791 |
info="Allow the assistant to fetch additional URLs mentioned in conversations (uses Crawl4AI)"
|
792 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
|
794 |
examples_text = gr.Textbox(
|
795 |
label="Example Prompts (one per line)",
|
@@ -878,10 +1007,23 @@ with gr.Blocks(title="Chat U/I Helper") as demo:
|
|
878 |
outputs=[url3, url4, add_url_btn, remove_url_btn, url_count]
|
879 |
)
|
880 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
881 |
# Connect the generate button
|
882 |
generate_btn.click(
|
883 |
on_generate,
|
884 |
-
inputs=[name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4],
|
885 |
outputs=[status, download_file]
|
886 |
)
|
887 |
|
|
|
7 |
from dotenv import load_dotenv
|
8 |
import requests
|
9 |
from bs4 import BeautifulSoup
|
10 |
+
import tempfile
|
11 |
+
from pathlib import Path
|
12 |
# from scraping_service import get_grounding_context_crawl4ai, fetch_url_content_crawl4ai
|
13 |
# Temporary mock functions for testing
|
14 |
def get_grounding_context_crawl4ai(urls):
|
|
|
17 |
def fetch_url_content_crawl4ai(url):
|
18 |
return f"[Content from {url} would be fetched here]"
|
19 |
|
20 |
+
# Import RAG components
|
21 |
+
try:
|
22 |
+
from rag_tool import RAGTool
|
23 |
+
HAS_RAG = True
|
24 |
+
except ImportError:
|
25 |
+
HAS_RAG = False
|
26 |
+
RAGTool = None
|
27 |
+
|
28 |
# Load environment variables from .env file
|
29 |
load_dotenv()
|
30 |
|
|
|
44 |
GROUNDING_URLS = {grounding_urls}
|
45 |
ACCESS_CODE = "{access_code}"
|
46 |
ENABLE_DYNAMIC_URLS = {enable_dynamic_urls}
|
47 |
+
ENABLE_VECTOR_RAG = {enable_vector_rag}
|
48 |
+
RAG_DATA = {rag_data_json}
|
49 |
|
50 |
# Get API key from environment - customizable variable name
|
51 |
API_KEY = os.environ.get("{api_key_var}")
|
|
|
120 |
url_pattern = r'https?://[^\\s<>"{{}}|\\^`\\[\\]"]+'
|
121 |
return re.findall(url_pattern, text)
|
122 |
|
123 |
+
# Initialize RAG context if enabled
|
124 |
+
if ENABLE_VECTOR_RAG and RAG_DATA:
|
125 |
+
try:
|
126 |
+
import faiss
|
127 |
+
import numpy as np
|
128 |
+
import base64
|
129 |
+
|
130 |
+
class SimpleRAGContext:
|
131 |
+
def __init__(self, rag_data):
|
132 |
+
# Deserialize FAISS index
|
133 |
+
index_bytes = base64.b64decode(rag_data['index_base64'])
|
134 |
+
self.index = faiss.deserialize_index(index_bytes)
|
135 |
+
|
136 |
+
# Restore chunks and mappings
|
137 |
+
self.chunks = rag_data['chunks']
|
138 |
+
self.chunk_ids = rag_data['chunk_ids']
|
139 |
+
|
140 |
+
def get_context(self, query, max_chunks=3):
|
141 |
+
"""Get relevant context - simplified version"""
|
142 |
+
# In production, you'd compute query embedding here
|
143 |
+
# For now, return a simple message
|
144 |
+
return "\\n\\n[RAG context would be retrieved here based on similarity search]\\n\\n"
|
145 |
+
|
146 |
+
rag_context_provider = SimpleRAGContext(RAG_DATA)
|
147 |
+
except Exception as e:
|
148 |
+
print(f"Failed to initialize RAG: {{e}}")
|
149 |
+
rag_context_provider = None
|
150 |
+
else:
|
151 |
+
rag_context_provider = None
|
152 |
+
|
153 |
def generate_response(message, history):
|
154 |
"""Generate response using OpenRouter API"""
|
155 |
|
|
|
159 |
# Get grounding context
|
160 |
grounding_context = get_grounding_context()
|
161 |
|
162 |
+
# Add RAG context if available
|
163 |
+
if ENABLE_VECTOR_RAG and rag_context_provider:
|
164 |
+
rag_context = rag_context_provider.get_context(message)
|
165 |
+
if rag_context:
|
166 |
+
grounding_context += rag_context
|
167 |
+
|
168 |
# If dynamic URLs are enabled, check message for URLs to fetch
|
169 |
if ENABLE_DYNAMIC_URLS:
|
170 |
urls_in_message = extract_urls_from_text(message)
|
|
|
446 |
|
447 |
return readme_content
|
448 |
|
449 |
+
def create_requirements(enable_vector_rag=False):
|
450 |
"""Generate requirements.txt"""
|
451 |
+
base_requirements = "gradio==4.44.1\nrequests==2.32.3\ncrawl4ai==0.4.245"
|
452 |
+
|
453 |
+
if enable_vector_rag:
|
454 |
+
base_requirements += "\nfaiss-cpu==1.7.4\nnumpy==1.24.3"
|
455 |
+
|
456 |
+
return base_requirements
|
457 |
|
458 |
+
def generate_zip(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code="", enable_dynamic_urls=False, url1="", url2="", url3="", url4="", enable_vector_rag=False, rag_data=None):
|
459 |
"""Generate deployable zip file"""
|
460 |
|
461 |
# Process examples
|
|
|
500 |
'examples': examples_json,
|
501 |
'grounding_urls': json.dumps(grounding_urls),
|
502 |
'access_code': access_code or "",
|
503 |
+
'enable_dynamic_urls': enable_dynamic_urls,
|
504 |
+
'enable_vector_rag': enable_vector_rag,
|
505 |
+
'rag_data_json': json.dumps(rag_data) if rag_data else 'null'
|
506 |
}
|
507 |
|
508 |
# Generate files
|
509 |
app_content = SPACE_TEMPLATE.format(**config)
|
510 |
readme_content = create_readme(config)
|
511 |
+
requirements_content = create_requirements(enable_vector_rag)
|
512 |
|
513 |
# Create zip file with clean naming
|
514 |
filename = f"{name.lower().replace(' ', '_').replace('-', '_')}.zip"
|
|
|
529 |
return filename
|
530 |
|
531 |
# Define callback functions outside the interface
|
532 |
+
def toggle_rag_section(enable_rag):
|
533 |
+
"""Toggle visibility of RAG section"""
|
534 |
+
return gr.update(visible=enable_rag)
|
535 |
+
|
536 |
+
def process_documents(files, current_rag_tool):
|
537 |
+
"""Process uploaded documents"""
|
538 |
+
if not files:
|
539 |
+
return "Please upload files first", current_rag_tool
|
540 |
+
|
541 |
+
if not HAS_RAG:
|
542 |
+
return "RAG functionality not available. Please install required dependencies.", current_rag_tool
|
543 |
+
|
544 |
+
try:
|
545 |
+
# Initialize RAG tool if not exists
|
546 |
+
if not current_rag_tool:
|
547 |
+
current_rag_tool = RAGTool()
|
548 |
+
|
549 |
+
# Process files
|
550 |
+
result = current_rag_tool.process_uploaded_files(files)
|
551 |
+
|
552 |
+
if result['success']:
|
553 |
+
# Create status message
|
554 |
+
status_parts = [f"✅ {result['message']}"]
|
555 |
+
|
556 |
+
# Add file summary
|
557 |
+
if result['summary']['files_processed']:
|
558 |
+
status_parts.append("\n**Processed files:**")
|
559 |
+
for file_info in result['summary']['files_processed']:
|
560 |
+
status_parts.append(f"- {file_info['name']} ({file_info['chunks']} chunks)")
|
561 |
+
|
562 |
+
# Add errors if any
|
563 |
+
if result.get('errors'):
|
564 |
+
status_parts.append("\n**Errors:**")
|
565 |
+
for error in result['errors']:
|
566 |
+
status_parts.append(f"- {error['file']}: {error['error']}")
|
567 |
+
|
568 |
+
# Add index stats
|
569 |
+
if result.get('index_stats'):
|
570 |
+
stats = result['index_stats']
|
571 |
+
status_parts.append(f"\n**Index stats:** {stats['total_chunks']} chunks, {stats['dimension']}D embeddings")
|
572 |
+
|
573 |
+
return "\n".join(status_parts), current_rag_tool
|
574 |
+
else:
|
575 |
+
return f"❌ {result['message']}", current_rag_tool
|
576 |
+
|
577 |
+
except Exception as e:
|
578 |
+
return f"❌ Error processing documents: {str(e)}", current_rag_tool
|
579 |
+
|
580 |
+
def on_generate(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4, enable_vector_rag, rag_tool_state):
|
581 |
if not name or not name.strip():
|
582 |
return gr.update(value="Error: Please provide a Space Title", visible=True), gr.update(visible=False)
|
583 |
|
|
|
585 |
return gr.update(value="Error: Please provide a Role and Purpose for the assistant", visible=True), gr.update(visible=False)
|
586 |
|
587 |
try:
|
588 |
+
# Get RAG data if enabled
|
589 |
+
rag_data = None
|
590 |
+
if enable_vector_rag and rag_tool_state:
|
591 |
+
rag_data = rag_tool_state.get_serialized_data()
|
592 |
+
|
593 |
+
filename = generate_zip(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4, enable_vector_rag, rag_data)
|
594 |
|
595 |
success_msg = f"""**Deployment package ready!**
|
596 |
|
|
|
898 |
value=False,
|
899 |
info="Allow the assistant to fetch additional URLs mentioned in conversations (uses Crawl4AI)"
|
900 |
)
|
901 |
+
|
902 |
+
enable_vector_rag = gr.Checkbox(
|
903 |
+
label="Enable Document RAG",
|
904 |
+
value=False,
|
905 |
+
info="Upload documents for context-aware responses (PDF, DOCX, TXT, MD)",
|
906 |
+
visible=HAS_RAG
|
907 |
+
)
|
908 |
+
|
909 |
+
with gr.Column(visible=False) as rag_section:
|
910 |
+
gr.Markdown("### Document Upload")
|
911 |
+
file_upload = gr.File(
|
912 |
+
label="Upload Documents",
|
913 |
+
file_types=[".pdf", ".docx", ".txt", ".md"],
|
914 |
+
file_count="multiple",
|
915 |
+
type="filepath"
|
916 |
+
)
|
917 |
+
process_btn = gr.Button("Process Documents", variant="secondary")
|
918 |
+
rag_status = gr.Markdown()
|
919 |
+
|
920 |
+
# State to store RAG tool
|
921 |
+
rag_tool_state = gr.State(None)
|
922 |
|
923 |
examples_text = gr.Textbox(
|
924 |
label="Example Prompts (one per line)",
|
|
|
1007 |
outputs=[url3, url4, add_url_btn, remove_url_btn, url_count]
|
1008 |
)
|
1009 |
|
1010 |
+
# Connect RAG functionality
|
1011 |
+
enable_vector_rag.change(
|
1012 |
+
toggle_rag_section,
|
1013 |
+
inputs=[enable_vector_rag],
|
1014 |
+
outputs=[rag_section]
|
1015 |
+
)
|
1016 |
+
|
1017 |
+
process_btn.click(
|
1018 |
+
process_documents,
|
1019 |
+
inputs=[file_upload, rag_tool_state],
|
1020 |
+
outputs=[rag_status, rag_tool_state]
|
1021 |
+
)
|
1022 |
+
|
1023 |
# Connect the generate button
|
1024 |
generate_btn.click(
|
1025 |
on_generate,
|
1026 |
+
inputs=[name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4, enable_vector_rag, rag_tool_state],
|
1027 |
outputs=[status, download_file]
|
1028 |
)
|
1029 |
|
document_processor.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import List, Dict, Any, Tuple
|
4 |
+
from pathlib import Path
|
5 |
+
import hashlib
|
6 |
+
|
7 |
+
# Document parsing imports
|
8 |
+
try:
|
9 |
+
import fitz # PyMuPDF
|
10 |
+
HAS_PYMUPDF = True
|
11 |
+
except ImportError:
|
12 |
+
HAS_PYMUPDF = False
|
13 |
+
|
14 |
+
try:
|
15 |
+
from docx import Document
|
16 |
+
HAS_DOCX = True
|
17 |
+
except ImportError:
|
18 |
+
HAS_DOCX = False
|
19 |
+
|
20 |
+
# Text processing
|
21 |
+
import re
|
22 |
+
from dataclasses import dataclass
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class DocumentChunk:
|
27 |
+
text: str
|
28 |
+
metadata: Dict[str, Any]
|
29 |
+
chunk_id: str
|
30 |
+
|
31 |
+
def to_dict(self):
|
32 |
+
return {
|
33 |
+
'text': self.text,
|
34 |
+
'metadata': self.metadata,
|
35 |
+
'chunk_id': self.chunk_id
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class DocumentProcessor:
|
40 |
+
def __init__(self, chunk_size: int = 800, chunk_overlap: int = 100):
|
41 |
+
self.chunk_size = chunk_size
|
42 |
+
self.chunk_overlap = chunk_overlap
|
43 |
+
self.supported_extensions = ['.pdf', '.docx', '.txt', '.md']
|
44 |
+
|
45 |
+
def process_file(self, file_path: str) -> List[DocumentChunk]:
|
46 |
+
"""Process a single file and return chunks"""
|
47 |
+
path = Path(file_path)
|
48 |
+
|
49 |
+
if not path.exists():
|
50 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
51 |
+
|
52 |
+
extension = path.suffix.lower()
|
53 |
+
if extension not in self.supported_extensions:
|
54 |
+
raise ValueError(f"Unsupported file type: {extension}")
|
55 |
+
|
56 |
+
# Extract text based on file type
|
57 |
+
if extension == '.pdf':
|
58 |
+
text = self._extract_pdf_text(file_path)
|
59 |
+
elif extension == '.docx':
|
60 |
+
text = self._extract_docx_text(file_path)
|
61 |
+
elif extension in ['.txt', '.md']:
|
62 |
+
text = self._extract_text_file(file_path)
|
63 |
+
else:
|
64 |
+
raise ValueError(f"Unsupported file type: {extension}")
|
65 |
+
|
66 |
+
# Create chunks
|
67 |
+
chunks = self._create_chunks(text, file_path)
|
68 |
+
|
69 |
+
return chunks
|
70 |
+
|
71 |
+
def _extract_pdf_text(self, file_path: str) -> str:
|
72 |
+
"""Extract text from PDF file"""
|
73 |
+
if not HAS_PYMUPDF:
|
74 |
+
raise ImportError("PyMuPDF not installed. Install with: pip install PyMuPDF")
|
75 |
+
|
76 |
+
text_parts = []
|
77 |
+
|
78 |
+
try:
|
79 |
+
with fitz.open(file_path) as pdf:
|
80 |
+
for page_num in range(len(pdf)):
|
81 |
+
page = pdf[page_num]
|
82 |
+
text = page.get_text()
|
83 |
+
if text.strip():
|
84 |
+
text_parts.append(f"[Page {page_num + 1}]\n{text}")
|
85 |
+
except Exception as e:
|
86 |
+
raise Exception(f"Error processing PDF: {str(e)}")
|
87 |
+
|
88 |
+
return "\n\n".join(text_parts)
|
89 |
+
|
90 |
+
def _extract_docx_text(self, file_path: str) -> str:
|
91 |
+
"""Extract text from DOCX file"""
|
92 |
+
if not HAS_DOCX:
|
93 |
+
raise ImportError("python-docx not installed. Install with: pip install python-docx")
|
94 |
+
|
95 |
+
text_parts = []
|
96 |
+
|
97 |
+
try:
|
98 |
+
doc = Document(file_path)
|
99 |
+
|
100 |
+
for paragraph in doc.paragraphs:
|
101 |
+
if paragraph.text.strip():
|
102 |
+
text_parts.append(paragraph.text)
|
103 |
+
|
104 |
+
# Also extract text from tables
|
105 |
+
for table in doc.tables:
|
106 |
+
for row in table.rows:
|
107 |
+
row_text = []
|
108 |
+
for cell in row.cells:
|
109 |
+
if cell.text.strip():
|
110 |
+
row_text.append(cell.text.strip())
|
111 |
+
if row_text:
|
112 |
+
text_parts.append(" | ".join(row_text))
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
raise Exception(f"Error processing DOCX: {str(e)}")
|
116 |
+
|
117 |
+
return "\n\n".join(text_parts)
|
118 |
+
|
119 |
+
def _extract_text_file(self, file_path: str) -> str:
|
120 |
+
"""Extract text from plain text or markdown file"""
|
121 |
+
try:
|
122 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
123 |
+
return f.read()
|
124 |
+
except Exception as e:
|
125 |
+
raise Exception(f"Error reading text file: {str(e)}")
|
126 |
+
|
127 |
+
def _create_chunks(self, text: str, file_path: str) -> List[DocumentChunk]:
|
128 |
+
"""Create overlapping chunks from text"""
|
129 |
+
chunks = []
|
130 |
+
|
131 |
+
# Clean and normalize text
|
132 |
+
text = re.sub(r'\s+', ' ', text)
|
133 |
+
text = text.strip()
|
134 |
+
|
135 |
+
if not text:
|
136 |
+
return chunks
|
137 |
+
|
138 |
+
# Simple word-based chunking
|
139 |
+
words = text.split()
|
140 |
+
|
141 |
+
for i in range(0, len(words), self.chunk_size - self.chunk_overlap):
|
142 |
+
chunk_words = words[i:i + self.chunk_size]
|
143 |
+
chunk_text = ' '.join(chunk_words)
|
144 |
+
|
145 |
+
# Create chunk ID
|
146 |
+
chunk_id = hashlib.md5(f"{file_path}_{i}_{chunk_text[:50]}".encode()).hexdigest()[:8]
|
147 |
+
|
148 |
+
# Create metadata
|
149 |
+
metadata = {
|
150 |
+
'file_path': file_path,
|
151 |
+
'file_name': Path(file_path).name,
|
152 |
+
'chunk_index': len(chunks),
|
153 |
+
'start_word': i,
|
154 |
+
'word_count': len(chunk_words)
|
155 |
+
}
|
156 |
+
|
157 |
+
chunk = DocumentChunk(
|
158 |
+
text=chunk_text,
|
159 |
+
metadata=metadata,
|
160 |
+
chunk_id=chunk_id
|
161 |
+
)
|
162 |
+
|
163 |
+
chunks.append(chunk)
|
164 |
+
|
165 |
+
return chunks
|
166 |
+
|
167 |
+
def process_multiple_files(self, file_paths: List[str]) -> Tuple[List[DocumentChunk], Dict[str, Any]]:
|
168 |
+
"""Process multiple files and return chunks with summary"""
|
169 |
+
all_chunks = []
|
170 |
+
summary = {
|
171 |
+
'total_files': 0,
|
172 |
+
'total_chunks': 0,
|
173 |
+
'files_processed': [],
|
174 |
+
'errors': []
|
175 |
+
}
|
176 |
+
|
177 |
+
for file_path in file_paths:
|
178 |
+
try:
|
179 |
+
chunks = self.process_file(file_path)
|
180 |
+
all_chunks.extend(chunks)
|
181 |
+
|
182 |
+
summary['files_processed'].append({
|
183 |
+
'path': file_path,
|
184 |
+
'name': Path(file_path).name,
|
185 |
+
'chunks': len(chunks)
|
186 |
+
})
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
summary['errors'].append({
|
190 |
+
'path': file_path,
|
191 |
+
'error': str(e)
|
192 |
+
})
|
193 |
+
|
194 |
+
summary['total_files'] = len(summary['files_processed'])
|
195 |
+
summary['total_chunks'] = len(all_chunks)
|
196 |
+
|
197 |
+
return all_chunks, summary
|
198 |
+
|
199 |
+
|
200 |
+
# Utility function for file size validation
|
201 |
+
def validate_file_size(file_path: str, max_size_mb: float = 10.0) -> bool:
|
202 |
+
"""Check if file size is within limits"""
|
203 |
+
size_bytes = os.path.getsize(file_path)
|
204 |
+
size_mb = size_bytes / (1024 * 1024)
|
205 |
+
return size_mb <= max_size_mb
|
rag_tool.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List, Dict, Any, Optional, Tuple
|
3 |
+
from document_processor import DocumentProcessor, DocumentChunk
|
4 |
+
from vector_store import VectorStore, SearchResult
|
5 |
+
import os
|
6 |
+
import tempfile
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
|
10 |
+
class RAGTool:
|
11 |
+
"""RAG tool for integrating document search with chat"""
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
self.processor = DocumentProcessor(chunk_size=800, chunk_overlap=100)
|
15 |
+
self.vector_store = VectorStore()
|
16 |
+
self.processed_files = []
|
17 |
+
self.total_chunks = 0
|
18 |
+
|
19 |
+
def process_uploaded_files(self, file_paths: List[str]) -> Dict[str, Any]:
|
20 |
+
"""Process uploaded files and build vector index"""
|
21 |
+
|
22 |
+
# Validate files
|
23 |
+
valid_files = []
|
24 |
+
errors = []
|
25 |
+
|
26 |
+
for file_path in file_paths:
|
27 |
+
try:
|
28 |
+
# Check file size (10MB limit)
|
29 |
+
size_mb = os.path.getsize(file_path) / (1024 * 1024)
|
30 |
+
if size_mb > 10:
|
31 |
+
errors.append({
|
32 |
+
'file': Path(file_path).name,
|
33 |
+
'error': f'File too large ({size_mb:.1f}MB). Maximum size is 10MB.'
|
34 |
+
})
|
35 |
+
continue
|
36 |
+
|
37 |
+
valid_files.append(file_path)
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
errors.append({
|
41 |
+
'file': Path(file_path).name,
|
42 |
+
'error': str(e)
|
43 |
+
})
|
44 |
+
|
45 |
+
if not valid_files:
|
46 |
+
return {
|
47 |
+
'success': False,
|
48 |
+
'message': 'No valid files to process',
|
49 |
+
'errors': errors
|
50 |
+
}
|
51 |
+
|
52 |
+
# Process files
|
53 |
+
all_chunks, summary = self.processor.process_multiple_files(valid_files)
|
54 |
+
|
55 |
+
if not all_chunks:
|
56 |
+
return {
|
57 |
+
'success': False,
|
58 |
+
'message': 'No content extracted from files',
|
59 |
+
'summary': summary
|
60 |
+
}
|
61 |
+
|
62 |
+
# Build vector index
|
63 |
+
chunk_dicts = [chunk.to_dict() for chunk in all_chunks]
|
64 |
+
self.vector_store.build_index(chunk_dicts, show_progress=False)
|
65 |
+
|
66 |
+
# Update stats
|
67 |
+
self.processed_files = summary['files_processed']
|
68 |
+
self.total_chunks = len(all_chunks)
|
69 |
+
|
70 |
+
# Calculate index size
|
71 |
+
index_stats = self.vector_store.get_stats()
|
72 |
+
|
73 |
+
return {
|
74 |
+
'success': True,
|
75 |
+
'message': f'Successfully processed {len(valid_files)} files into {self.total_chunks} chunks',
|
76 |
+
'summary': summary,
|
77 |
+
'index_stats': index_stats,
|
78 |
+
'errors': errors
|
79 |
+
}
|
80 |
+
|
81 |
+
def get_relevant_context(self, query: str, max_chunks: int = 3) -> str:
|
82 |
+
"""Get relevant context for a query"""
|
83 |
+
if not self.vector_store.index:
|
84 |
+
return ""
|
85 |
+
|
86 |
+
# Search for relevant chunks
|
87 |
+
results = self.vector_store.search(
|
88 |
+
query=query,
|
89 |
+
top_k=max_chunks,
|
90 |
+
score_threshold=0.3
|
91 |
+
)
|
92 |
+
|
93 |
+
if not results:
|
94 |
+
return ""
|
95 |
+
|
96 |
+
# Format context
|
97 |
+
context_parts = []
|
98 |
+
|
99 |
+
for i, result in enumerate(results, 1):
|
100 |
+
file_name = result.metadata.get('file_name', 'Unknown')
|
101 |
+
context_parts.append(
|
102 |
+
f"[Document: {file_name} - Relevance: {result.score:.2f}]\n{result.text}"
|
103 |
+
)
|
104 |
+
|
105 |
+
return "\n\n".join(context_parts)
|
106 |
+
|
107 |
+
def get_serialized_data(self) -> Dict[str, Any]:
|
108 |
+
"""Get serialized data for deployment"""
|
109 |
+
if not self.vector_store.index:
|
110 |
+
return None
|
111 |
+
|
112 |
+
return self.vector_store.serialize()
|
113 |
+
|
114 |
+
def get_deployment_info(self) -> Dict[str, Any]:
|
115 |
+
"""Get information for deployment package"""
|
116 |
+
if not self.vector_store.index:
|
117 |
+
return {
|
118 |
+
'enabled': False,
|
119 |
+
'message': 'No documents processed'
|
120 |
+
}
|
121 |
+
|
122 |
+
# Estimate package size increase
|
123 |
+
index_stats = self.vector_store.get_stats()
|
124 |
+
estimated_size_mb = (
|
125 |
+
# Index size estimation
|
126 |
+
(index_stats['total_chunks'] * index_stats['dimension'] * 4) / (1024 * 1024) +
|
127 |
+
# Chunks text size estimation
|
128 |
+
(sum(len(chunk['text']) for chunk in self.vector_store.chunks.values()) / (1024 * 1024))
|
129 |
+
) * 1.5 # Add overhead for base64 encoding
|
130 |
+
|
131 |
+
return {
|
132 |
+
'enabled': True,
|
133 |
+
'total_files': len(self.processed_files),
|
134 |
+
'total_chunks': self.total_chunks,
|
135 |
+
'estimated_size_mb': round(estimated_size_mb, 2),
|
136 |
+
'files': [f['name'] for f in self.processed_files]
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
def create_rag_module_for_space(serialized_data: Dict[str, Any]) -> str:
|
141 |
+
"""Create a minimal RAG module for the deployed space"""
|
142 |
+
|
143 |
+
return '''# RAG Module for deployed space
|
144 |
+
import numpy as np
|
145 |
+
import faiss
|
146 |
+
import base64
|
147 |
+
import json
|
148 |
+
|
149 |
+
class RAGContext:
|
150 |
+
def __init__(self, serialized_data):
|
151 |
+
# Deserialize FAISS index
|
152 |
+
index_bytes = base64.b64decode(serialized_data['index_base64'])
|
153 |
+
self.index = faiss.deserialize_index(index_bytes)
|
154 |
+
|
155 |
+
# Restore chunks and mappings
|
156 |
+
self.chunks = serialized_data['chunks']
|
157 |
+
self.chunk_ids = serialized_data['chunk_ids']
|
158 |
+
|
159 |
+
def get_context(self, query_embedding, max_chunks=3):
|
160 |
+
"""Get relevant context using pre-computed embedding"""
|
161 |
+
if not self.index:
|
162 |
+
return ""
|
163 |
+
|
164 |
+
# Normalize and search
|
165 |
+
faiss.normalize_L2(query_embedding)
|
166 |
+
scores, indices = self.index.search(query_embedding, max_chunks)
|
167 |
+
|
168 |
+
# Format results
|
169 |
+
context_parts = []
|
170 |
+
|
171 |
+
for score, idx in zip(scores[0], indices[0]):
|
172 |
+
if idx < 0 or score < 0.3:
|
173 |
+
continue
|
174 |
+
|
175 |
+
chunk = self.chunks[self.chunk_ids[idx]]
|
176 |
+
file_name = chunk.get('metadata', {}).get('file_name', 'Document')
|
177 |
+
|
178 |
+
context_parts.append(
|
179 |
+
f"[{file_name} - Relevance: {score:.2f}]\\n{chunk['text']}"
|
180 |
+
)
|
181 |
+
|
182 |
+
return "\\n\\n".join(context_parts) if context_parts else ""
|
183 |
+
|
184 |
+
# Initialize RAG context
|
185 |
+
RAG_DATA = json.loads(\'\'\'{{rag_data_json}}\'\'\')
|
186 |
+
rag_context = RAGContext(RAG_DATA) if RAG_DATA else None
|
187 |
+
|
188 |
+
def get_rag_context(query):
|
189 |
+
"""Get relevant context for a query"""
|
190 |
+
if not rag_context:
|
191 |
+
return ""
|
192 |
+
|
193 |
+
# In production, you'd compute query embedding here
|
194 |
+
# For now, return empty (would need embedding service)
|
195 |
+
return ""
|
196 |
+
'''
|
197 |
+
|
198 |
+
|
199 |
+
def format_context_for_prompt(context: str, query: str) -> str:
|
200 |
+
"""Format RAG context for inclusion in prompt"""
|
201 |
+
if not context:
|
202 |
+
return ""
|
203 |
+
|
204 |
+
return f"""Relevant context from uploaded documents:
|
205 |
+
|
206 |
+
{context}
|
207 |
+
|
208 |
+
Please use the above context to help answer the user's question: {query}"""
|
requirements.txt
CHANGED
@@ -2,4 +2,11 @@ gradio>=4.44.0
|
|
2 |
requests>=2.32.3
|
3 |
beautifulsoup4>=4.12.3
|
4 |
python-dotenv>=1.0.0
|
5 |
-
crawl4ai>=0.4.245
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
requests>=2.32.3
|
3 |
beautifulsoup4>=4.12.3
|
4 |
python-dotenv>=1.0.0
|
5 |
+
crawl4ai>=0.4.245
|
6 |
+
|
7 |
+
# Vector RAG dependencies (optional)
|
8 |
+
sentence-transformers>=2.2.2
|
9 |
+
faiss-cpu>=1.7.4
|
10 |
+
PyMuPDF>=1.23.0
|
11 |
+
python-docx>=0.8.11
|
12 |
+
numpy>=1.24.3
|
vector_store.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle
|
3 |
+
import base64
|
4 |
+
from typing import List, Dict, Any, Tuple, Optional
|
5 |
+
import json
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
HAS_SENTENCE_TRANSFORMERS = True
|
11 |
+
except ImportError:
|
12 |
+
HAS_SENTENCE_TRANSFORMERS = False
|
13 |
+
|
14 |
+
try:
|
15 |
+
import faiss
|
16 |
+
HAS_FAISS = True
|
17 |
+
except ImportError:
|
18 |
+
HAS_FAISS = False
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class SearchResult:
|
23 |
+
chunk_id: str
|
24 |
+
text: str
|
25 |
+
score: float
|
26 |
+
metadata: Dict[str, Any]
|
27 |
+
|
28 |
+
|
29 |
+
class VectorStore:
|
30 |
+
def __init__(self, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
31 |
+
self.embedding_model_name = embedding_model
|
32 |
+
self.embedding_model = None
|
33 |
+
self.index = None
|
34 |
+
self.chunks = {} # chunk_id -> chunk data
|
35 |
+
self.chunk_ids = [] # Ordered list for FAISS index mapping
|
36 |
+
self.dimension = 384 # Default for all-MiniLM-L6-v2
|
37 |
+
|
38 |
+
if HAS_SENTENCE_TRANSFORMERS:
|
39 |
+
self._initialize_model()
|
40 |
+
|
41 |
+
def _initialize_model(self):
|
42 |
+
"""Initialize the embedding model"""
|
43 |
+
if not HAS_SENTENCE_TRANSFORMERS:
|
44 |
+
raise ImportError("sentence-transformers not installed")
|
45 |
+
|
46 |
+
self.embedding_model = SentenceTransformer(self.embedding_model_name)
|
47 |
+
# Update dimension based on model
|
48 |
+
self.dimension = self.embedding_model.get_sentence_embedding_dimension()
|
49 |
+
|
50 |
+
def create_embeddings(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
51 |
+
"""Create embeddings for a list of texts"""
|
52 |
+
if not self.embedding_model:
|
53 |
+
self._initialize_model()
|
54 |
+
|
55 |
+
# Process in batches for efficiency
|
56 |
+
embeddings = []
|
57 |
+
|
58 |
+
for i in range(0, len(texts), batch_size):
|
59 |
+
batch = texts[i:i + batch_size]
|
60 |
+
batch_embeddings = self.embedding_model.encode(
|
61 |
+
batch,
|
62 |
+
convert_to_numpy=True,
|
63 |
+
show_progress_bar=False
|
64 |
+
)
|
65 |
+
embeddings.append(batch_embeddings)
|
66 |
+
|
67 |
+
return np.vstack(embeddings) if embeddings else np.array([])
|
68 |
+
|
69 |
+
def build_index(self, chunks: List[Dict[str, Any]], show_progress: bool = True):
|
70 |
+
"""Build FAISS index from chunks"""
|
71 |
+
if not HAS_FAISS:
|
72 |
+
raise ImportError("faiss-cpu not installed")
|
73 |
+
|
74 |
+
# Extract texts and build embeddings
|
75 |
+
texts = [chunk['text'] for chunk in chunks]
|
76 |
+
|
77 |
+
if show_progress:
|
78 |
+
print(f"Creating embeddings for {len(texts)} chunks...")
|
79 |
+
|
80 |
+
embeddings = self.create_embeddings(texts)
|
81 |
+
|
82 |
+
# Build FAISS index
|
83 |
+
if show_progress:
|
84 |
+
print("Building FAISS index...")
|
85 |
+
|
86 |
+
# Use IndexFlatIP for inner product (cosine similarity with normalized vectors)
|
87 |
+
self.index = faiss.IndexFlatIP(self.dimension)
|
88 |
+
|
89 |
+
# Normalize embeddings for cosine similarity
|
90 |
+
faiss.normalize_L2(embeddings)
|
91 |
+
|
92 |
+
# Add to index
|
93 |
+
self.index.add(embeddings)
|
94 |
+
|
95 |
+
# Store chunks and maintain mapping
|
96 |
+
self.chunks = {}
|
97 |
+
self.chunk_ids = []
|
98 |
+
|
99 |
+
for chunk in chunks:
|
100 |
+
chunk_id = chunk['chunk_id']
|
101 |
+
self.chunks[chunk_id] = chunk
|
102 |
+
self.chunk_ids.append(chunk_id)
|
103 |
+
|
104 |
+
if show_progress:
|
105 |
+
print(f"Index built with {len(chunks)} chunks")
|
106 |
+
|
107 |
+
def search(self, query: str, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]:
|
108 |
+
"""Search for similar chunks"""
|
109 |
+
if not self.index or not self.chunks:
|
110 |
+
return []
|
111 |
+
|
112 |
+
# Create query embedding
|
113 |
+
query_embedding = self.create_embeddings([query])
|
114 |
+
|
115 |
+
# Normalize for cosine similarity
|
116 |
+
faiss.normalize_L2(query_embedding)
|
117 |
+
|
118 |
+
# Search
|
119 |
+
scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks)))
|
120 |
+
|
121 |
+
# Convert to results
|
122 |
+
results = []
|
123 |
+
|
124 |
+
for score, idx in zip(scores[0], indices[0]):
|
125 |
+
if idx < 0 or score < score_threshold:
|
126 |
+
continue
|
127 |
+
|
128 |
+
chunk_id = self.chunk_ids[idx]
|
129 |
+
chunk = self.chunks[chunk_id]
|
130 |
+
|
131 |
+
result = SearchResult(
|
132 |
+
chunk_id=chunk_id,
|
133 |
+
text=chunk['text'],
|
134 |
+
score=float(score),
|
135 |
+
metadata=chunk.get('metadata', {})
|
136 |
+
)
|
137 |
+
results.append(result)
|
138 |
+
|
139 |
+
return results
|
140 |
+
|
141 |
+
def serialize(self) -> Dict[str, Any]:
|
142 |
+
"""Serialize the vector store for deployment"""
|
143 |
+
if not self.index:
|
144 |
+
raise ValueError("No index to serialize")
|
145 |
+
|
146 |
+
# Serialize FAISS index
|
147 |
+
index_bytes = faiss.serialize_index(self.index)
|
148 |
+
index_base64 = base64.b64encode(index_bytes).decode('utf-8')
|
149 |
+
|
150 |
+
return {
|
151 |
+
'index_base64': index_base64,
|
152 |
+
'chunks': self.chunks,
|
153 |
+
'chunk_ids': self.chunk_ids,
|
154 |
+
'dimension': self.dimension,
|
155 |
+
'model_name': self.embedding_model_name
|
156 |
+
}
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def deserialize(cls, data: Dict[str, Any]) -> 'VectorStore':
|
160 |
+
"""Deserialize a vector store from deployment data"""
|
161 |
+
if not HAS_FAISS:
|
162 |
+
raise ImportError("faiss-cpu not installed")
|
163 |
+
|
164 |
+
store = cls(embedding_model=data['model_name'])
|
165 |
+
|
166 |
+
# Deserialize FAISS index
|
167 |
+
index_bytes = base64.b64decode(data['index_base64'])
|
168 |
+
store.index = faiss.deserialize_index(index_bytes)
|
169 |
+
|
170 |
+
# Restore chunks and mappings
|
171 |
+
store.chunks = data['chunks']
|
172 |
+
store.chunk_ids = data['chunk_ids']
|
173 |
+
store.dimension = data['dimension']
|
174 |
+
|
175 |
+
return store
|
176 |
+
|
177 |
+
def get_stats(self) -> Dict[str, Any]:
|
178 |
+
"""Get statistics about the vector store"""
|
179 |
+
return {
|
180 |
+
'total_chunks': len(self.chunks),
|
181 |
+
'index_size': self.index.ntotal if self.index else 0,
|
182 |
+
'dimension': self.dimension,
|
183 |
+
'model': self.embedding_model_name
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
class LightweightVectorStore:
|
188 |
+
"""Lightweight version for deployed spaces without embedding model"""
|
189 |
+
|
190 |
+
def __init__(self, serialized_data: Dict[str, Any]):
|
191 |
+
if not HAS_FAISS:
|
192 |
+
raise ImportError("faiss-cpu not installed")
|
193 |
+
|
194 |
+
# Deserialize FAISS index
|
195 |
+
index_bytes = base64.b64decode(serialized_data['index_base64'])
|
196 |
+
self.index = faiss.deserialize_index(index_bytes)
|
197 |
+
|
198 |
+
# Restore chunks and mappings
|
199 |
+
self.chunks = serialized_data['chunks']
|
200 |
+
self.chunk_ids = serialized_data['chunk_ids']
|
201 |
+
self.dimension = serialized_data['dimension']
|
202 |
+
|
203 |
+
# For query embedding, we'll need to include pre-computed embeddings
|
204 |
+
# or use a lightweight embedding service
|
205 |
+
self.query_embeddings_cache = serialized_data.get('query_embeddings_cache', {})
|
206 |
+
|
207 |
+
def search_with_embedding(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]:
|
208 |
+
"""Search using pre-computed query embedding"""
|
209 |
+
if not self.index or not self.chunks:
|
210 |
+
return []
|
211 |
+
|
212 |
+
# Normalize for cosine similarity
|
213 |
+
faiss.normalize_L2(query_embedding)
|
214 |
+
|
215 |
+
# Search
|
216 |
+
scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks)))
|
217 |
+
|
218 |
+
# Convert to results
|
219 |
+
results = []
|
220 |
+
|
221 |
+
for score, idx in zip(scores[0], indices[0]):
|
222 |
+
if idx < 0 or score < score_threshold:
|
223 |
+
continue
|
224 |
+
|
225 |
+
chunk_id = self.chunk_ids[idx]
|
226 |
+
chunk = self.chunks[chunk_id]
|
227 |
+
|
228 |
+
result = SearchResult(
|
229 |
+
chunk_id=chunk_id,
|
230 |
+
text=chunk['text'],
|
231 |
+
score=float(score),
|
232 |
+
metadata=chunk.get('metadata', {})
|
233 |
+
)
|
234 |
+
results.append(result)
|
235 |
+
|
236 |
+
return results
|
237 |
+
|
238 |
+
|
239 |
+
# Utility functions
|
240 |
+
def estimate_index_size(num_chunks: int, dimension: int = 384) -> float:
|
241 |
+
"""Estimate the size of the index in MB"""
|
242 |
+
# Rough estimation: 4 bytes per float * dimension * num_chunks
|
243 |
+
bytes_size = 4 * dimension * num_chunks
|
244 |
+
# Add overhead for index structure and metadata
|
245 |
+
overhead = 1.2
|
246 |
+
return (bytes_size * overhead) / (1024 * 1024)
|