orionweller commited on
Commit
c837e28
·
1 Parent(s): 05ff7af

load faiss

Browse files
Files changed (1) hide show
  1. app.py +37 -38
app.py CHANGED
@@ -7,7 +7,6 @@ import torch
7
  import torch.nn.functional as F
8
  from transformers import AutoTokenizer, AutoModel, set_seed
9
  from peft import PeftModel
10
- from tevatron.retriever.searcher import FaissFlatSearcher
11
  import logging
12
  import os
13
  import json
@@ -47,7 +46,6 @@ current_dataset = "scifact"
47
  def log_system_info():
48
  logger.info("System Information:")
49
  logger.info(f"Python version: {sys.version}")
50
- # logger.info(f"Platform: {platform.platform()}")
51
 
52
  logger.info("\nPackage Versions:")
53
  logger.info(f"torch: {torch.__version__}")
@@ -55,7 +53,6 @@ def log_system_info():
55
  logger.info(f"peft: {peft.__version__}")
56
  logger.info(f"faiss: {faiss.__version__}")
57
  logger.info(f"gradio: {gr.__version__}")
58
- # logger.info(f"pytrec_eval: {pytrec_eval.__version__}")
59
  logger.info(f"ir_datasets: {ir_datasets.__version__}")
60
 
61
  if torch.cuda.is_available():
@@ -70,11 +67,8 @@ def log_system_info():
70
  logger.info("\nCUDA Information:")
71
  logger.info("CUDA available: No")
72
 
73
-
74
  log_system_info()
75
 
76
-
77
-
78
  def pool(last_hidden_states, attention_mask, pool_type="last"):
79
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
80
 
@@ -151,18 +145,45 @@ class RepLlamaModel:
151
  self.model = self.model.cpu()
152
  return np.concatenate(all_embeddings, axis=0)
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- def load_faiss_index(dataset_name):
156
- index_path = f"{dataset_name}/faiss_index.bin"
157
- if os.path.exists(index_path):
158
- logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
159
- return faiss.read_index(index_path)
160
- return None
161
 
162
  def search_queries(dataset_name, q_reps, depth=1000):
163
- faiss_index = load_faiss_index(dataset_name)
164
- if faiss_index is None:
165
- raise ValueError(f"No FAISS index found for dataset {dataset_name}")
166
 
167
  logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
168
 
@@ -171,28 +192,11 @@ def search_queries(dataset_name, q_reps, depth=1000):
171
 
172
  logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
173
  logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
174
-
175
 
176
  psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
177
 
178
  return all_scores, np.array(psg_indices)
179
 
180
- def load_corpus_lookups(dataset_name):
181
- global corpus_lookups
182
- corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
183
- index_files = glob.glob(corpus_path)
184
- # sort them
185
- index_files.sort(key=lambda x: int(x.split('.')[-2]))
186
-
187
- corpus_lookups[dataset_name] = []
188
- for file in index_files:
189
- with open(file, 'rb') as f:
190
- _, p_lookup = pickle.load(f)
191
- corpus_lookups[dataset_name] += p_lookup
192
-
193
- logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
194
- logger.info(f"Sample corpus lookup entry: {corpus_lookups[dataset_name][:10]}")
195
-
196
  def load_queries(dataset_name):
197
  global queries, q_lookups, qrels, query2qid
198
  dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
@@ -214,7 +218,6 @@ def load_queries(dataset_name):
214
  logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
215
  logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
216
 
217
-
218
  def evaluate(qrels, results, k_values):
219
  qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
220
  results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
@@ -273,7 +276,6 @@ def run_evaluation(dataset, postfix):
273
  logger.info(f"Number of results: {len(results)}")
274
  logger.info(f"Sample result: {list(results.items())[0]}")
275
 
276
- # Add these lines
277
  logger.info(f"Number of queries in qrels: {len(qrels[dataset])}")
278
  logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}")
279
  logger.info(f"Number of queries in results: {len(results)}")
@@ -293,13 +295,10 @@ def run_evaluation(dataset, postfix):
293
  def gradio_interface(dataset, postfix):
294
  return run_evaluation(dataset, postfix)
295
 
296
-
297
  if model is None:
298
  model = RepLlamaModel(model_name_or_path=CUR_MODEL)
299
- load_corpus_lookups(current_dataset)
300
  load_queries(current_dataset)
301
 
302
-
303
  # Create Gradio interface
304
  iface = gr.Interface(
305
  fn=gradio_interface,
@@ -318,4 +317,4 @@ iface = gr.Interface(
318
  )
319
 
320
  # Launch the interface
321
- iface.launch()
 
7
  import torch.nn.functional as F
8
  from transformers import AutoTokenizer, AutoModel, set_seed
9
  from peft import PeftModel
 
10
  import logging
11
  import os
12
  import json
 
46
  def log_system_info():
47
  logger.info("System Information:")
48
  logger.info(f"Python version: {sys.version}")
 
49
 
50
  logger.info("\nPackage Versions:")
51
  logger.info(f"torch: {torch.__version__}")
 
53
  logger.info(f"peft: {peft.__version__}")
54
  logger.info(f"faiss: {faiss.__version__}")
55
  logger.info(f"gradio: {gr.__version__}")
 
56
  logger.info(f"ir_datasets: {ir_datasets.__version__}")
57
 
58
  if torch.cuda.is_available():
 
67
  logger.info("\nCUDA Information:")
68
  logger.info("CUDA available: No")
69
 
 
70
  log_system_info()
71
 
 
 
72
  def pool(last_hidden_states, attention_mask, pool_type="last"):
73
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
74
 
 
145
  self.model = self.model.cpu()
146
  return np.concatenate(all_embeddings, axis=0)
147
 
148
+ def load_corpus_embeddings(dataset_name):
149
+ corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
150
+ index_files = glob.glob(corpus_path)
151
+ index_files.sort(key=lambda x: int(x.split('.')[-2]))
152
+
153
+ all_embeddings = []
154
+ corpus_lookups = []
155
+
156
+ for file in index_files:
157
+ with open(file, 'rb') as f:
158
+ embeddings, p_lookup = pickle.load(f)
159
+ all_embeddings.append(embeddings)
160
+ corpus_lookups.extend(p_lookup)
161
+
162
+ all_embeddings = np.concatenate(all_embeddings, axis=0)
163
+ logger.info(f"Loaded corpus embeddings for {dataset_name}. Shape: {all_embeddings.shape}")
164
+
165
+ return all_embeddings, corpus_lookups
166
+
167
+ def create_faiss_index(embeddings):
168
+ dimension = embeddings.shape[1]
169
+ index = faiss.IndexFlatIP(dimension)
170
+ index.add(embeddings)
171
+ logger.info(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}")
172
+ return index
173
+
174
+ def load_or_create_faiss_index(dataset_name):
175
+ embeddings, corpus_lookups = load_corpus_embeddings(dataset_name)
176
+ index = create_faiss_index(embeddings)
177
+ return index, corpus_lookups
178
 
179
+ def initialize_faiss_and_corpus(dataset_name):
180
+ global corpus_lookups
181
+ index, corpus_lookups[dataset_name] = load_or_create_faiss_index(dataset_name)
182
+ logger.info(f"Initialized FAISS index and corpus lookups for {dataset_name}")
183
+ return index
 
184
 
185
  def search_queries(dataset_name, q_reps, depth=1000):
186
+ faiss_index = initialize_faiss_and_corpus(dataset_name)
 
 
187
 
188
  logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
189
 
 
192
 
193
  logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
194
  logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
 
195
 
196
  psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
197
 
198
  return all_scores, np.array(psg_indices)
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def load_queries(dataset_name):
201
  global queries, q_lookups, qrels, query2qid
202
  dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
 
218
  logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
219
  logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
220
 
 
221
  def evaluate(qrels, results, k_values):
222
  qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
223
  results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
 
276
  logger.info(f"Number of results: {len(results)}")
277
  logger.info(f"Sample result: {list(results.items())[0]}")
278
 
 
279
  logger.info(f"Number of queries in qrels: {len(qrels[dataset])}")
280
  logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}")
281
  logger.info(f"Number of queries in results: {len(results)}")
 
295
  def gradio_interface(dataset, postfix):
296
  return run_evaluation(dataset, postfix)
297
 
 
298
  if model is None:
299
  model = RepLlamaModel(model_name_or_path=CUR_MODEL)
 
300
  load_queries(current_dataset)
301
 
 
302
  # Create Gradio interface
303
  iface = gr.Interface(
304
  fn=gradio_interface,
 
317
  )
318
 
319
  # Launch the interface
320
+ iface.launch(share=False)