Testys commited on
Commit
5ee0a10
·
1 Parent(s): 17ad44f

Create search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +64 -0
search_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import faiss
4
+ from pathlib import Path
5
+ from sentence_transformers import SentenceTransformer, util
6
+ import streamlit as st
7
+
8
+ class SemanticSearch:
9
+ def __init__(self, shard_dir="compressed_shards"):
10
+ self.shard_dir = Path(shard_dir)
11
+ self.shard_dir.mkdir(exist_ok=True, parents=True)
12
+ self.model = None
13
+ self.index_shards = []
14
+
15
+ @st.cache_resource
16
+ def load_model(_self):
17
+ return SentenceTransformer('all-MiniLM-L6-v2')
18
+
19
+ def initialize_system(self):
20
+ self.model = self.load_model()
21
+ self._load_index_shards()
22
+
23
+ def _load_index_shards(self):
24
+ """Load FAISS shards directly from local directory"""
25
+ for shard_path in sorted(self.shard_dir.glob("*.index")):
26
+ self.index_shards.append(faiss.read_index(str(shard_path)))
27
+
28
+ def search(self, query, top_k=5):
29
+ """Search across all shards"""
30
+ query_embedding = self.model.encode([query], convert_to_numpy=True)
31
+ all_scores = []
32
+ all_indices = []
33
+
34
+ for shard_idx, index in enumerate(self.index_shards):
35
+ distances, indices = index.search(query_embedding, top_k)
36
+ # Convert local indices to global shard offsets
37
+ global_indices = [
38
+ self._calculate_global_index(shard_idx, idx)
39
+ for idx in indices[0]
40
+ ]
41
+ all_scores.extend(distances[0])
42
+ all_indices.extend(global_indices)
43
+
44
+ return self._process_results(np.array(all_scores), np.array(all_indices), top_k)
45
+
46
+ def _calculate_global_index(self, shard_idx, local_idx):
47
+ """Convert shard-local index to global index"""
48
+ # Implement your specific shard indexing logic here
49
+ # Example: return f"{shard_idx}-{local_idx}"
50
+ return local_idx # Simple version if using unique IDs
51
+
52
+ def _process_results(self, distances, indices, top_k):
53
+ """Format search results"""
54
+ results = pd.DataFrame({
55
+ 'global_index': indices,
56
+ 'similarity': 1 - (distances / 2) # L2 to cosine approximation
57
+ })
58
+ return results.sort_values('similarity', ascending=False).head(top_k)
59
+
60
+ def search_with_threshold(self, query, top_k=5, similarity_threshold=0.6):
61
+ """Threshold-filtered search"""
62
+ results = self.search(query, top_k*2)
63
+ filtered = results[results['similarity'] > similarity_threshold].head(top_k)
64
+ return filtered.reset_index(drop=True)