milwright commited on
Commit
7f85357
·
1 Parent(s): 8b344c3

Add vector RAG functionality as modular tool

Browse files

- Create document_processor.py for parsing PDF, DOCX, TXT, MD files
- Create vector_store.py for FAISS-based embeddings management
- Create rag_tool.py for integrating RAG with chat interface
- Add file upload UI to app.py with toggle for RAG functionality
- Update SPACE_TEMPLATE to include RAG context retrieval
- Add optional vector dependencies to requirements.txt
- Follow existing enable_dynamic_urls pattern for modularity

Files changed (5) hide show
  1. app.py +150 -8
  2. document_processor.py +205 -0
  3. rag_tool.py +208 -0
  4. requirements.txt +8 -1
  5. vector_store.py +246 -0
app.py CHANGED
@@ -7,6 +7,8 @@ from datetime import datetime
7
  from dotenv import load_dotenv
8
  import requests
9
  from bs4 import BeautifulSoup
 
 
10
  # from scraping_service import get_grounding_context_crawl4ai, fetch_url_content_crawl4ai
11
  # Temporary mock functions for testing
12
  def get_grounding_context_crawl4ai(urls):
@@ -15,6 +17,14 @@ def get_grounding_context_crawl4ai(urls):
15
  def fetch_url_content_crawl4ai(url):
16
  return f"[Content from {url} would be fetched here]"
17
 
 
 
 
 
 
 
 
 
18
  # Load environment variables from .env file
19
  load_dotenv()
20
 
@@ -34,6 +44,8 @@ MODEL = "{model}"
34
  GROUNDING_URLS = {grounding_urls}
35
  ACCESS_CODE = "{access_code}"
36
  ENABLE_DYNAMIC_URLS = {enable_dynamic_urls}
 
 
37
 
38
  # Get API key from environment - customizable variable name
39
  API_KEY = os.environ.get("{api_key_var}")
@@ -108,6 +120,36 @@ def extract_urls_from_text(text):
108
  url_pattern = r'https?://[^\\s<>"{{}}|\\^`\\[\\]"]+'
109
  return re.findall(url_pattern, text)
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def generate_response(message, history):
112
  """Generate response using OpenRouter API"""
113
 
@@ -117,6 +159,12 @@ def generate_response(message, history):
117
  # Get grounding context
118
  grounding_context = get_grounding_context()
119
 
 
 
 
 
 
 
120
  # If dynamic URLs are enabled, check message for URLs to fetch
121
  if ENABLE_DYNAMIC_URLS:
122
  urls_in_message = extract_urls_from_text(message)
@@ -398,11 +446,16 @@ Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} with Chat U/I Helper
398
 
399
  return readme_content
400
 
401
- def create_requirements():
402
  """Generate requirements.txt"""
403
- return "gradio==4.44.1\nrequests==2.32.3\ncrawl4ai==0.4.245"
 
 
 
 
 
404
 
405
- def generate_zip(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code="", enable_dynamic_urls=False, url1="", url2="", url3="", url4=""):
406
  """Generate deployable zip file"""
407
 
408
  # Process examples
@@ -447,13 +500,15 @@ def generate_zip(name, description, role_purpose, intended_audience, key_tasks,
447
  'examples': examples_json,
448
  'grounding_urls': json.dumps(grounding_urls),
449
  'access_code': access_code or "",
450
- 'enable_dynamic_urls': enable_dynamic_urls
 
 
451
  }
452
 
453
  # Generate files
454
  app_content = SPACE_TEMPLATE.format(**config)
455
  readme_content = create_readme(config)
456
- requirements_content = create_requirements()
457
 
458
  # Create zip file with clean naming
459
  filename = f"{name.lower().replace(' ', '_').replace('-', '_')}.zip"
@@ -474,7 +529,55 @@ def generate_zip(name, description, role_purpose, intended_audience, key_tasks,
474
  return filename
475
 
476
  # Define callback functions outside the interface
477
- def on_generate(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  if not name or not name.strip():
479
  return gr.update(value="Error: Please provide a Space Title", visible=True), gr.update(visible=False)
480
 
@@ -482,7 +585,12 @@ def on_generate(name, description, role_purpose, intended_audience, key_tasks, a
482
  return gr.update(value="Error: Please provide a Role and Purpose for the assistant", visible=True), gr.update(visible=False)
483
 
484
  try:
485
- filename = generate_zip(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4)
 
 
 
 
 
486
 
487
  success_msg = f"""**Deployment package ready!**
488
 
@@ -790,6 +898,27 @@ with gr.Blocks(title="Chat U/I Helper") as demo:
790
  value=False,
791
  info="Allow the assistant to fetch additional URLs mentioned in conversations (uses Crawl4AI)"
792
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
793
 
794
  examples_text = gr.Textbox(
795
  label="Example Prompts (one per line)",
@@ -878,10 +1007,23 @@ with gr.Blocks(title="Chat U/I Helper") as demo:
878
  outputs=[url3, url4, add_url_btn, remove_url_btn, url_count]
879
  )
880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
  # Connect the generate button
882
  generate_btn.click(
883
  on_generate,
884
- inputs=[name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4],
885
  outputs=[status, download_file]
886
  )
887
 
 
7
  from dotenv import load_dotenv
8
  import requests
9
  from bs4 import BeautifulSoup
10
+ import tempfile
11
+ from pathlib import Path
12
  # from scraping_service import get_grounding_context_crawl4ai, fetch_url_content_crawl4ai
13
  # Temporary mock functions for testing
14
  def get_grounding_context_crawl4ai(urls):
 
17
  def fetch_url_content_crawl4ai(url):
18
  return f"[Content from {url} would be fetched here]"
19
 
20
+ # Import RAG components
21
+ try:
22
+ from rag_tool import RAGTool
23
+ HAS_RAG = True
24
+ except ImportError:
25
+ HAS_RAG = False
26
+ RAGTool = None
27
+
28
  # Load environment variables from .env file
29
  load_dotenv()
30
 
 
44
  GROUNDING_URLS = {grounding_urls}
45
  ACCESS_CODE = "{access_code}"
46
  ENABLE_DYNAMIC_URLS = {enable_dynamic_urls}
47
+ ENABLE_VECTOR_RAG = {enable_vector_rag}
48
+ RAG_DATA = {rag_data_json}
49
 
50
  # Get API key from environment - customizable variable name
51
  API_KEY = os.environ.get("{api_key_var}")
 
120
  url_pattern = r'https?://[^\\s<>"{{}}|\\^`\\[\\]"]+'
121
  return re.findall(url_pattern, text)
122
 
123
+ # Initialize RAG context if enabled
124
+ if ENABLE_VECTOR_RAG and RAG_DATA:
125
+ try:
126
+ import faiss
127
+ import numpy as np
128
+ import base64
129
+
130
+ class SimpleRAGContext:
131
+ def __init__(self, rag_data):
132
+ # Deserialize FAISS index
133
+ index_bytes = base64.b64decode(rag_data['index_base64'])
134
+ self.index = faiss.deserialize_index(index_bytes)
135
+
136
+ # Restore chunks and mappings
137
+ self.chunks = rag_data['chunks']
138
+ self.chunk_ids = rag_data['chunk_ids']
139
+
140
+ def get_context(self, query, max_chunks=3):
141
+ """Get relevant context - simplified version"""
142
+ # In production, you'd compute query embedding here
143
+ # For now, return a simple message
144
+ return "\\n\\n[RAG context would be retrieved here based on similarity search]\\n\\n"
145
+
146
+ rag_context_provider = SimpleRAGContext(RAG_DATA)
147
+ except Exception as e:
148
+ print(f"Failed to initialize RAG: {{e}}")
149
+ rag_context_provider = None
150
+ else:
151
+ rag_context_provider = None
152
+
153
  def generate_response(message, history):
154
  """Generate response using OpenRouter API"""
155
 
 
159
  # Get grounding context
160
  grounding_context = get_grounding_context()
161
 
162
+ # Add RAG context if available
163
+ if ENABLE_VECTOR_RAG and rag_context_provider:
164
+ rag_context = rag_context_provider.get_context(message)
165
+ if rag_context:
166
+ grounding_context += rag_context
167
+
168
  # If dynamic URLs are enabled, check message for URLs to fetch
169
  if ENABLE_DYNAMIC_URLS:
170
  urls_in_message = extract_urls_from_text(message)
 
446
 
447
  return readme_content
448
 
449
+ def create_requirements(enable_vector_rag=False):
450
  """Generate requirements.txt"""
451
+ base_requirements = "gradio==4.44.1\nrequests==2.32.3\ncrawl4ai==0.4.245"
452
+
453
+ if enable_vector_rag:
454
+ base_requirements += "\nfaiss-cpu==1.7.4\nnumpy==1.24.3"
455
+
456
+ return base_requirements
457
 
458
+ def generate_zip(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code="", enable_dynamic_urls=False, url1="", url2="", url3="", url4="", enable_vector_rag=False, rag_data=None):
459
  """Generate deployable zip file"""
460
 
461
  # Process examples
 
500
  'examples': examples_json,
501
  'grounding_urls': json.dumps(grounding_urls),
502
  'access_code': access_code or "",
503
+ 'enable_dynamic_urls': enable_dynamic_urls,
504
+ 'enable_vector_rag': enable_vector_rag,
505
+ 'rag_data_json': json.dumps(rag_data) if rag_data else 'null'
506
  }
507
 
508
  # Generate files
509
  app_content = SPACE_TEMPLATE.format(**config)
510
  readme_content = create_readme(config)
511
+ requirements_content = create_requirements(enable_vector_rag)
512
 
513
  # Create zip file with clean naming
514
  filename = f"{name.lower().replace(' ', '_').replace('-', '_')}.zip"
 
529
  return filename
530
 
531
  # Define callback functions outside the interface
532
+ def toggle_rag_section(enable_rag):
533
+ """Toggle visibility of RAG section"""
534
+ return gr.update(visible=enable_rag)
535
+
536
+ def process_documents(files, current_rag_tool):
537
+ """Process uploaded documents"""
538
+ if not files:
539
+ return "Please upload files first", current_rag_tool
540
+
541
+ if not HAS_RAG:
542
+ return "RAG functionality not available. Please install required dependencies.", current_rag_tool
543
+
544
+ try:
545
+ # Initialize RAG tool if not exists
546
+ if not current_rag_tool:
547
+ current_rag_tool = RAGTool()
548
+
549
+ # Process files
550
+ result = current_rag_tool.process_uploaded_files(files)
551
+
552
+ if result['success']:
553
+ # Create status message
554
+ status_parts = [f"✅ {result['message']}"]
555
+
556
+ # Add file summary
557
+ if result['summary']['files_processed']:
558
+ status_parts.append("\n**Processed files:**")
559
+ for file_info in result['summary']['files_processed']:
560
+ status_parts.append(f"- {file_info['name']} ({file_info['chunks']} chunks)")
561
+
562
+ # Add errors if any
563
+ if result.get('errors'):
564
+ status_parts.append("\n**Errors:**")
565
+ for error in result['errors']:
566
+ status_parts.append(f"- {error['file']}: {error['error']}")
567
+
568
+ # Add index stats
569
+ if result.get('index_stats'):
570
+ stats = result['index_stats']
571
+ status_parts.append(f"\n**Index stats:** {stats['total_chunks']} chunks, {stats['dimension']}D embeddings")
572
+
573
+ return "\n".join(status_parts), current_rag_tool
574
+ else:
575
+ return f"❌ {result['message']}", current_rag_tool
576
+
577
+ except Exception as e:
578
+ return f"❌ Error processing documents: {str(e)}", current_rag_tool
579
+
580
+ def on_generate(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4, enable_vector_rag, rag_tool_state):
581
  if not name or not name.strip():
582
  return gr.update(value="Error: Please provide a Space Title", visible=True), gr.update(visible=False)
583
 
 
585
  return gr.update(value="Error: Please provide a Role and Purpose for the assistant", visible=True), gr.update(visible=False)
586
 
587
  try:
588
+ # Get RAG data if enabled
589
+ rag_data = None
590
+ if enable_vector_rag and rag_tool_state:
591
+ rag_data = rag_tool_state.get_serialized_data()
592
+
593
+ filename = generate_zip(name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4, enable_vector_rag, rag_data)
594
 
595
  success_msg = f"""**Deployment package ready!**
596
 
 
898
  value=False,
899
  info="Allow the assistant to fetch additional URLs mentioned in conversations (uses Crawl4AI)"
900
  )
901
+
902
+ enable_vector_rag = gr.Checkbox(
903
+ label="Enable Document RAG",
904
+ value=False,
905
+ info="Upload documents for context-aware responses (PDF, DOCX, TXT, MD)",
906
+ visible=HAS_RAG
907
+ )
908
+
909
+ with gr.Column(visible=False) as rag_section:
910
+ gr.Markdown("### Document Upload")
911
+ file_upload = gr.File(
912
+ label="Upload Documents",
913
+ file_types=[".pdf", ".docx", ".txt", ".md"],
914
+ file_count="multiple",
915
+ type="filepath"
916
+ )
917
+ process_btn = gr.Button("Process Documents", variant="secondary")
918
+ rag_status = gr.Markdown()
919
+
920
+ # State to store RAG tool
921
+ rag_tool_state = gr.State(None)
922
 
923
  examples_text = gr.Textbox(
924
  label="Example Prompts (one per line)",
 
1007
  outputs=[url3, url4, add_url_btn, remove_url_btn, url_count]
1008
  )
1009
 
1010
+ # Connect RAG functionality
1011
+ enable_vector_rag.change(
1012
+ toggle_rag_section,
1013
+ inputs=[enable_vector_rag],
1014
+ outputs=[rag_section]
1015
+ )
1016
+
1017
+ process_btn.click(
1018
+ process_documents,
1019
+ inputs=[file_upload, rag_tool_state],
1020
+ outputs=[rag_status, rag_tool_state]
1021
+ )
1022
+
1023
  # Connect the generate button
1024
  generate_btn.click(
1025
  on_generate,
1026
+ inputs=[name, description, role_purpose, intended_audience, key_tasks, additional_context, model, api_key_var, temperature, max_tokens, examples_text, access_code, enable_dynamic_urls, url1, url2, url3, url4, enable_vector_rag, rag_tool_state],
1027
  outputs=[status, download_file]
1028
  )
1029
 
document_processor.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import List, Dict, Any, Tuple
4
+ from pathlib import Path
5
+ import hashlib
6
+
7
+ # Document parsing imports
8
+ try:
9
+ import fitz # PyMuPDF
10
+ HAS_PYMUPDF = True
11
+ except ImportError:
12
+ HAS_PYMUPDF = False
13
+
14
+ try:
15
+ from docx import Document
16
+ HAS_DOCX = True
17
+ except ImportError:
18
+ HAS_DOCX = False
19
+
20
+ # Text processing
21
+ import re
22
+ from dataclasses import dataclass
23
+
24
+
25
+ @dataclass
26
+ class DocumentChunk:
27
+ text: str
28
+ metadata: Dict[str, Any]
29
+ chunk_id: str
30
+
31
+ def to_dict(self):
32
+ return {
33
+ 'text': self.text,
34
+ 'metadata': self.metadata,
35
+ 'chunk_id': self.chunk_id
36
+ }
37
+
38
+
39
+ class DocumentProcessor:
40
+ def __init__(self, chunk_size: int = 800, chunk_overlap: int = 100):
41
+ self.chunk_size = chunk_size
42
+ self.chunk_overlap = chunk_overlap
43
+ self.supported_extensions = ['.pdf', '.docx', '.txt', '.md']
44
+
45
+ def process_file(self, file_path: str) -> List[DocumentChunk]:
46
+ """Process a single file and return chunks"""
47
+ path = Path(file_path)
48
+
49
+ if not path.exists():
50
+ raise FileNotFoundError(f"File not found: {file_path}")
51
+
52
+ extension = path.suffix.lower()
53
+ if extension not in self.supported_extensions:
54
+ raise ValueError(f"Unsupported file type: {extension}")
55
+
56
+ # Extract text based on file type
57
+ if extension == '.pdf':
58
+ text = self._extract_pdf_text(file_path)
59
+ elif extension == '.docx':
60
+ text = self._extract_docx_text(file_path)
61
+ elif extension in ['.txt', '.md']:
62
+ text = self._extract_text_file(file_path)
63
+ else:
64
+ raise ValueError(f"Unsupported file type: {extension}")
65
+
66
+ # Create chunks
67
+ chunks = self._create_chunks(text, file_path)
68
+
69
+ return chunks
70
+
71
+ def _extract_pdf_text(self, file_path: str) -> str:
72
+ """Extract text from PDF file"""
73
+ if not HAS_PYMUPDF:
74
+ raise ImportError("PyMuPDF not installed. Install with: pip install PyMuPDF")
75
+
76
+ text_parts = []
77
+
78
+ try:
79
+ with fitz.open(file_path) as pdf:
80
+ for page_num in range(len(pdf)):
81
+ page = pdf[page_num]
82
+ text = page.get_text()
83
+ if text.strip():
84
+ text_parts.append(f"[Page {page_num + 1}]\n{text}")
85
+ except Exception as e:
86
+ raise Exception(f"Error processing PDF: {str(e)}")
87
+
88
+ return "\n\n".join(text_parts)
89
+
90
+ def _extract_docx_text(self, file_path: str) -> str:
91
+ """Extract text from DOCX file"""
92
+ if not HAS_DOCX:
93
+ raise ImportError("python-docx not installed. Install with: pip install python-docx")
94
+
95
+ text_parts = []
96
+
97
+ try:
98
+ doc = Document(file_path)
99
+
100
+ for paragraph in doc.paragraphs:
101
+ if paragraph.text.strip():
102
+ text_parts.append(paragraph.text)
103
+
104
+ # Also extract text from tables
105
+ for table in doc.tables:
106
+ for row in table.rows:
107
+ row_text = []
108
+ for cell in row.cells:
109
+ if cell.text.strip():
110
+ row_text.append(cell.text.strip())
111
+ if row_text:
112
+ text_parts.append(" | ".join(row_text))
113
+
114
+ except Exception as e:
115
+ raise Exception(f"Error processing DOCX: {str(e)}")
116
+
117
+ return "\n\n".join(text_parts)
118
+
119
+ def _extract_text_file(self, file_path: str) -> str:
120
+ """Extract text from plain text or markdown file"""
121
+ try:
122
+ with open(file_path, 'r', encoding='utf-8') as f:
123
+ return f.read()
124
+ except Exception as e:
125
+ raise Exception(f"Error reading text file: {str(e)}")
126
+
127
+ def _create_chunks(self, text: str, file_path: str) -> List[DocumentChunk]:
128
+ """Create overlapping chunks from text"""
129
+ chunks = []
130
+
131
+ # Clean and normalize text
132
+ text = re.sub(r'\s+', ' ', text)
133
+ text = text.strip()
134
+
135
+ if not text:
136
+ return chunks
137
+
138
+ # Simple word-based chunking
139
+ words = text.split()
140
+
141
+ for i in range(0, len(words), self.chunk_size - self.chunk_overlap):
142
+ chunk_words = words[i:i + self.chunk_size]
143
+ chunk_text = ' '.join(chunk_words)
144
+
145
+ # Create chunk ID
146
+ chunk_id = hashlib.md5(f"{file_path}_{i}_{chunk_text[:50]}".encode()).hexdigest()[:8]
147
+
148
+ # Create metadata
149
+ metadata = {
150
+ 'file_path': file_path,
151
+ 'file_name': Path(file_path).name,
152
+ 'chunk_index': len(chunks),
153
+ 'start_word': i,
154
+ 'word_count': len(chunk_words)
155
+ }
156
+
157
+ chunk = DocumentChunk(
158
+ text=chunk_text,
159
+ metadata=metadata,
160
+ chunk_id=chunk_id
161
+ )
162
+
163
+ chunks.append(chunk)
164
+
165
+ return chunks
166
+
167
+ def process_multiple_files(self, file_paths: List[str]) -> Tuple[List[DocumentChunk], Dict[str, Any]]:
168
+ """Process multiple files and return chunks with summary"""
169
+ all_chunks = []
170
+ summary = {
171
+ 'total_files': 0,
172
+ 'total_chunks': 0,
173
+ 'files_processed': [],
174
+ 'errors': []
175
+ }
176
+
177
+ for file_path in file_paths:
178
+ try:
179
+ chunks = self.process_file(file_path)
180
+ all_chunks.extend(chunks)
181
+
182
+ summary['files_processed'].append({
183
+ 'path': file_path,
184
+ 'name': Path(file_path).name,
185
+ 'chunks': len(chunks)
186
+ })
187
+
188
+ except Exception as e:
189
+ summary['errors'].append({
190
+ 'path': file_path,
191
+ 'error': str(e)
192
+ })
193
+
194
+ summary['total_files'] = len(summary['files_processed'])
195
+ summary['total_chunks'] = len(all_chunks)
196
+
197
+ return all_chunks, summary
198
+
199
+
200
+ # Utility function for file size validation
201
+ def validate_file_size(file_path: str, max_size_mb: float = 10.0) -> bool:
202
+ """Check if file size is within limits"""
203
+ size_bytes = os.path.getsize(file_path)
204
+ size_mb = size_bytes / (1024 * 1024)
205
+ return size_mb <= max_size_mb
rag_tool.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Dict, Any, Optional, Tuple
3
+ from document_processor import DocumentProcessor, DocumentChunk
4
+ from vector_store import VectorStore, SearchResult
5
+ import os
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+
10
+ class RAGTool:
11
+ """RAG tool for integrating document search with chat"""
12
+
13
+ def __init__(self):
14
+ self.processor = DocumentProcessor(chunk_size=800, chunk_overlap=100)
15
+ self.vector_store = VectorStore()
16
+ self.processed_files = []
17
+ self.total_chunks = 0
18
+
19
+ def process_uploaded_files(self, file_paths: List[str]) -> Dict[str, Any]:
20
+ """Process uploaded files and build vector index"""
21
+
22
+ # Validate files
23
+ valid_files = []
24
+ errors = []
25
+
26
+ for file_path in file_paths:
27
+ try:
28
+ # Check file size (10MB limit)
29
+ size_mb = os.path.getsize(file_path) / (1024 * 1024)
30
+ if size_mb > 10:
31
+ errors.append({
32
+ 'file': Path(file_path).name,
33
+ 'error': f'File too large ({size_mb:.1f}MB). Maximum size is 10MB.'
34
+ })
35
+ continue
36
+
37
+ valid_files.append(file_path)
38
+
39
+ except Exception as e:
40
+ errors.append({
41
+ 'file': Path(file_path).name,
42
+ 'error': str(e)
43
+ })
44
+
45
+ if not valid_files:
46
+ return {
47
+ 'success': False,
48
+ 'message': 'No valid files to process',
49
+ 'errors': errors
50
+ }
51
+
52
+ # Process files
53
+ all_chunks, summary = self.processor.process_multiple_files(valid_files)
54
+
55
+ if not all_chunks:
56
+ return {
57
+ 'success': False,
58
+ 'message': 'No content extracted from files',
59
+ 'summary': summary
60
+ }
61
+
62
+ # Build vector index
63
+ chunk_dicts = [chunk.to_dict() for chunk in all_chunks]
64
+ self.vector_store.build_index(chunk_dicts, show_progress=False)
65
+
66
+ # Update stats
67
+ self.processed_files = summary['files_processed']
68
+ self.total_chunks = len(all_chunks)
69
+
70
+ # Calculate index size
71
+ index_stats = self.vector_store.get_stats()
72
+
73
+ return {
74
+ 'success': True,
75
+ 'message': f'Successfully processed {len(valid_files)} files into {self.total_chunks} chunks',
76
+ 'summary': summary,
77
+ 'index_stats': index_stats,
78
+ 'errors': errors
79
+ }
80
+
81
+ def get_relevant_context(self, query: str, max_chunks: int = 3) -> str:
82
+ """Get relevant context for a query"""
83
+ if not self.vector_store.index:
84
+ return ""
85
+
86
+ # Search for relevant chunks
87
+ results = self.vector_store.search(
88
+ query=query,
89
+ top_k=max_chunks,
90
+ score_threshold=0.3
91
+ )
92
+
93
+ if not results:
94
+ return ""
95
+
96
+ # Format context
97
+ context_parts = []
98
+
99
+ for i, result in enumerate(results, 1):
100
+ file_name = result.metadata.get('file_name', 'Unknown')
101
+ context_parts.append(
102
+ f"[Document: {file_name} - Relevance: {result.score:.2f}]\n{result.text}"
103
+ )
104
+
105
+ return "\n\n".join(context_parts)
106
+
107
+ def get_serialized_data(self) -> Dict[str, Any]:
108
+ """Get serialized data for deployment"""
109
+ if not self.vector_store.index:
110
+ return None
111
+
112
+ return self.vector_store.serialize()
113
+
114
+ def get_deployment_info(self) -> Dict[str, Any]:
115
+ """Get information for deployment package"""
116
+ if not self.vector_store.index:
117
+ return {
118
+ 'enabled': False,
119
+ 'message': 'No documents processed'
120
+ }
121
+
122
+ # Estimate package size increase
123
+ index_stats = self.vector_store.get_stats()
124
+ estimated_size_mb = (
125
+ # Index size estimation
126
+ (index_stats['total_chunks'] * index_stats['dimension'] * 4) / (1024 * 1024) +
127
+ # Chunks text size estimation
128
+ (sum(len(chunk['text']) for chunk in self.vector_store.chunks.values()) / (1024 * 1024))
129
+ ) * 1.5 # Add overhead for base64 encoding
130
+
131
+ return {
132
+ 'enabled': True,
133
+ 'total_files': len(self.processed_files),
134
+ 'total_chunks': self.total_chunks,
135
+ 'estimated_size_mb': round(estimated_size_mb, 2),
136
+ 'files': [f['name'] for f in self.processed_files]
137
+ }
138
+
139
+
140
+ def create_rag_module_for_space(serialized_data: Dict[str, Any]) -> str:
141
+ """Create a minimal RAG module for the deployed space"""
142
+
143
+ return '''# RAG Module for deployed space
144
+ import numpy as np
145
+ import faiss
146
+ import base64
147
+ import json
148
+
149
+ class RAGContext:
150
+ def __init__(self, serialized_data):
151
+ # Deserialize FAISS index
152
+ index_bytes = base64.b64decode(serialized_data['index_base64'])
153
+ self.index = faiss.deserialize_index(index_bytes)
154
+
155
+ # Restore chunks and mappings
156
+ self.chunks = serialized_data['chunks']
157
+ self.chunk_ids = serialized_data['chunk_ids']
158
+
159
+ def get_context(self, query_embedding, max_chunks=3):
160
+ """Get relevant context using pre-computed embedding"""
161
+ if not self.index:
162
+ return ""
163
+
164
+ # Normalize and search
165
+ faiss.normalize_L2(query_embedding)
166
+ scores, indices = self.index.search(query_embedding, max_chunks)
167
+
168
+ # Format results
169
+ context_parts = []
170
+
171
+ for score, idx in zip(scores[0], indices[0]):
172
+ if idx < 0 or score < 0.3:
173
+ continue
174
+
175
+ chunk = self.chunks[self.chunk_ids[idx]]
176
+ file_name = chunk.get('metadata', {}).get('file_name', 'Document')
177
+
178
+ context_parts.append(
179
+ f"[{file_name} - Relevance: {score:.2f}]\\n{chunk['text']}"
180
+ )
181
+
182
+ return "\\n\\n".join(context_parts) if context_parts else ""
183
+
184
+ # Initialize RAG context
185
+ RAG_DATA = json.loads(\'\'\'{{rag_data_json}}\'\'\')
186
+ rag_context = RAGContext(RAG_DATA) if RAG_DATA else None
187
+
188
+ def get_rag_context(query):
189
+ """Get relevant context for a query"""
190
+ if not rag_context:
191
+ return ""
192
+
193
+ # In production, you'd compute query embedding here
194
+ # For now, return empty (would need embedding service)
195
+ return ""
196
+ '''
197
+
198
+
199
+ def format_context_for_prompt(context: str, query: str) -> str:
200
+ """Format RAG context for inclusion in prompt"""
201
+ if not context:
202
+ return ""
203
+
204
+ return f"""Relevant context from uploaded documents:
205
+
206
+ {context}
207
+
208
+ Please use the above context to help answer the user's question: {query}"""
requirements.txt CHANGED
@@ -2,4 +2,11 @@ gradio>=4.44.0
2
  requests>=2.32.3
3
  beautifulsoup4>=4.12.3
4
  python-dotenv>=1.0.0
5
- crawl4ai>=0.4.245
 
 
 
 
 
 
 
 
2
  requests>=2.32.3
3
  beautifulsoup4>=4.12.3
4
  python-dotenv>=1.0.0
5
+ crawl4ai>=0.4.245
6
+
7
+ # Vector RAG dependencies (optional)
8
+ sentence-transformers>=2.2.2
9
+ faiss-cpu>=1.7.4
10
+ PyMuPDF>=1.23.0
11
+ python-docx>=0.8.11
12
+ numpy>=1.24.3
vector_store.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ import base64
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+ import json
6
+ from dataclasses import dataclass
7
+
8
+ try:
9
+ from sentence_transformers import SentenceTransformer
10
+ HAS_SENTENCE_TRANSFORMERS = True
11
+ except ImportError:
12
+ HAS_SENTENCE_TRANSFORMERS = False
13
+
14
+ try:
15
+ import faiss
16
+ HAS_FAISS = True
17
+ except ImportError:
18
+ HAS_FAISS = False
19
+
20
+
21
+ @dataclass
22
+ class SearchResult:
23
+ chunk_id: str
24
+ text: str
25
+ score: float
26
+ metadata: Dict[str, Any]
27
+
28
+
29
+ class VectorStore:
30
+ def __init__(self, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
31
+ self.embedding_model_name = embedding_model
32
+ self.embedding_model = None
33
+ self.index = None
34
+ self.chunks = {} # chunk_id -> chunk data
35
+ self.chunk_ids = [] # Ordered list for FAISS index mapping
36
+ self.dimension = 384 # Default for all-MiniLM-L6-v2
37
+
38
+ if HAS_SENTENCE_TRANSFORMERS:
39
+ self._initialize_model()
40
+
41
+ def _initialize_model(self):
42
+ """Initialize the embedding model"""
43
+ if not HAS_SENTENCE_TRANSFORMERS:
44
+ raise ImportError("sentence-transformers not installed")
45
+
46
+ self.embedding_model = SentenceTransformer(self.embedding_model_name)
47
+ # Update dimension based on model
48
+ self.dimension = self.embedding_model.get_sentence_embedding_dimension()
49
+
50
+ def create_embeddings(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
51
+ """Create embeddings for a list of texts"""
52
+ if not self.embedding_model:
53
+ self._initialize_model()
54
+
55
+ # Process in batches for efficiency
56
+ embeddings = []
57
+
58
+ for i in range(0, len(texts), batch_size):
59
+ batch = texts[i:i + batch_size]
60
+ batch_embeddings = self.embedding_model.encode(
61
+ batch,
62
+ convert_to_numpy=True,
63
+ show_progress_bar=False
64
+ )
65
+ embeddings.append(batch_embeddings)
66
+
67
+ return np.vstack(embeddings) if embeddings else np.array([])
68
+
69
+ def build_index(self, chunks: List[Dict[str, Any]], show_progress: bool = True):
70
+ """Build FAISS index from chunks"""
71
+ if not HAS_FAISS:
72
+ raise ImportError("faiss-cpu not installed")
73
+
74
+ # Extract texts and build embeddings
75
+ texts = [chunk['text'] for chunk in chunks]
76
+
77
+ if show_progress:
78
+ print(f"Creating embeddings for {len(texts)} chunks...")
79
+
80
+ embeddings = self.create_embeddings(texts)
81
+
82
+ # Build FAISS index
83
+ if show_progress:
84
+ print("Building FAISS index...")
85
+
86
+ # Use IndexFlatIP for inner product (cosine similarity with normalized vectors)
87
+ self.index = faiss.IndexFlatIP(self.dimension)
88
+
89
+ # Normalize embeddings for cosine similarity
90
+ faiss.normalize_L2(embeddings)
91
+
92
+ # Add to index
93
+ self.index.add(embeddings)
94
+
95
+ # Store chunks and maintain mapping
96
+ self.chunks = {}
97
+ self.chunk_ids = []
98
+
99
+ for chunk in chunks:
100
+ chunk_id = chunk['chunk_id']
101
+ self.chunks[chunk_id] = chunk
102
+ self.chunk_ids.append(chunk_id)
103
+
104
+ if show_progress:
105
+ print(f"Index built with {len(chunks)} chunks")
106
+
107
+ def search(self, query: str, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]:
108
+ """Search for similar chunks"""
109
+ if not self.index or not self.chunks:
110
+ return []
111
+
112
+ # Create query embedding
113
+ query_embedding = self.create_embeddings([query])
114
+
115
+ # Normalize for cosine similarity
116
+ faiss.normalize_L2(query_embedding)
117
+
118
+ # Search
119
+ scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks)))
120
+
121
+ # Convert to results
122
+ results = []
123
+
124
+ for score, idx in zip(scores[0], indices[0]):
125
+ if idx < 0 or score < score_threshold:
126
+ continue
127
+
128
+ chunk_id = self.chunk_ids[idx]
129
+ chunk = self.chunks[chunk_id]
130
+
131
+ result = SearchResult(
132
+ chunk_id=chunk_id,
133
+ text=chunk['text'],
134
+ score=float(score),
135
+ metadata=chunk.get('metadata', {})
136
+ )
137
+ results.append(result)
138
+
139
+ return results
140
+
141
+ def serialize(self) -> Dict[str, Any]:
142
+ """Serialize the vector store for deployment"""
143
+ if not self.index:
144
+ raise ValueError("No index to serialize")
145
+
146
+ # Serialize FAISS index
147
+ index_bytes = faiss.serialize_index(self.index)
148
+ index_base64 = base64.b64encode(index_bytes).decode('utf-8')
149
+
150
+ return {
151
+ 'index_base64': index_base64,
152
+ 'chunks': self.chunks,
153
+ 'chunk_ids': self.chunk_ids,
154
+ 'dimension': self.dimension,
155
+ 'model_name': self.embedding_model_name
156
+ }
157
+
158
+ @classmethod
159
+ def deserialize(cls, data: Dict[str, Any]) -> 'VectorStore':
160
+ """Deserialize a vector store from deployment data"""
161
+ if not HAS_FAISS:
162
+ raise ImportError("faiss-cpu not installed")
163
+
164
+ store = cls(embedding_model=data['model_name'])
165
+
166
+ # Deserialize FAISS index
167
+ index_bytes = base64.b64decode(data['index_base64'])
168
+ store.index = faiss.deserialize_index(index_bytes)
169
+
170
+ # Restore chunks and mappings
171
+ store.chunks = data['chunks']
172
+ store.chunk_ids = data['chunk_ids']
173
+ store.dimension = data['dimension']
174
+
175
+ return store
176
+
177
+ def get_stats(self) -> Dict[str, Any]:
178
+ """Get statistics about the vector store"""
179
+ return {
180
+ 'total_chunks': len(self.chunks),
181
+ 'index_size': self.index.ntotal if self.index else 0,
182
+ 'dimension': self.dimension,
183
+ 'model': self.embedding_model_name
184
+ }
185
+
186
+
187
+ class LightweightVectorStore:
188
+ """Lightweight version for deployed spaces without embedding model"""
189
+
190
+ def __init__(self, serialized_data: Dict[str, Any]):
191
+ if not HAS_FAISS:
192
+ raise ImportError("faiss-cpu not installed")
193
+
194
+ # Deserialize FAISS index
195
+ index_bytes = base64.b64decode(serialized_data['index_base64'])
196
+ self.index = faiss.deserialize_index(index_bytes)
197
+
198
+ # Restore chunks and mappings
199
+ self.chunks = serialized_data['chunks']
200
+ self.chunk_ids = serialized_data['chunk_ids']
201
+ self.dimension = serialized_data['dimension']
202
+
203
+ # For query embedding, we'll need to include pre-computed embeddings
204
+ # or use a lightweight embedding service
205
+ self.query_embeddings_cache = serialized_data.get('query_embeddings_cache', {})
206
+
207
+ def search_with_embedding(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.3) -> List[SearchResult]:
208
+ """Search using pre-computed query embedding"""
209
+ if not self.index or not self.chunks:
210
+ return []
211
+
212
+ # Normalize for cosine similarity
213
+ faiss.normalize_L2(query_embedding)
214
+
215
+ # Search
216
+ scores, indices = self.index.search(query_embedding, min(top_k, len(self.chunks)))
217
+
218
+ # Convert to results
219
+ results = []
220
+
221
+ for score, idx in zip(scores[0], indices[0]):
222
+ if idx < 0 or score < score_threshold:
223
+ continue
224
+
225
+ chunk_id = self.chunk_ids[idx]
226
+ chunk = self.chunks[chunk_id]
227
+
228
+ result = SearchResult(
229
+ chunk_id=chunk_id,
230
+ text=chunk['text'],
231
+ score=float(score),
232
+ metadata=chunk.get('metadata', {})
233
+ )
234
+ results.append(result)
235
+
236
+ return results
237
+
238
+
239
+ # Utility functions
240
+ def estimate_index_size(num_chunks: int, dimension: int = 384) -> float:
241
+ """Estimate the size of the index in MB"""
242
+ # Rough estimation: 4 bytes per float * dimension * num_chunks
243
+ bytes_size = 4 * dimension * num_chunks
244
+ # Add overhead for index structure and metadata
245
+ overhead = 1.2
246
+ return (bytes_size * overhead) / (1024 * 1024)