File size: 11,280 Bytes
3c18172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
import numpy as np
import faiss
import torch
from transformers import AutoTokenizer, AutoModel
from pypdf import PdfReader
from typing import List, Dict, Tuple
import re

# Constants
CHUNK_SIZE = 300  # tokens
CHUNK_OVERLAP = 50  # tokens

class DocumentProcessor:
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
        """Initialize document processor with embedding model"""
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        
    def load_pdf(self, pdf_path: str) -> str:
        """Extract text from PDF file"""
        reader = PdfReader(pdf_path)
        text = ""
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + "\n\n"
        return text
    
    def load_text(self, text_path: str) -> str:
        """Load text from a file"""
        with open(text_path, "r", encoding="utf-8") as f:
            return f.read()
            
    def chunk_text(self, text: str) -> List[str]:
        """Split text into chunks with overlap"""
        # Simple sentence-based chunking
        sentences = re.split(r'(?<=[.!?])\s+', text)
        chunks = []
        current_chunk = []
        current_size = 0
        
        for sentence in sentences:
            sentence_tokens = self.tokenizer.tokenize(sentence)
            sentence_token_count = len(sentence_tokens)
            
            if current_size + sentence_token_count > CHUNK_SIZE and current_chunk:
                # Save current chunk
                chunks.append(" ".join(current_chunk))
                
                # Start new chunk with overlap
                overlap_size = min(CHUNK_OVERLAP, len(current_chunk))
                current_chunk = current_chunk[-overlap_size:] + [sentence]
                current_size = sum(len(self.tokenizer.tokenize(s)) for s in current_chunk)
            else:
                current_chunk.append(sentence)
                current_size += sentence_token_count
        
        # Add the last chunk if not empty
        if current_chunk:
            chunks.append(" ".join(current_chunk))
            
        return chunks
    
    def get_embeddings(self, texts: List[str]) -> np.ndarray:
        """Convert text chunks to embeddings"""
        embeddings = []
        
        for text in texts:
            inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            # Use mean pooling to get sentence embedding
            token_embeddings = outputs.last_hidden_state
            attention_mask = inputs["attention_mask"]
            
            # Mask padded tokens
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            embeddings_sum = torch.sum(token_embeddings * input_mask_expanded, 1)
            mask_sum = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            embedding = (embeddings_sum / mask_sum).squeeze().cpu().numpy()
            embeddings.append(embedding)
            
        return np.array(embeddings)

class RAGSystem:
    def __init__(self):
        """Initialize RAG system"""
        self.doc_processor = DocumentProcessor()
        self.index = None
        self.chunks = []
        self.sources = {}
        
    def add_document(self, file_path: str, source_id: str):
        """Process and add document to the RAG system"""
        if file_path.lower().endswith('.pdf'):
            text = self.doc_processor.load_pdf(file_path)
        else:
            text = self.doc_processor.load_text(file_path)
            
        # Chunk the document
        chunks = self.doc_processor.chunk_text(text)
        
        # Store the chunks with source info
        start_idx = len(self.chunks)
        for i, chunk in enumerate(chunks):
            chunk_id = start_idx + i
            self.chunks.append(chunk)
            self.sources[chunk_id] = {
                'source': source_id,
                'file_path': file_path,
                'chunk_index': i
            }
            
        # Create or update the FAISS index
        self._update_index()
        
        return len(chunks)
    
    def _update_index(self):
        """Update or create the FAISS index"""
        embeddings = self.doc_processor.get_embeddings(self.chunks)
        vector_dimension = embeddings.shape[1]
        
        if self.index is None:
            # Initialize FAISS index - using L2 distance
            self.index = faiss.IndexFlatL2(vector_dimension)
            
        # Add embeddings to index
        faiss.normalize_L2(embeddings)
        self.index.add(embeddings)
        
    def query(self, query_text: str, top_k: int = 3) -> List[Dict]:
        """Retrieve relevant chunks for a query"""
        # Check if we have a valid index and chunks
        if self.index is None or len(self.chunks) == 0:
            print("Warning: No index or chunks available")
            return []
            
        # Get query embedding
        query_embedding = self.doc_processor.get_embeddings([query_text])
        faiss.normalize_L2(query_embedding)
        
        # Search the index - limit to the actual number of chunks we have
        actual_k = min(top_k, len(self.chunks))
        if actual_k == 0:
            return []
            
        distances, indices = self.index.search(query_embedding, actual_k)
        
        # Format results
        results = []
        for i, idx in enumerate(indices[0]):
            # FAISS may return -1 if not enough results or might return out-of-bounds indices
            if idx >= 0 and idx < len(self.chunks):
                # Safety check that the idx is a valid index in self.sources
                if idx in self.sources:
                    source_info = self.sources[idx]
                else:
                    # Create a default source if there's a mismatch
                    source_info = {
                        'source': "Unknown",
                        'file_path': "Unknown",
                        'chunk_index': 0
                    }
                
                results.append({
                    'chunk': self.chunks[idx],
                    'score': float(1 / (1 + distances[0][i])),  # Convert distance to similarity score
                    'source': source_info
                })
                
        return results
    
    def get_context_for_query(self, query: str, top_k: int = 3) -> str:
        """Get formatted context for a query to include in a prompt"""
        results = self.query(query, top_k)
        
        if not results:
            return "No relevant information found."
            
        context = "RELEVANT INFORMATION:\n\n"
        for i, result in enumerate(results):
            context += f"[{i+1}] From {result['source']['source']}:\n"
            context += f"{result['chunk']}\n\n"
            
        return context
    
    def save_index(self, directory: str):
        """Save the FAISS index and chunks to disk"""
        os.makedirs(directory, exist_ok=True)
        
        # Save FAISS index - fix the write function
        index_path = os.path.join(directory, "index.faiss")
        try:
            # Try using standard method
            faiss.write_index(self.index, index_path)
        except AttributeError:
            # Fallback for faiss-cpu where functions might be in different location
            import faiss.swigfaiss as swigfaiss
            swigfaiss.write_index(self.index, index_path)
        
        # Save chunks and sources
        np.save(os.path.join(directory, "chunks.npy"), np.array(self.chunks, dtype=object))
        
        # Save sources directly as JSON for more reliable loading
        import json
        with open(os.path.join(directory, "sources.json"), "w") as f:
            # Convert int keys to strings for JSON serialization
            sources_dict = {str(k): v for k, v in self.sources.items()}
            json.dump(sources_dict, f)
        
    def load_index(self, directory: str):
        """Load the FAISS index and chunks from disk"""
        # Load FAISS index - fix the read function
        index_path = os.path.join(directory, "index.faiss")
        if os.path.exists(index_path):
            try:
                # Try using io module if available
                self.index = faiss.read_index(index_path)
            except AttributeError:
                # Fallback for faiss-cpu where functions might be in different location
                import faiss.swigfaiss as swigfaiss
                self.index = swigfaiss.read_index(index_path)
        else:
            raise FileNotFoundError(f"Index file not found at {index_path}")
        
        # Load chunks
        chunks_path = os.path.join(directory, "chunks.npy")
        if not os.path.exists(chunks_path):
            raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
        self.chunks = np.load(chunks_path, allow_pickle=True).tolist()
        
        # Load sources from JSON
        import json
        sources_json_path = os.path.join(directory, "sources.json")
        
        # Try the new JSON format first
        if os.path.exists(sources_json_path):
            with open(sources_json_path, "r") as f:
                sources_dict = json.load(f)
                # Convert string keys back to integers
                self.sources = {int(k): v for k, v in sources_dict.items()}
        else:
            # Legacy support for old numpy format
            sources_path = os.path.join(directory, "sources.npy")
            if not os.path.exists(sources_path):
                raise FileNotFoundError(f"Sources file not found at either {sources_json_path} or {sources_path}")
            
            try:
                # Try loading as array with dict inside
                loaded_data = np.load(sources_path, allow_pickle=True)
                if isinstance(loaded_data[0], dict):
                    self.sources = loaded_data[0]
                else:
                    # If it's not a dictionary, try multiple approaches
                    self.sources = dict(enumerate(loaded_data))
            except Exception as e:
                raise RuntimeError(f"Failed to load sources: {e}")


# Example usage
if __name__ == "__main__":
    # Create RAG system
    rag = RAGSystem()
    
    # Add documents
    linkedin_count = rag.add_document("me/linkedin.pdf", "LinkedIn Profile")
    summary_count = rag.add_document("me/summary.txt", "Professional Summary")
    
    print(f"Added {linkedin_count} chunks from LinkedIn profile")
    print(f"Added {summary_count} chunks from Professional Summary")
    
    # Save the index
    rag.save_index("me/rag_index")
    
    # Test query
    query = "What are Sagarnil's technical skills?"
    context = rag.get_context_for_query(query)
    print(f"\nQuery: {query}\n")
    print(context)