Testys commited on
Commit
7ccde22
·
1 Parent(s): b01eb37

Update search_utils.py

Browse files
Files changed (1) hide show
  1. search_utils.py +109 -1
search_utils.py CHANGED
@@ -4,6 +4,62 @@ import faiss
4
  from pathlib import Path
5
  from sentence_transformers import SentenceTransformer, util
6
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class SemanticSearch:
9
  def __init__(self, shard_dir="compressed_shards"):
@@ -11,6 +67,7 @@ class SemanticSearch:
11
  self.shard_dir.mkdir(exist_ok=True, parents=True)
12
  self.model = None
13
  self.index_shards = []
 
14
 
15
  @st.cache_resource
16
  def load_model(_self):
@@ -61,4 +118,55 @@ class SemanticSearch:
61
  """Threshold-filtered search"""
62
  results = self.search(query, top_k*2)
63
  filtered = results[results['similarity'] > similarity_threshold].head(top_k)
64
- return filtered.reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from pathlib import Path
5
  from sentence_transformers import SentenceTransformer, util
6
  import streamlit as st
7
+ import zipfile
8
+ import pandas as pd
9
+ from pathlib import Path
10
+ import streamlit as st
11
+
12
+ class MetadataManager:
13
+ def __init__(self):
14
+ self.shard_dir = Path("metadata_shards")
15
+ self.shard_map = {}
16
+ self.loaded_shards = {}
17
+ self._ensure_unzipped()
18
+ self._build_shard_map()
19
+
20
+ def _ensure_unzipped(self):
21
+ """Extract metadata shards from zip if needed"""
22
+ if not self.shard_dir.exists():
23
+ zip_path = Path("metadata_shards.zip")
24
+ if zip_path.exists():
25
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
26
+ zip_ref.extractall(self.shard_dir)
27
+ st.toast("✅ Successfully extracted metadata shards!", icon="📦")
28
+ else:
29
+ raise FileNotFoundError("No metadata shards found!")
30
+
31
+ def _build_shard_map(self):
32
+ """Map index ranges to shard files"""
33
+ for f in self.shard_dir.glob("*.parquet"):
34
+ parts = f.stem.split("_")
35
+ self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
36
+
37
+ def get_metadata(self, indices):
38
+ """Retrieve metadata for specific indices"""
39
+ results = []
40
+ shard_groups = {}
41
+
42
+ # Group indices by shard
43
+ for idx in indices:
44
+ for (start, end), shard in self.shard_map.items():
45
+ if start <= idx <= end:
46
+ if shard not in shard_groups:
47
+ shard_groups[shard] = []
48
+ shard_groups[shard].append(idx - start)
49
+ break
50
+
51
+ # Load required shards
52
+ for shard, local_indices in shard_groups.items():
53
+ if shard not in self.loaded_shards:
54
+ self.loaded_shards[shard] = pd.read_parquet(
55
+ self.shard_dir / shard,
56
+ columns=["title", "summary", "source"]
57
+ )
58
+
59
+ results.append(self.loaded_shards[shard].iloc[local_indices])
60
+
61
+ return pd.concat(results).reset_index(drop=True)
62
+
63
 
64
  class SemanticSearch:
65
  def __init__(self, shard_dir="compressed_shards"):
 
67
  self.shard_dir.mkdir(exist_ok=True, parents=True)
68
  self.model = None
69
  self.index_shards = []
70
+ self.metadata_mgr = MetadataManager()
71
 
72
  @st.cache_resource
73
  def load_model(_self):
 
118
  """Threshold-filtered search"""
119
  results = self.search(query, top_k*2)
120
  filtered = results[results['similarity'] > similarity_threshold].head(top_k)
121
+ return filtered.reset_index(drop=True)
122
+
123
+
124
+
125
+
126
+ class MetadataManager:
127
+ def __init__(self, repo_id, shard_dir="metadata_shards"):
128
+ self.repo_id = repo_id
129
+ self.shard_dir = Path(shard_dir)
130
+ self.shard_map = {}
131
+ self.loaded_shards = {}
132
+ self._build_shard_map()
133
+
134
+ def _build_shard_map(self):
135
+ """Map index ranges to shard files"""
136
+ for f in self.shard_dir.glob("*.parquet"):
137
+ parts = f.stem.split("_")
138
+ self.shard_map[(int(parts[1]), int(parts[2]))] = f.name
139
+
140
+ def _download_shard(self, shard_name):
141
+ """Download missing shards on demand"""
142
+ if not (self.shard_dir/shard_name).exists():
143
+ hf_hub_download(
144
+ repo_id=self.repo_id,
145
+ filename=f"metadata_shards/{shard_name}",
146
+ local_dir=self.shard_dir,
147
+ cache_dir="metadata_cache"
148
+ )
149
+
150
+ def get_metadata(self, indices):
151
+ """Retrieve metadata for specific indices"""
152
+ results = []
153
+
154
+ # Group indices by shard
155
+ shard_groups = {}
156
+ for idx in indices:
157
+ for (start, end), shard in self.shard_map.items():
158
+ if start <= idx <= end:
159
+ if shard not in shard_groups:
160
+ shard_groups[shard] = []
161
+ shard_groups[shard].append(idx - start)
162
+ break
163
+
164
+ # Process each required shard
165
+ for shard, local_indices in shard_groups.items():
166
+ if shard not in self.loaded_shards:
167
+ self._download_shard(shard)
168
+ self.loaded_shards[shard] = pd.read_parquet(self.shard_dir/shard)
169
+
170
+ results.append(self.loaded_shards[shard].iloc[local_indices])
171
+
172
+ return pd.concat(results).reset_index(drop=True)