YanBoChen commited on
Commit
6c249e5
·
1 Parent(s): 69b7911

WIP: feat(retrieval): implement basic vector retrieval system for medical documents

Browse files
Files changed (1) hide show
  1. src/retrieval.py +308 -0
src/retrieval.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic Retrieval System for OnCall.ai
3
+
4
+ This module implements the core vector retrieval functionality:
5
+ - Basic vector search
6
+ - Source marking
7
+ - Unified output format
8
+ """
9
+
10
+ import numpy as np
11
+ import json
12
+ from pathlib import Path
13
+ from typing import Dict, List, Tuple, Any, Optional
14
+ from sentence_transformers import SentenceTransformer
15
+ from annoy import AnnoyIndex
16
+ import logging
17
+
18
+ # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class BasicRetrievalSystem:
26
+ """Basic vector retrieval system for medical documents"""
27
+
28
+ def __init__(self, embedding_dim: int = 768):
29
+ """
30
+ Initialize the retrieval system
31
+
32
+ Args:
33
+ embedding_dim: Dimension of embeddings (default: 768 for PubMedBERT)
34
+ """
35
+ self.embedding_dim = embedding_dim
36
+ self.embedding_model = None
37
+ self.emergency_index = None
38
+ self.treatment_index = None
39
+ self.emergency_chunks = {}
40
+ self.treatment_chunks = {}
41
+
42
+ # Initialize system
43
+ self._initialize_system()
44
+
45
+ def _initialize_system(self) -> None:
46
+ """Initialize embeddings, indices and chunks"""
47
+ try:
48
+ logger.info("Initializing retrieval system...")
49
+
50
+ # Initialize embedding model
51
+ self.embedding_model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
52
+ logger.info("Embedding model loaded successfully")
53
+
54
+ # Initialize Annoy indices
55
+ self.emergency_index = AnnoyIndex(self.embedding_dim, 'angular')
56
+ self.treatment_index = AnnoyIndex(self.embedding_dim, 'angular')
57
+
58
+ # Load data
59
+ base_path = Path("models")
60
+ self._load_chunks(base_path)
61
+ self._load_embeddings(base_path)
62
+ self._build_or_load_indices(base_path)
63
+
64
+ logger.info("Retrieval system initialized successfully")
65
+
66
+ except Exception as e:
67
+ logger.error(f"Failed to initialize retrieval system: {e}")
68
+ raise
69
+
70
+ def _load_chunks(self, base_path: Path) -> None:
71
+ """Load chunk data from JSON files"""
72
+ try:
73
+ # Load emergency chunks
74
+ with open(base_path / "embeddings" / "emergency_chunks.json", 'r') as f:
75
+ self.emergency_chunks = json.load(f)
76
+
77
+ # Load treatment chunks
78
+ with open(base_path / "embeddings" / "treatment_chunks.json", 'r') as f:
79
+ self.treatment_chunks = json.load(f)
80
+
81
+ logger.info("Chunks loaded successfully")
82
+
83
+ except FileNotFoundError as e:
84
+ logger.error(f"Chunk file not found: {e}")
85
+ raise
86
+ except json.JSONDecodeError as e:
87
+ logger.error(f"Invalid JSON in chunk file: {e}")
88
+ raise
89
+
90
+ def _load_embeddings(self, base_path: Path) -> None:
91
+ """Load pre-computed embeddings"""
92
+ try:
93
+ # Load emergency embeddings
94
+ self.emergency_embeddings = np.load(
95
+ base_path / "embeddings" / "emergency_embeddings.npy"
96
+ )
97
+
98
+ # Load treatment embeddings
99
+ self.treatment_embeddings = np.load(
100
+ base_path / "embeddings" / "treatment_embeddings.npy"
101
+ )
102
+
103
+ logger.info("Embeddings loaded successfully")
104
+
105
+ except Exception as e:
106
+ logger.error(f"Failed to load embeddings: {e}")
107
+ raise
108
+
109
+ def _build_or_load_indices(self, base_path: Path) -> None:
110
+ """Build or load Annoy indices"""
111
+ indices_path = base_path / "indices" / "annoy"
112
+ emergency_index_path = indices_path / "emergency.ann"
113
+ treatment_index_path = indices_path / "treatment.ann"
114
+
115
+ try:
116
+ # Emergency index
117
+ if emergency_index_path.exists():
118
+ self.emergency_index.load(str(emergency_index_path))
119
+ logger.info("Loaded existing emergency index")
120
+ else:
121
+ self._build_index(
122
+ self.emergency_embeddings,
123
+ self.emergency_index,
124
+ emergency_index_path
125
+ )
126
+ logger.info("Built new emergency index")
127
+
128
+ # Treatment index
129
+ if treatment_index_path.exists():
130
+ self.treatment_index.load(str(treatment_index_path))
131
+ logger.info("Loaded existing treatment index")
132
+ else:
133
+ self._build_index(
134
+ self.treatment_embeddings,
135
+ self.treatment_index,
136
+ treatment_index_path
137
+ )
138
+ logger.info("Built new treatment index")
139
+
140
+ except Exception as e:
141
+ logger.error(f"Failed to build/load indices: {e}")
142
+ raise
143
+
144
+ def _build_index(self, embeddings: np.ndarray, index: AnnoyIndex,
145
+ save_path: Path, n_trees: int = 15) -> None:
146
+ """
147
+ Build and save Annoy index
148
+
149
+ Args:
150
+ embeddings: Embedding vectors
151
+ index: AnnoyIndex instance
152
+ save_path: Path to save the index
153
+ n_trees: Number of trees for Annoy index (default: 15)
154
+ """
155
+ try:
156
+ for i, vec in enumerate(embeddings):
157
+ index.add_item(i, vec)
158
+ index.build(n_trees)
159
+ save_path.parent.mkdir(parents=True, exist_ok=True)
160
+ index.save(str(save_path))
161
+
162
+ except Exception as e:
163
+ logger.error(f"Failed to build index: {e}")
164
+ raise
165
+
166
+ def search(self, query: str, top_k: int = 5) -> Dict[str, Any]:
167
+ """
168
+ Perform vector search on both indices
169
+
170
+ Args:
171
+ query: Search query
172
+ top_k: Number of results to return from each index
173
+
174
+ Returns:
175
+ Dict containing search results and metadata
176
+ """
177
+ try:
178
+ # Get query embedding
179
+ query_embedding = self.embedding_model.encode([query])[0]
180
+
181
+ # Search both indices
182
+ emergency_results = self._search_index(
183
+ query_embedding,
184
+ self.emergency_index,
185
+ self.emergency_chunks,
186
+ "emergency",
187
+ top_k
188
+ )
189
+
190
+ treatment_results = self._search_index(
191
+ query_embedding,
192
+ self.treatment_index,
193
+ self.treatment_chunks,
194
+ "treatment",
195
+ top_k
196
+ )
197
+
198
+ results = {
199
+ "query": query,
200
+ "emergency_results": emergency_results,
201
+ "treatment_results": treatment_results,
202
+ "total_results": len(emergency_results) + len(treatment_results)
203
+ }
204
+
205
+ # Post-process results
206
+ processed_results = self.post_process_results(results)
207
+
208
+ return processed_results
209
+
210
+ except Exception as e:
211
+ logger.error(f"Search failed: {e}")
212
+ raise
213
+
214
+ def _search_index(self, query_embedding: np.ndarray, index: AnnoyIndex,
215
+ chunks: Dict, source_type: str, top_k: int) -> List[Dict]:
216
+ """
217
+ Search a single index and format results
218
+
219
+ Args:
220
+ query_embedding: Query vector
221
+ index: AnnoyIndex to search
222
+ chunks: Chunk data
223
+ source_type: Type of source ("emergency" or "treatment")
224
+ top_k: Number of results to return
225
+
226
+ Returns:
227
+ List of formatted results
228
+ """
229
+ # Get nearest neighbors
230
+ indices, distances = index.get_nns_by_vector(
231
+ query_embedding, top_k, include_distances=True
232
+ )
233
+
234
+ # Format results
235
+ results = []
236
+ for idx, distance in zip(indices, distances):
237
+ chunk_data = chunks[str(idx)]
238
+ result = {
239
+ "type": source_type, # Using 'type' to match metadata
240
+ "chunk_id": idx,
241
+ "distance": distance,
242
+ "text": chunk_data.get("text", ""),
243
+ "matched": chunk_data.get("matched", ""),
244
+ "matched_treatment": chunk_data.get("matched_treatment", "")
245
+ }
246
+ results.append(result)
247
+
248
+ return results
249
+
250
+ def post_process_results(self, results: Dict[str, Any]) -> Dict[str, Any]:
251
+ """
252
+ Post-process search results
253
+ - Remove duplicates
254
+ - Sort by distance
255
+ - Add metadata enrichment
256
+
257
+ Args:
258
+ results: Raw search results
259
+
260
+ Returns:
261
+ Processed results
262
+ """
263
+ try:
264
+ emergency_results = results["emergency_results"]
265
+ treatment_results = results["treatment_results"]
266
+
267
+ # Combine all results
268
+ all_results = emergency_results + treatment_results
269
+
270
+ # Remove duplicates based on text similarity
271
+ unique_results = self._remove_duplicates(all_results)
272
+
273
+ # Sort by distance
274
+ sorted_results = sorted(unique_results, key=lambda x: x["distance"])
275
+
276
+ return {
277
+ "query": results["query"],
278
+ "processed_results": sorted_results,
279
+ "total_results": len(sorted_results),
280
+ "processing_info": {
281
+ "duplicates_removed": len(all_results) - len(unique_results)
282
+ }
283
+ }
284
+
285
+ except Exception as e:
286
+ logger.error(f"Post-processing failed: {e}")
287
+ raise
288
+
289
+ def _remove_duplicates(self, results: List[Dict]) -> List[Dict]:
290
+ """
291
+ Remove duplicate results based on text similarity
292
+
293
+ Args:
294
+ results: List of search results
295
+
296
+ Returns:
297
+ Deduplicated results
298
+ """
299
+ seen_texts = set()
300
+ unique_results = []
301
+
302
+ for result in results:
303
+ text = result["text"]
304
+ if text not in seen_texts:
305
+ seen_texts.add(text)
306
+ unique_results.append(result)
307
+
308
+ return unique_results