File size: 11,280 Bytes
3c18172 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
import os
import numpy as np
import faiss
import torch
from transformers import AutoTokenizer, AutoModel
from pypdf import PdfReader
from typing import List, Dict, Tuple
import re
# Constants
CHUNK_SIZE = 300 # tokens
CHUNK_OVERLAP = 50 # tokens
class DocumentProcessor:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
"""Initialize document processor with embedding model"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def load_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n\n"
return text
def load_text(self, text_path: str) -> str:
"""Load text from a file"""
with open(text_path, "r", encoding="utf-8") as f:
return f.read()
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks with overlap"""
# Simple sentence-based chunking
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
sentence_tokens = self.tokenizer.tokenize(sentence)
sentence_token_count = len(sentence_tokens)
if current_size + sentence_token_count > CHUNK_SIZE and current_chunk:
# Save current chunk
chunks.append(" ".join(current_chunk))
# Start new chunk with overlap
overlap_size = min(CHUNK_OVERLAP, len(current_chunk))
current_chunk = current_chunk[-overlap_size:] + [sentence]
current_size = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk)
else:
current_chunk.append(sentence)
current_size += sentence_token_count
# Add the last chunk if not empty
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def get_embeddings(self, texts: List[str]) -> np.ndarray:
"""Convert text chunks to embeddings"""
embeddings = []
for text in texts:
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# Use mean pooling to get sentence embedding
token_embeddings = outputs.last_hidden_state
attention_mask = inputs["attention_mask"]
# Mask padded tokens
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
embeddings_sum = torch.sum(token_embeddings * input_mask_expanded, 1)
mask_sum = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embedding = (embeddings_sum / mask_sum).squeeze().cpu().numpy()
embeddings.append(embedding)
return np.array(embeddings)
class RAGSystem:
def __init__(self):
"""Initialize RAG system"""
self.doc_processor = DocumentProcessor()
self.index = None
self.chunks = []
self.sources = {}
def add_document(self, file_path: str, source_id: str):
"""Process and add document to the RAG system"""
if file_path.lower().endswith('.pdf'):
text = self.doc_processor.load_pdf(file_path)
else:
text = self.doc_processor.load_text(file_path)
# Chunk the document
chunks = self.doc_processor.chunk_text(text)
# Store the chunks with source info
start_idx = len(self.chunks)
for i, chunk in enumerate(chunks):
chunk_id = start_idx + i
self.chunks.append(chunk)
self.sources[chunk_id] = {
'source': source_id,
'file_path': file_path,
'chunk_index': i
}
# Create or update the FAISS index
self._update_index()
return len(chunks)
def _update_index(self):
"""Update or create the FAISS index"""
embeddings = self.doc_processor.get_embeddings(self.chunks)
vector_dimension = embeddings.shape[1]
if self.index is None:
# Initialize FAISS index - using L2 distance
self.index = faiss.IndexFlatL2(vector_dimension)
# Add embeddings to index
faiss.normalize_L2(embeddings)
self.index.add(embeddings)
def query(self, query_text: str, top_k: int = 3) -> List[Dict]:
"""Retrieve relevant chunks for a query"""
# Check if we have a valid index and chunks
if self.index is None or len(self.chunks) == 0:
print("Warning: No index or chunks available")
return []
# Get query embedding
query_embedding = self.doc_processor.get_embeddings([query_text])
faiss.normalize_L2(query_embedding)
# Search the index - limit to the actual number of chunks we have
actual_k = min(top_k, len(self.chunks))
if actual_k == 0:
return []
distances, indices = self.index.search(query_embedding, actual_k)
# Format results
results = []
for i, idx in enumerate(indices[0]):
# FAISS may return -1 if not enough results or might return out-of-bounds indices
if idx >= 0 and idx < len(self.chunks):
# Safety check that the idx is a valid index in self.sources
if idx in self.sources:
source_info = self.sources[idx]
else:
# Create a default source if there's a mismatch
source_info = {
'source': "Unknown",
'file_path': "Unknown",
'chunk_index': 0
}
results.append({
'chunk': self.chunks[idx],
'score': float(1 / (1 + distances[0][i])), # Convert distance to similarity score
'source': source_info
})
return results
def get_context_for_query(self, query: str, top_k: int = 3) -> str:
"""Get formatted context for a query to include in a prompt"""
results = self.query(query, top_k)
if not results:
return "No relevant information found."
context = "RELEVANT INFORMATION:\n\n"
for i, result in enumerate(results):
context += f"[{i+1}] From {result['source']['source']}:\n"
context += f"{result['chunk']}\n\n"
return context
def save_index(self, directory: str):
"""Save the FAISS index and chunks to disk"""
os.makedirs(directory, exist_ok=True)
# Save FAISS index - fix the write function
index_path = os.path.join(directory, "index.faiss")
try:
# Try using standard method
faiss.write_index(self.index, index_path)
except AttributeError:
# Fallback for faiss-cpu where functions might be in different location
import faiss.swigfaiss as swigfaiss
swigfaiss.write_index(self.index, index_path)
# Save chunks and sources
np.save(os.path.join(directory, "chunks.npy"), np.array(self.chunks, dtype=object))
# Save sources directly as JSON for more reliable loading
import json
with open(os.path.join(directory, "sources.json"), "w") as f:
# Convert int keys to strings for JSON serialization
sources_dict = {str(k): v for k, v in self.sources.items()}
json.dump(sources_dict, f)
def load_index(self, directory: str):
"""Load the FAISS index and chunks from disk"""
# Load FAISS index - fix the read function
index_path = os.path.join(directory, "index.faiss")
if os.path.exists(index_path):
try:
# Try using io module if available
self.index = faiss.read_index(index_path)
except AttributeError:
# Fallback for faiss-cpu where functions might be in different location
import faiss.swigfaiss as swigfaiss
self.index = swigfaiss.read_index(index_path)
else:
raise FileNotFoundError(f"Index file not found at {index_path}")
# Load chunks
chunks_path = os.path.join(directory, "chunks.npy")
if not os.path.exists(chunks_path):
raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
self.chunks = np.load(chunks_path, allow_pickle=True).tolist()
# Load sources from JSON
import json
sources_json_path = os.path.join(directory, "sources.json")
# Try the new JSON format first
if os.path.exists(sources_json_path):
with open(sources_json_path, "r") as f:
sources_dict = json.load(f)
# Convert string keys back to integers
self.sources = {int(k): v for k, v in sources_dict.items()}
else:
# Legacy support for old numpy format
sources_path = os.path.join(directory, "sources.npy")
if not os.path.exists(sources_path):
raise FileNotFoundError(f"Sources file not found at either {sources_json_path} or {sources_path}")
try:
# Try loading as array with dict inside
loaded_data = np.load(sources_path, allow_pickle=True)
if isinstance(loaded_data[0], dict):
self.sources = loaded_data[0]
else:
# If it's not a dictionary, try multiple approaches
self.sources = dict(enumerate(loaded_data))
except Exception as e:
raise RuntimeError(f"Failed to load sources: {e}")
# Example usage
if __name__ == "__main__":
# Create RAG system
rag = RAGSystem()
# Add documents
linkedin_count = rag.add_document("me/linkedin.pdf", "LinkedIn Profile")
summary_count = rag.add_document("me/summary.txt", "Professional Summary")
print(f"Added {linkedin_count} chunks from LinkedIn profile")
print(f"Added {summary_count} chunks from Professional Summary")
# Save the index
rag.save_index("me/rag_index")
# Test query
query = "What are Sagarnil's technical skills?"
context = rag.get_context_for_query(query)
print(f"\nQuery: {query}\n")
print(context) |