import os import numpy as np import faiss import torch from transformers import AutoTokenizer, AutoModel from pypdf import PdfReader from typing import List, Dict, Tuple import re # Constants CHUNK_SIZE = 300 # tokens CHUNK_OVERLAP = 50 # tokens class DocumentProcessor: def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"): """Initialize document processor with embedding model""" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) def load_pdf(self, pdf_path: str) -> str: """Extract text from PDF file""" reader = PdfReader(pdf_path) text = "" for page in reader.pages: page_text = page.extract_text() if page_text: text += page_text + "\n\n" return text def load_text(self, text_path: str) -> str: """Load text from a file""" with open(text_path, "r", encoding="utf-8") as f: return f.read() def chunk_text(self, text: str) -> List[str]: """Split text into chunks with overlap""" # Simple sentence-based chunking sentences = re.split(r'(?<=[.!?])\s+', text) chunks = [] current_chunk = [] current_size = 0 for sentence in sentences: sentence_tokens = self.tokenizer.tokenize(sentence) sentence_token_count = len(sentence_tokens) if current_size + sentence_token_count > CHUNK_SIZE and current_chunk: # Save current chunk chunks.append(" ".join(current_chunk)) # Start new chunk with overlap overlap_size = min(CHUNK_OVERLAP, len(current_chunk)) current_chunk = current_chunk[-overlap_size:] + [sentence] current_size = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk) else: current_chunk.append(sentence) current_size += sentence_token_count # Add the last chunk if not empty if current_chunk: chunks.append(" ".join(current_chunk)) return chunks def get_embeddings(self, texts: List[str]) -> np.ndarray: """Convert text chunks to embeddings""" embeddings = [] for text in texts: inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) # Use mean pooling to get sentence embedding token_embeddings = outputs.last_hidden_state attention_mask = inputs["attention_mask"] # Mask padded tokens input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() embeddings_sum = torch.sum(token_embeddings * input_mask_expanded, 1) mask_sum = torch.clamp(input_mask_expanded.sum(1), min=1e-9) embedding = (embeddings_sum / mask_sum).squeeze().cpu().numpy() embeddings.append(embedding) return np.array(embeddings) class RAGSystem: def __init__(self): """Initialize RAG system""" self.doc_processor = DocumentProcessor() self.index = None self.chunks = [] self.sources = {} def add_document(self, file_path: str, source_id: str): """Process and add document to the RAG system""" if file_path.lower().endswith('.pdf'): text = self.doc_processor.load_pdf(file_path) else: text = self.doc_processor.load_text(file_path) # Chunk the document chunks = self.doc_processor.chunk_text(text) # Store the chunks with source info start_idx = len(self.chunks) for i, chunk in enumerate(chunks): chunk_id = start_idx + i self.chunks.append(chunk) self.sources[chunk_id] = { 'source': source_id, 'file_path': file_path, 'chunk_index': i } # Create or update the FAISS index self._update_index() return len(chunks) def _update_index(self): """Update or create the FAISS index""" embeddings = self.doc_processor.get_embeddings(self.chunks) vector_dimension = embeddings.shape[1] if self.index is None: # Initialize FAISS index - using L2 distance self.index = faiss.IndexFlatL2(vector_dimension) # Add embeddings to index faiss.normalize_L2(embeddings) self.index.add(embeddings) def query(self, query_text: str, top_k: int = 3) -> List[Dict]: """Retrieve relevant chunks for a query""" # Check if we have a valid index and chunks if self.index is None or len(self.chunks) == 0: print("Warning: No index or chunks available") return [] # Get query embedding query_embedding = self.doc_processor.get_embeddings([query_text]) faiss.normalize_L2(query_embedding) # Search the index - limit to the actual number of chunks we have actual_k = min(top_k, len(self.chunks)) if actual_k == 0: return [] distances, indices = self.index.search(query_embedding, actual_k) # Format results results = [] for i, idx in enumerate(indices[0]): # FAISS may return -1 if not enough results or might return out-of-bounds indices if idx >= 0 and idx < len(self.chunks): # Safety check that the idx is a valid index in self.sources if idx in self.sources: source_info = self.sources[idx] else: # Create a default source if there's a mismatch source_info = { 'source': "Unknown", 'file_path': "Unknown", 'chunk_index': 0 } results.append({ 'chunk': self.chunks[idx], 'score': float(1 / (1 + distances[0][i])), # Convert distance to similarity score 'source': source_info }) return results def get_context_for_query(self, query: str, top_k: int = 3) -> str: """Get formatted context for a query to include in a prompt""" results = self.query(query, top_k) if not results: return "No relevant information found." context = "RELEVANT INFORMATION:\n\n" for i, result in enumerate(results): context += f"[{i+1}] From {result['source']['source']}:\n" context += f"{result['chunk']}\n\n" return context def save_index(self, directory: str): """Save the FAISS index and chunks to disk""" os.makedirs(directory, exist_ok=True) # Save FAISS index - fix the write function index_path = os.path.join(directory, "index.faiss") try: # Try using standard method faiss.write_index(self.index, index_path) except AttributeError: # Fallback for faiss-cpu where functions might be in different location import faiss.swigfaiss as swigfaiss swigfaiss.write_index(self.index, index_path) # Save chunks and sources np.save(os.path.join(directory, "chunks.npy"), np.array(self.chunks, dtype=object)) # Save sources directly as JSON for more reliable loading import json with open(os.path.join(directory, "sources.json"), "w") as f: # Convert int keys to strings for JSON serialization sources_dict = {str(k): v for k, v in self.sources.items()} json.dump(sources_dict, f) def load_index(self, directory: str): """Load the FAISS index and chunks from disk""" # Load FAISS index - fix the read function index_path = os.path.join(directory, "index.faiss") if os.path.exists(index_path): try: # Try using io module if available self.index = faiss.read_index(index_path) except AttributeError: # Fallback for faiss-cpu where functions might be in different location import faiss.swigfaiss as swigfaiss self.index = swigfaiss.read_index(index_path) else: raise FileNotFoundError(f"Index file not found at {index_path}") # Load chunks chunks_path = os.path.join(directory, "chunks.npy") if not os.path.exists(chunks_path): raise FileNotFoundError(f"Chunks file not found at {chunks_path}") self.chunks = np.load(chunks_path, allow_pickle=True).tolist() # Load sources from JSON import json sources_json_path = os.path.join(directory, "sources.json") # Try the new JSON format first if os.path.exists(sources_json_path): with open(sources_json_path, "r") as f: sources_dict = json.load(f) # Convert string keys back to integers self.sources = {int(k): v for k, v in sources_dict.items()} else: # Legacy support for old numpy format sources_path = os.path.join(directory, "sources.npy") if not os.path.exists(sources_path): raise FileNotFoundError(f"Sources file not found at either {sources_json_path} or {sources_path}") try: # Try loading as array with dict inside loaded_data = np.load(sources_path, allow_pickle=True) if isinstance(loaded_data[0], dict): self.sources = loaded_data[0] else: # If it's not a dictionary, try multiple approaches self.sources = dict(enumerate(loaded_data)) except Exception as e: raise RuntimeError(f"Failed to load sources: {e}") # Example usage if __name__ == "__main__": # Create RAG system rag = RAGSystem() # Add documents linkedin_count = rag.add_document("me/linkedin.pdf", "LinkedIn Profile") summary_count = rag.add_document("me/summary.txt", "Professional Summary") print(f"Added {linkedin_count} chunks from LinkedIn profile") print(f"Added {summary_count} chunks from Professional Summary") # Save the index rag.save_index("me/rag_index") # Test query query = "What are Sagarnil's technical skills?" context = rag.get_context_for_query(query) print(f"\nQuery: {query}\n") print(context)