Testys commited on
Commit
a223079
·
1 Parent(s): d1f3bf3

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +65 -28
search_utils.py CHANGED
@@ -36,29 +36,46 @@ class MetadataManager:
36
  self.total_docs = max(self.total_docs, end + 1)
37
 
38
  def get_metadata(self, global_indices):
39
- """Retrieve metadata for global indices"""
40
- results = []
41
- shard_groups = {}
42
 
43
- # Organize indices by their respective shards
44
- for idx in global_indices:
 
 
 
 
 
 
 
45
  for (start, end), shard in self.shard_map.items():
46
  if start <= idx <= end:
47
  if shard not in shard_groups:
48
  shard_groups[shard] = []
49
  shard_groups[shard].append(idx - start)
 
50
  break
51
-
52
- # Load and process required shards
 
 
 
53
  for shard, local_indices in shard_groups.items():
54
- if shard not in self.loaded_shards:
55
- self.loaded_shards[shard] = pd.read_parquet(
56
- self.shard_dir / shard,
57
- columns=["title", "summary", "source"]
58
- )
59
- results.append(self.loaded_shards[shard].iloc[local_indices])
60
-
61
- return pd.concat(results).reset_index(drop=True)
 
 
 
 
 
 
62
 
63
  class SemanticSearch:
64
  def __init__(self):
@@ -89,19 +106,39 @@ class SemanticSearch:
89
  return sum(self.shard_sizes[:shard_idx]) + local_idx
90
 
91
  def search(self, query, top_k=5):
92
- """Main search functionality"""
93
- query_embedding = self.model.encode([query], convert_to_numpy=True)
 
 
 
 
 
 
 
 
94
  all_distances = []
95
  all_global_indices = []
96
-
97
- # Search across all shards
98
  for shard_idx, index in enumerate(self.index_shards):
99
- distances, indices = index.search(query_embedding, top_k)
100
- global_indices = [self._global_index(shard_idx, idx) for idx in indices[0]]
101
- all_distances.extend(distances[0])
102
- all_global_indices.extend(global_indices)
103
-
104
- # Process and format results
105
- results = self.metadata_mgr.get_metadata(all_global_indices)
106
- results['similarity'] = 1 - (np.array(all_distances) / 2) # Convert L2 to cosine
107
- return results.sort_values('similarity', ascending=False).head(top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.total_docs = max(self.total_docs, end + 1)
37
 
38
  def get_metadata(self, global_indices):
39
+ """Retrieve metadata with validation"""
40
+ if not global_indices:
41
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
42
 
43
+ # Filter valid indices
44
+ valid_indices = [idx for idx in global_indices if 0 <= idx < self.total_docs]
45
+ if not valid_indices:
46
+ return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
47
+
48
+ # Group indices by shard with boundary check
49
+ shard_groups = {}
50
+ for idx in valid_indices:
51
+ found = False
52
  for (start, end), shard in self.shard_map.items():
53
  if start <= idx <= end:
54
  if shard not in shard_groups:
55
  shard_groups[shard] = []
56
  shard_groups[shard].append(idx - start)
57
+ found = True
58
  break
59
+ if not found:
60
+ st.warning(f"Index {idx} out of shard range (0-{self.total_docs-1})")
61
+
62
+ # Load and process shards
63
+ results = []
64
  for shard, local_indices in shard_groups.items():
65
+ try:
66
+ if shard not in self.loaded_shards:
67
+ self.loaded_shards[shard] = pd.read_parquet(
68
+ self.shard_dir / shard,
69
+ columns=["title", "summary", "source"]
70
+ )
71
+
72
+ if local_indices:
73
+ results.append(self.loaded_shards[shard].iloc[local_indices])
74
+ except Exception as e:
75
+ st.error(f"Error loading shard {shard}: {str(e)}")
76
+ continue
77
+
78
+ return pd.concat(results).reset_index(drop=True) if results else pd.DataFrame()
79
 
80
  class SemanticSearch:
81
  def __init__(self):
 
106
  return sum(self.shard_sizes[:shard_idx]) + local_idx
107
 
108
  def search(self, query, top_k=5):
109
+ """Search with validation"""
110
+ if not query or not self.index_shards:
111
+ return pd.DataFrame()
112
+
113
+ try:
114
+ query_embedding = self.model.encode([query], convert_to_numpy=True)
115
+ except Exception as e:
116
+ st.error(f"Query encoding failed: {str(e)}")
117
+ return pd.DataFrame()
118
+
119
  all_distances = []
120
  all_global_indices = []
121
+
122
+ # Search with index validation
123
  for shard_idx, index in enumerate(self.index_shards):
124
+ if index.ntotal == 0:
125
+ continue
126
+
127
+ try:
128
+ distances, indices = index.search(query_embedding, top_k)
129
+ valid_indices = [idx for idx in indices[0] if 0 <= idx < index.ntotal]
130
+ global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
131
+
132
+ all_distances.extend(distances[0][:len(valid_indices)])
133
+ all_global_indices.extend(global_indices)
134
+ except Exception as e:
135
+ st.error(f"Search failed in shard {shard_idx}: {str(e)}")
136
+ continue
137
+
138
+ # Ensure equal array lengths
139
+ min_length = min(len(all_distances), len(all_global_indices))
140
+ return self._process_results(
141
+ np.array(all_distances[:min_length]),
142
+ np.array(all_global_indices[:min_length]),
143
+ top_k
144
+ )