career_conversation / rag_utils.py
sagarnildass's picture
Upload folder using huggingface_hub
3c18172 verified
import os
import numpy as np
import faiss
import torch
from transformers import AutoTokenizer, AutoModel
from pypdf import PdfReader
from typing import List, Dict, Tuple
import re
# Constants
CHUNK_SIZE = 300 # tokens
CHUNK_OVERLAP = 50 # tokens
class DocumentProcessor:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
"""Initialize document processor with embedding model"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def load_pdf(self, pdf_path: str) -> str:
"""Extract text from PDF file"""
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n\n"
return text
def load_text(self, text_path: str) -> str:
"""Load text from a file"""
with open(text_path, "r", encoding="utf-8") as f:
return f.read()
def chunk_text(self, text: str) -> List[str]:
"""Split text into chunks with overlap"""
# Simple sentence-based chunking
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
sentence_tokens = self.tokenizer.tokenize(sentence)
sentence_token_count = len(sentence_tokens)
if current_size + sentence_token_count > CHUNK_SIZE and current_chunk:
# Save current chunk
chunks.append(" ".join(current_chunk))
# Start new chunk with overlap
overlap_size = min(CHUNK_OVERLAP, len(current_chunk))
current_chunk = current_chunk[-overlap_size:] + [sentence]
current_size = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk)
else:
current_chunk.append(sentence)
current_size += sentence_token_count
# Add the last chunk if not empty
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def get_embeddings(self, texts: List[str]) -> np.ndarray:
"""Convert text chunks to embeddings"""
embeddings = []
for text in texts:
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
# Use mean pooling to get sentence embedding
token_embeddings = outputs.last_hidden_state
attention_mask = inputs["attention_mask"]
# Mask padded tokens
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
embeddings_sum = torch.sum(token_embeddings * input_mask_expanded, 1)
mask_sum = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embedding = (embeddings_sum / mask_sum).squeeze().cpu().numpy()
embeddings.append(embedding)
return np.array(embeddings)
class RAGSystem:
def __init__(self):
"""Initialize RAG system"""
self.doc_processor = DocumentProcessor()
self.index = None
self.chunks = []
self.sources = {}
def add_document(self, file_path: str, source_id: str):
"""Process and add document to the RAG system"""
if file_path.lower().endswith('.pdf'):
text = self.doc_processor.load_pdf(file_path)
else:
text = self.doc_processor.load_text(file_path)
# Chunk the document
chunks = self.doc_processor.chunk_text(text)
# Store the chunks with source info
start_idx = len(self.chunks)
for i, chunk in enumerate(chunks):
chunk_id = start_idx + i
self.chunks.append(chunk)
self.sources[chunk_id] = {
'source': source_id,
'file_path': file_path,
'chunk_index': i
}
# Create or update the FAISS index
self._update_index()
return len(chunks)
def _update_index(self):
"""Update or create the FAISS index"""
embeddings = self.doc_processor.get_embeddings(self.chunks)
vector_dimension = embeddings.shape[1]
if self.index is None:
# Initialize FAISS index - using L2 distance
self.index = faiss.IndexFlatL2(vector_dimension)
# Add embeddings to index
faiss.normalize_L2(embeddings)
self.index.add(embeddings)
def query(self, query_text: str, top_k: int = 3) -> List[Dict]:
"""Retrieve relevant chunks for a query"""
# Check if we have a valid index and chunks
if self.index is None or len(self.chunks) == 0:
print("Warning: No index or chunks available")
return []
# Get query embedding
query_embedding = self.doc_processor.get_embeddings([query_text])
faiss.normalize_L2(query_embedding)
# Search the index - limit to the actual number of chunks we have
actual_k = min(top_k, len(self.chunks))
if actual_k == 0:
return []
distances, indices = self.index.search(query_embedding, actual_k)
# Format results
results = []
for i, idx in enumerate(indices[0]):
# FAISS may return -1 if not enough results or might return out-of-bounds indices
if idx >= 0 and idx < len(self.chunks):
# Safety check that the idx is a valid index in self.sources
if idx in self.sources:
source_info = self.sources[idx]
else:
# Create a default source if there's a mismatch
source_info = {
'source': "Unknown",
'file_path': "Unknown",
'chunk_index': 0
}
results.append({
'chunk': self.chunks[idx],
'score': float(1 / (1 + distances[0][i])), # Convert distance to similarity score
'source': source_info
})
return results
def get_context_for_query(self, query: str, top_k: int = 3) -> str:
"""Get formatted context for a query to include in a prompt"""
results = self.query(query, top_k)
if not results:
return "No relevant information found."
context = "RELEVANT INFORMATION:\n\n"
for i, result in enumerate(results):
context += f"[{i+1}] From {result['source']['source']}:\n"
context += f"{result['chunk']}\n\n"
return context
def save_index(self, directory: str):
"""Save the FAISS index and chunks to disk"""
os.makedirs(directory, exist_ok=True)
# Save FAISS index - fix the write function
index_path = os.path.join(directory, "index.faiss")
try:
# Try using standard method
faiss.write_index(self.index, index_path)
except AttributeError:
# Fallback for faiss-cpu where functions might be in different location
import faiss.swigfaiss as swigfaiss
swigfaiss.write_index(self.index, index_path)
# Save chunks and sources
np.save(os.path.join(directory, "chunks.npy"), np.array(self.chunks, dtype=object))
# Save sources directly as JSON for more reliable loading
import json
with open(os.path.join(directory, "sources.json"), "w") as f:
# Convert int keys to strings for JSON serialization
sources_dict = {str(k): v for k, v in self.sources.items()}
json.dump(sources_dict, f)
def load_index(self, directory: str):
"""Load the FAISS index and chunks from disk"""
# Load FAISS index - fix the read function
index_path = os.path.join(directory, "index.faiss")
if os.path.exists(index_path):
try:
# Try using io module if available
self.index = faiss.read_index(index_path)
except AttributeError:
# Fallback for faiss-cpu where functions might be in different location
import faiss.swigfaiss as swigfaiss
self.index = swigfaiss.read_index(index_path)
else:
raise FileNotFoundError(f"Index file not found at {index_path}")
# Load chunks
chunks_path = os.path.join(directory, "chunks.npy")
if not os.path.exists(chunks_path):
raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
self.chunks = np.load(chunks_path, allow_pickle=True).tolist()
# Load sources from JSON
import json
sources_json_path = os.path.join(directory, "sources.json")
# Try the new JSON format first
if os.path.exists(sources_json_path):
with open(sources_json_path, "r") as f:
sources_dict = json.load(f)
# Convert string keys back to integers
self.sources = {int(k): v for k, v in sources_dict.items()}
else:
# Legacy support for old numpy format
sources_path = os.path.join(directory, "sources.npy")
if not os.path.exists(sources_path):
raise FileNotFoundError(f"Sources file not found at either {sources_json_path} or {sources_path}")
try:
# Try loading as array with dict inside
loaded_data = np.load(sources_path, allow_pickle=True)
if isinstance(loaded_data[0], dict):
self.sources = loaded_data[0]
else:
# If it's not a dictionary, try multiple approaches
self.sources = dict(enumerate(loaded_data))
except Exception as e:
raise RuntimeError(f"Failed to load sources: {e}")
# Example usage
if __name__ == "__main__":
# Create RAG system
rag = RAGSystem()
# Add documents
linkedin_count = rag.add_document("me/linkedin.pdf", "LinkedIn Profile")
summary_count = rag.add_document("me/summary.txt", "Professional Summary")
print(f"Added {linkedin_count} chunks from LinkedIn profile")
print(f"Added {summary_count} chunks from Professional Summary")
# Save the index
rag.save_index("me/rag_index")
# Test query
query = "What are Sagarnil's technical skills?"
context = rag.get_context_for_query(query)
print(f"\nQuery: {query}\n")
print(context)