YanBoChen commited on
Commit
922ed80
Β·
1 Parent(s): 6083d96

feat(data_processing): Implement token length control with semantic preservation

Browse files

BREAKING CHANGE: Modify chunk creation to handle >512 token texts

Problem:
- Token indices sequence length exceeding model's maximum (512 tokens)
- Risk of semantic information loss during text chunking
- Potential impact on medical term context preservation

Solution:
1. Dynamic Character-to-Token Ratio
- Calculate average chars_per_token from sample text
- Use ratio to estimate initial chunk boundaries
- Prevents tokenizing entire long document at once

2. Semantic-Aware Chunking
- Set ROUGH_CHUNK_TARGET_TOKENS = 512
- Keep keywords centered in chunks
- Maintain context window around keywords
- Ensure rough_chunk stays within token limit

3. Overlap Strategy
- Implement sliding window with 64-token overlap
- Preserve context across chunk boundaries
- Maintain semantic continuity
- Prevent information loss at chunk edges

Technical Details:
- Target chunk size: 512 tokens (maximum model limit)
- Overlap size: 64 tokens (empirically determined)
- Dynamic ratio calculation using sample text
- Centered keyword positioning

Impact:
βœ“ Eliminates token length warnings
βœ“ Preserves medical term context
βœ“ Maintains semantic relationships
βœ“ Improves retrieval quality
βœ“ Optimizes processing efficiency

Testing:
- Verified with long medical texts
- Confirmed keyword context preservation
- Validated chunk boundary handling
- Tested overlap effectiveness

Co-authored-by: YanBo Chen

commit_message_embedding_update.txt DELETED
@@ -1,43 +0,0 @@
1
- refactor(data_processing): optimize chunking strategy with token-based approach
2
-
3
- BREAKING CHANGE: Switch from character-based to token-based chunking and improve keyword context preservation
4
-
5
- - Replace character-based chunking with token-based approach using PubMedBERT tokenizer
6
- - Set chunk_size to 256 tokens and chunk_overlap to 64 tokens for optimal performance
7
- - Implement dynamic chunking strategy centered around medical keywords
8
- - Add token count validation to ensure semantic integrity
9
- - Optimize memory usage with lazy loading of tokenizer and model
10
- - Update chunking methods to handle token-level operations
11
- - Add comprehensive logging for debugging token counts
12
- - Update tests to verify token-based chunking behavior
13
-
14
- Recent Improvements:
15
- - Fix keyword context preservation in chunks
16
- - Implement separate tokenization for pre-keyword and post-keyword text
17
- - Add precise boundary calculation based on keyword length
18
- - Ensure medical terms (e.g., "ST elevation") remain intact
19
- - Improve chunk boundary calculations to maintain keyword context
20
- - Add validation to verify keyword presence in generated chunks
21
-
22
- Technical Details:
23
- - chunk_size: 256 tokens (based on PubMedBERT context window)
24
- - overlap: 64 tokens (25% overlap for context continuity)
25
- - Model: NeuML/pubmedbert-base-embeddings (768 dims)
26
- - Tokenizer: Same as embedding model for consistency
27
- - Keyword-centered chunking with balanced context distribution
28
-
29
- Performance Impact:
30
- - Improved semantic coherence in chunks
31
- - Better handling of medical terminology
32
- - Reduced redundancy in overlapping regions
33
- - Optimized for downstream retrieval tasks
34
- - Enhanced preservation of medical term context
35
- - More accurate chunk boundaries around keywords
36
-
37
- Testing:
38
- - Added token count validation in tests
39
- - Verified keyword preservation in chunks
40
- - Confirmed overlap handling
41
- - Tested with sample medical texts
42
- - Validated medical terminology preservation
43
- - Verified chunk context balance around keywords
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OnCall.ai src package
3
+
4
+ This package contains the core implementation of the OnCall.ai system.
5
+ """
6
+
7
+ # Version
8
+ __version__ = '0.1.0'
src/commit_message_20250726_data_processing.txt DELETED
@@ -1,52 +0,0 @@
1
- feat(data-processing): implement data processing pipeline with embeddings
2
-
3
- BREAKING CHANGE: Add data processing implementation with robust path handling and improved text processing
4
-
5
- Key Changes:
6
- 1. Create DataProcessor class for medical data processing:
7
- - Handle paths with spaces and special characters
8
- - Support dataset/dataset directory structure
9
- - Add detailed logging for debugging
10
- - Implement case-insensitive text processing
11
-
12
- 2. Implement core functionalities:
13
- - Load filtered emergency and treatment data
14
- - Create intelligent chunks based on matched keywords
15
- - Generate embeddings using NeuML/pubmedbert-base-embeddings
16
- - Build ANNOY indices for vector search
17
- - Save embeddings and metadata separately
18
- - Improve keyword matching with case-insensitive comparison
19
- - Add proper chunk boundary calculations for medical terms
20
-
21
- 3. Add test coverage:
22
- - Basic data loading tests
23
- - Chunking functionality tests
24
- - Model loading tests
25
- - Token-based chunking validation
26
- - Medical terminology preservation tests
27
-
28
- Technical Details:
29
- - Use pathlib.Path.resolve() for robust path handling
30
- - Separate storage for embeddings and indices:
31
- * /models/embeddings/ for vector representations
32
- * /models/indices/annoy/ for search indices
33
- - Keep keywords as metadata without embedding
34
- - Implement case-insensitive text processing while preserving medical term integrity
35
- - Add proper chunk overlap handling
36
-
37
- Testing:
38
- βœ… Data loading: 11,914 emergency + 11,023 treatment records
39
- βœ… Chunking: Successful with keyword-centered approach
40
- βœ… Model loading: NeuML/pubmedbert-base-embeddings (768 dims)
41
- βœ… Token chunking: Verified with medical terms (e.g., "ST elevation")
42
-
43
- Storage Structure:
44
- /models/
45
- β”œβ”€β”€ embeddings/ # Vector representations
46
- └── indices/
47
- └── annoy/ # Search indices (.ann files)
48
-
49
- Next Steps:
50
- - Integrate with Meditron for enhanced processing
51
- - Implement prompt engineering
52
- - Add hybrid search functionality
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/commit_message_embedding_update.txt DELETED
@@ -1,43 +0,0 @@
1
- refactor(data_processing): optimize chunking strategy with token-based approach
2
-
3
- BREAKING CHANGE: Switch from character-based to token-based chunking and improve keyword context preservation
4
-
5
- - Replace character-based chunking with token-based approach using PubMedBERT tokenizer
6
- - Set chunk_size to 256 tokens and chunk_overlap to 64 tokens for optimal performance
7
- - Implement dynamic chunking strategy centered around medical keywords
8
- - Add token count validation to ensure semantic integrity
9
- - Optimize memory usage with lazy loading of tokenizer and model
10
- - Update chunking methods to handle token-level operations
11
- - Add comprehensive logging for debugging token counts
12
- - Update tests to verify token-based chunking behavior
13
-
14
- Recent Improvements:
15
- - Fix keyword context preservation in chunks
16
- - Implement separate tokenization for pre-keyword and post-keyword text
17
- - Add precise boundary calculation based on keyword length
18
- - Ensure medical terms (e.g., "ST elevation") remain intact
19
- - Improve chunk boundary calculations to maintain keyword context
20
- - Add validation to verify keyword presence in generated chunks
21
-
22
- Technical Details:
23
- - chunk_size: 256 tokens (based on PubMedBERT context window)
24
- - overlap: 64 tokens (25% overlap for context continuity)
25
- - Model: NeuML/pubmedbert-base-embeddings (768 dims)
26
- - Tokenizer: Same as embedding model for consistency
27
- - Keyword-centered chunking with balanced context distribution
28
-
29
- Performance Impact:
30
- - Improved semantic coherence in chunks
31
- - Better handling of medical terminology
32
- - Reduced redundancy in overlapping regions
33
- - Optimized for downstream retrieval tasks
34
- - Enhanced preservation of medical term context
35
- - More accurate chunk boundaries around keywords
36
-
37
- Testing:
38
- - Added token count validation in tests
39
- - Verified keyword preservation in chunks
40
- - Confirmed overlap handling
41
- - Tested with sample medical texts
42
- - Validated medical terminology preservation
43
- - Verified chunk context balance around keywords
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_processing.py CHANGED
@@ -21,6 +21,7 @@ from typing import List, Dict, Tuple, Any
21
  from sentence_transformers import SentenceTransformer
22
  from annoy import AnnoyIndex
23
  import logging
 
24
 
25
  # Setup logging
26
  logging.basicConfig(
@@ -141,10 +142,23 @@ class DataProcessor:
141
  chunk_size = chunk_size or self.chunk_size
142
  chunks = []
143
 
144
- # Tokenize full text once
145
- full_text_tokens = self.tokenizer.tokenize(text)
146
- total_tokens = len(full_text_tokens)
147
-
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  for i, keyword in enumerate(keywords):
149
  # Find keyword position in text (already lowercase)
150
  keyword_pos = text.find(keyword)
@@ -153,53 +167,66 @@ class DataProcessor:
153
  # Get the keyword text (already lowercase)
154
  actual_keyword = text[keyword_pos:keyword_pos + len(keyword)]
155
 
156
- # Get text before and after keyword
157
- text_before = text[:keyword_pos]
158
- text_after = text[keyword_pos + len(keyword):]
 
 
159
 
160
- # Tokenize each part separately
161
- tokens_before = self.tokenizer.tokenize(text_before)
162
- keyword_tokens = self.tokenizer.tokenize(actual_keyword)
163
- tokens_after = self.tokenizer.tokenize(text_after)
164
 
165
- # Calculate token positions
 
 
 
 
 
 
 
 
 
 
 
166
  keyword_start_pos = len(tokens_before)
 
 
 
 
167
  keyword_length = len(keyword_tokens)
168
 
169
- # Calculate how many tokens we want on each side of the keyword
170
  tokens_each_side = (chunk_size - keyword_length) // 2
171
-
172
- # Calculate chunk boundaries
173
  chunk_start = max(0, keyword_start_pos - tokens_each_side)
174
- chunk_end = min(total_tokens, keyword_start_pos + keyword_length + tokens_each_side)
175
 
176
  # Add overlap if possible
177
  if chunk_start > 0:
178
  chunk_start = max(0, chunk_start - self.chunk_overlap)
179
- if chunk_end < total_tokens:
180
- chunk_end = min(total_tokens, chunk_end + self.chunk_overlap)
181
 
182
- # Extract chunk tokens and convert to text
183
- chunk_tokens = full_text_tokens[chunk_start:chunk_end]
184
- chunk_text = self.tokenizer.convert_tokens_to_string(chunk_tokens)
185
 
186
- # Verify the keyword is in the chunk (direct comparison since all lowercase)
187
  if chunk_text and actual_keyword in chunk_text:
188
  chunk_info = {
189
  "text": chunk_text,
190
  "primary_keyword": actual_keyword,
191
  "all_matched_keywords": matched_keywords.lower(),
192
- "token_position": keyword_start_pos,
193
- "token_start": chunk_start,
194
- "token_end": chunk_end,
195
- "token_count": len(chunk_tokens),
196
  "chunk_id": f"{doc_id}_chunk_{i}" if doc_id else f"chunk_{i}",
197
  "source_doc_id": doc_id
198
  }
199
  chunks.append(chunk_info)
200
- logger.info(f"Created chunk for keyword '{actual_keyword}' with {len(chunk_tokens)} tokens")
201
  else:
202
- logger.warning(f"Failed to create valid chunk for keyword '{actual_keyword}' - keyword not found in generated chunk")
 
 
 
203
 
204
  return chunks
205
 
@@ -276,14 +303,17 @@ class DataProcessor:
276
 
277
  def process_emergency_chunks(self) -> List[Dict[str, Any]]:
278
  """Process emergency data into chunks"""
279
- logger.info("Processing emergency data into chunks...")
280
-
281
  if self.emergency_data is None:
282
  raise ValueError("Emergency data not loaded. Call load_filtered_data() first.")
283
 
284
  all_chunks = []
285
 
286
- for idx, row in self.emergency_data.iterrows():
 
 
 
 
 
287
  if pd.notna(row.get('clean_text')) and pd.notna(row.get('matched')):
288
  chunks = self.create_keyword_centered_chunks(
289
  text=row['clean_text'],
@@ -305,19 +335,22 @@ class DataProcessor:
305
  all_chunks.extend(chunks)
306
 
307
  self.emergency_chunks = all_chunks
308
- logger.info(f"Generated {len(all_chunks)} emergency chunks")
309
  return all_chunks
310
 
311
  def process_treatment_chunks(self) -> List[Dict[str, Any]]:
312
  """Process treatment data into chunks"""
313
- logger.info("Processing treatment data into chunks...")
314
-
315
  if self.treatment_data is None:
316
  raise ValueError("Treatment data not loaded. Call load_filtered_data() first.")
317
 
318
  all_chunks = []
319
 
320
- for idx, row in self.treatment_data.iterrows():
 
 
 
 
 
321
  if (pd.notna(row.get('clean_text')) and
322
  pd.notna(row.get('treatment_matched'))):
323
 
@@ -343,13 +376,39 @@ class DataProcessor:
343
  all_chunks.extend(chunks)
344
 
345
  self.treatment_chunks = all_chunks
346
- logger.info(f"Generated {len(all_chunks)} treatment chunks")
347
  return all_chunks
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def generate_embeddings(self, chunks: List[Dict[str, Any]],
350
  chunk_type: str = "emergency") -> np.ndarray:
351
  """
352
- Generate embeddings for chunks
353
 
354
  Args:
355
  chunks: List of chunk dictionaries
@@ -358,28 +417,78 @@ class DataProcessor:
358
  Returns:
359
  numpy array of embeddings
360
  """
361
- logger.info(f"Generating embeddings for {len(chunks)} {chunk_type} chunks...")
362
-
363
- # Load model if not already loaded
364
- model = self.load_embedding_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
- # Extract text from chunks
367
- texts = [chunk['text'] for chunk in chunks]
368
 
369
- # Generate embeddings in batches
370
- batch_size = 32
371
- embeddings = []
372
 
373
- for i in range(0, len(texts), batch_size):
374
- batch_texts = texts[i:i+batch_size]
375
- batch_embeddings = model.encode(batch_texts, show_progress_bar=True)
376
- embeddings.append(batch_embeddings)
377
 
378
- # Concatenate all embeddings
379
- all_embeddings = np.vstack(embeddings)
 
380
 
381
- logger.info(f"Generated embeddings shape: {all_embeddings.shape}")
382
- return all_embeddings
383
 
384
  def build_annoy_index(self, embeddings: np.ndarray,
385
  index_name: str, n_trees: int = 15) -> AnnoyIndex:
 
21
  from sentence_transformers import SentenceTransformer
22
  from annoy import AnnoyIndex
23
  import logging
24
+ from tqdm import tqdm
25
 
26
  # Setup logging
27
  logging.basicConfig(
 
142
  chunk_size = chunk_size or self.chunk_size
143
  chunks = []
144
 
145
+ # Calculate character-to-token ratio using a sample around the first keyword
146
+ if keywords:
147
+ first_keyword = keywords[0]
148
+ first_pos = text.find(first_keyword)
149
+ if first_pos != -1:
150
+ # Take a sample around the first keyword for ratio calculation
151
+ sample_start = max(0, first_pos - 100)
152
+ sample_end = min(len(text), first_pos + len(first_keyword) + 100)
153
+ sample_text = text[sample_start:sample_end]
154
+ sample_tokens = len(self.tokenizer.tokenize(sample_text))
155
+ chars_per_token = len(sample_text) / sample_tokens if sample_tokens > 0 else 4.0
156
+ else:
157
+ chars_per_token = 4.0 # Fallback ratio
158
+ else:
159
+ chars_per_token = 4.0 # Default ratio
160
+
161
+ # Process keywords
162
  for i, keyword in enumerate(keywords):
163
  # Find keyword position in text (already lowercase)
164
  keyword_pos = text.find(keyword)
 
167
  # Get the keyword text (already lowercase)
168
  actual_keyword = text[keyword_pos:keyword_pos + len(keyword)]
169
 
170
+ # Calculate rough window size using dynamic ratio
171
+ # Cap the rough chunk target token size to prevent tokenizer warnings
172
+ # Use 512 tokens as target (model's max limit)
173
+ ROUGH_CHUNK_TARGET_TOKENS = 512
174
+ char_window = int(ROUGH_CHUNK_TARGET_TOKENS * chars_per_token / 2)
175
 
176
+ # Get rough chunk boundaries in characters
177
+ rough_start = max(0, keyword_pos - char_window)
178
+ rough_end = min(len(text), keyword_pos + len(keyword) + char_window)
 
179
 
180
+ # Extract rough chunk for processing
181
+ rough_chunk = text[rough_start:rough_end]
182
+
183
+ # Find keyword's relative position in rough chunk
184
+ rel_pos = rough_chunk.find(actual_keyword)
185
+ if rel_pos == -1:
186
+ logger.debug(f"Could not locate keyword '{actual_keyword}' in rough chunk for doc {doc_id}")
187
+ continue
188
+
189
+ # Calculate token position by tokenizing text before keyword
190
+ text_before = rough_chunk[:rel_pos]
191
+ tokens_before = self.tokenizer.tokenize(text_before)
192
  keyword_start_pos = len(tokens_before)
193
+
194
+ # Tokenize necessary parts
195
+ chunk_tokens = self.tokenizer.tokenize(rough_chunk)
196
+ keyword_tokens = self.tokenizer.tokenize(actual_keyword)
197
  keyword_length = len(keyword_tokens)
198
 
199
+ # Calculate final chunk boundaries in tokens
200
  tokens_each_side = (chunk_size - keyword_length) // 2
 
 
201
  chunk_start = max(0, keyword_start_pos - tokens_each_side)
202
+ chunk_end = min(len(chunk_tokens), keyword_start_pos + keyword_length + tokens_each_side)
203
 
204
  # Add overlap if possible
205
  if chunk_start > 0:
206
  chunk_start = max(0, chunk_start - self.chunk_overlap)
207
+ if chunk_end < len(chunk_tokens):
208
+ chunk_end = min(len(chunk_tokens), chunk_end + self.chunk_overlap)
209
 
210
+ # Extract final tokens and convert to text
211
+ final_tokens = chunk_tokens[chunk_start:chunk_end]
212
+ chunk_text = self.tokenizer.convert_tokens_to_string(final_tokens)
213
 
214
+ # Verify keyword presence in final chunk
215
  if chunk_text and actual_keyword in chunk_text:
216
  chunk_info = {
217
  "text": chunk_text,
218
  "primary_keyword": actual_keyword,
219
  "all_matched_keywords": matched_keywords.lower(),
220
+ "token_count": len(final_tokens),
 
 
 
221
  "chunk_id": f"{doc_id}_chunk_{i}" if doc_id else f"chunk_{i}",
222
  "source_doc_id": doc_id
223
  }
224
  chunks.append(chunk_info)
 
225
  else:
226
+ logger.debug(f"Could not create chunk for keyword '{actual_keyword}' in doc {doc_id}")
227
+
228
+ if chunks:
229
+ logger.debug(f"Created {len(chunks)} chunks for document {doc_id or 'unknown'}")
230
 
231
  return chunks
232
 
 
303
 
304
  def process_emergency_chunks(self) -> List[Dict[str, Any]]:
305
  """Process emergency data into chunks"""
 
 
306
  if self.emergency_data is None:
307
  raise ValueError("Emergency data not loaded. Call load_filtered_data() first.")
308
 
309
  all_chunks = []
310
 
311
+ # Add progress bar with leave=False to avoid cluttering
312
+ for idx, row in tqdm(self.emergency_data.iterrows(),
313
+ total=len(self.emergency_data),
314
+ desc="Processing emergency documents",
315
+ unit="doc",
316
+ leave=False):
317
  if pd.notna(row.get('clean_text')) and pd.notna(row.get('matched')):
318
  chunks = self.create_keyword_centered_chunks(
319
  text=row['clean_text'],
 
335
  all_chunks.extend(chunks)
336
 
337
  self.emergency_chunks = all_chunks
338
+ logger.info(f"Completed processing emergency data: {len(all_chunks)} chunks generated")
339
  return all_chunks
340
 
341
  def process_treatment_chunks(self) -> List[Dict[str, Any]]:
342
  """Process treatment data into chunks"""
 
 
343
  if self.treatment_data is None:
344
  raise ValueError("Treatment data not loaded. Call load_filtered_data() first.")
345
 
346
  all_chunks = []
347
 
348
+ # Add progress bar with leave=False to avoid cluttering
349
+ for idx, row in tqdm(self.treatment_data.iterrows(),
350
+ total=len(self.treatment_data),
351
+ desc="Processing treatment documents",
352
+ unit="doc",
353
+ leave=False):
354
  if (pd.notna(row.get('clean_text')) and
355
  pd.notna(row.get('treatment_matched'))):
356
 
 
376
  all_chunks.extend(chunks)
377
 
378
  self.treatment_chunks = all_chunks
379
+ logger.info(f"Completed processing treatment data: {len(all_chunks)} chunks generated")
380
  return all_chunks
381
 
382
+ def _get_chunk_hash(self, text: str) -> str:
383
+ """Generate hash for chunk text to use as cache key"""
384
+ import hashlib
385
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
386
+
387
+ def _load_embedding_cache(self, cache_file: str) -> dict:
388
+ """Load embedding cache from file"""
389
+ import pickle
390
+ import os
391
+ if os.path.exists(cache_file):
392
+ try:
393
+ with open(cache_file, 'rb') as f:
394
+ return pickle.load(f)
395
+ except:
396
+ logger.warning(f"Could not load cache file {cache_file}, starting fresh")
397
+ return {}
398
+ return {}
399
+
400
+ def _save_embedding_cache(self, cache: dict, cache_file: str):
401
+ """Save embedding cache to file"""
402
+ import pickle
403
+ import os
404
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
405
+ with open(cache_file, 'wb') as f:
406
+ pickle.dump(cache, f)
407
+
408
  def generate_embeddings(self, chunks: List[Dict[str, Any]],
409
  chunk_type: str = "emergency") -> np.ndarray:
410
  """
411
+ Generate embeddings for chunks with caching support
412
 
413
  Args:
414
  chunks: List of chunk dictionaries
 
417
  Returns:
418
  numpy array of embeddings
419
  """
420
+ logger.info(f"Starting embedding generation for {len(chunks)} {chunk_type} chunks...")
421
+
422
+ # Cache setup
423
+ cache_dir = self.models_dir / "cache"
424
+ cache_dir.mkdir(parents=True, exist_ok=True)
425
+ cache_file = cache_dir / f"{chunk_type}_embeddings_cache.pkl"
426
+
427
+ # Load existing cache
428
+ cache = self._load_embedding_cache(str(cache_file))
429
+
430
+ cached_embeddings = []
431
+ to_embed = []
432
+
433
+ # Check cache for each chunk
434
+ for i, chunk in enumerate(chunks):
435
+ chunk_hash = self._get_chunk_hash(chunk['text'])
436
+ if chunk_hash in cache:
437
+ cached_embeddings.append((i, cache[chunk_hash]))
438
+ else:
439
+ to_embed.append((i, chunk_hash, chunk['text']))
440
+
441
+ logger.info(f"Cache status: {len(cached_embeddings)} cached, {len(to_embed)} new chunks to embed")
442
+
443
+ # Generate embeddings for new chunks
444
+ new_embeddings = []
445
+ if to_embed:
446
+ # Load model
447
+ model = self.load_embedding_model()
448
+ texts = [text for _, _, text in to_embed]
449
+
450
+ # Generate embeddings in batches with clear progress
451
+ batch_size = 32
452
+ total_batches = (len(texts) + batch_size - 1) // batch_size
453
+
454
+ logger.info(f"Processing {len(texts)} new {chunk_type} texts in {total_batches} batches...")
455
+
456
+ for i in tqdm(range(0, len(texts), batch_size),
457
+ desc=f"Embedding {chunk_type} subset",
458
+ total=total_batches,
459
+ unit="batch",
460
+ leave=False):
461
+ batch_texts = texts[i:i + batch_size]
462
+ batch_emb = model.encode(
463
+ batch_texts,
464
+ show_progress_bar=False
465
+ )
466
+ new_embeddings.extend(batch_emb)
467
+
468
+ # Update cache with new embeddings
469
+ for (_, chunk_hash, _), emb in zip(to_embed, new_embeddings):
470
+ cache[chunk_hash] = emb
471
+
472
+ # Save updated cache
473
+ self._save_embedding_cache(cache, str(cache_file))
474
+ logger.info(f"Updated cache with {len(new_embeddings)} new embeddings")
475
 
476
+ # Combine cached and new embeddings in correct order
477
+ all_embeddings = [None] * len(chunks)
478
 
479
+ # Place cached embeddings
480
+ for idx, emb in cached_embeddings:
481
+ all_embeddings[idx] = emb
482
 
483
+ # Place new embeddings
484
+ for (idx, _, _), emb in zip(to_embed, new_embeddings):
485
+ all_embeddings[idx] = emb
 
486
 
487
+ # Convert to numpy array
488
+ result = np.vstack(all_embeddings)
489
+ logger.info(f"Completed embedding generation: shape {result.shape}")
490
 
491
+ return result
 
492
 
493
  def build_annoy_index(self, embeddings: np.ndarray,
494
  index_name: str, n_trees: int = 15) -> AnnoyIndex:
tests/test_embedding_and_index.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from annoy import AnnoyIndex
3
+ import pytest
4
+ from data_processing import DataProcessor
5
+
6
+ @pytest.fixture(scope="module")
7
+ def processor():
8
+ return DataProcessor(base_dir=".")
9
+
10
+ def test_embedding_dimensions(processor):
11
+ # load emergency embeddings
12
+ emb = np.load(processor.models_dir / "embeddings" / "emergency_embeddings.npy")
13
+ expected_dim = processor.embedding_dim
14
+ assert emb.ndim == 2, f"Expected 2D array, got {emb.ndim}D"
15
+ assert emb.shape[1] == expected_dim, (
16
+ f"Expected embedding dimension {expected_dim}, got {emb.shape[1]}"
17
+ )
18
+
19
+ def test_annoy_search(processor):
20
+ # load embeddings
21
+ emb = np.load(processor.models_dir / "embeddings" / "emergency_embeddings.npy")
22
+ # load Annoy index
23
+ idx = AnnoyIndex(processor.embedding_dim, 'angular')
24
+ idx.load(str(processor.models_dir / "indices" / "annoy" / "emergency_index.ann"))
25
+ # perform a sample query
26
+ query_vec = emb[0]
27
+ ids, distances = idx.get_nns_by_vector(query_vec, 5, include_distances=True)
28
+ assert len(ids) == 5
29
+ assert all(0 <= d <= 2 for d in distances)