orionweller commited on
Commit
c74038e
·
1 Parent(s): 0f7ca6c
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -39,6 +39,7 @@ corpus_lookups = {}
39
  queries = {}
40
  q_lookups = {}
41
  qrels = {}
 
42
  datasets = ["scifact"]
43
  current_dataset = "scifact"
44
 
@@ -157,7 +158,7 @@ def load_faiss_index(dataset_name):
157
  return faiss.read_index(index_path)
158
  return None
159
 
160
- def search_queries(dataset_name, q_reps, depth=1000):
161
  faiss_index = load_faiss_index(dataset_name)
162
  if faiss_index is None:
163
  raise ValueError(f"No FAISS index found for dataset {dataset_name}")
@@ -169,6 +170,7 @@ def search_queries(dataset_name, q_reps, depth=1000):
169
 
170
  logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
171
  logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
 
172
 
173
  psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
174
 
@@ -191,15 +193,17 @@ def load_corpus_lookups(dataset_name):
191
  logger.info(f"Sample corpus lookup entry: {corpus_lookups[dataset_name][:10]}")
192
 
193
  def load_queries(dataset_name):
194
- global queries, q_lookups, qrels
195
  dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
196
 
197
  queries[dataset_name] = []
 
198
  q_lookups[dataset_name] = {}
199
  qrels[dataset_name] = {}
200
  for query in dataset.queries_iter():
201
  queries[dataset_name].append(query.text)
202
  q_lookups[dataset_name][query.query_id] = query.text
 
203
 
204
  for qrel in dataset.qrels_iter():
205
  if qrel.query_id not in qrels[dataset_name]:
@@ -231,7 +235,7 @@ def evaluate(qrels, results, k_values):
231
 
232
  @spaces.GPU
233
  def run_evaluation(dataset, postfix):
234
- global current_dataset, queries, model
235
  current_dataset = dataset
236
 
237
  input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[current_dataset]]
@@ -248,7 +252,8 @@ def run_evaluation(dataset, postfix):
248
  logging.info(f"Number of queries in q_lookups: {len(q_lookups[dataset])}")
249
  logging.info("Size of all_scores: " + str(len(all_scores)))
250
  logging.info("Size of psg_indices: " + str(len(psg_indices)))
251
- for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices):
 
252
  qid_str = str(qid)
253
  results[qid_str] = {}
254
  for doc_id, score in zip(doc_ids, scores):
 
39
  queries = {}
40
  q_lookups = {}
41
  qrels = {}
42
+ query2qid = {}
43
  datasets = ["scifact"]
44
  current_dataset = "scifact"
45
 
 
158
  return faiss.read_index(index_path)
159
  return None
160
 
161
+ def search_queries(dataset_name, q_reps, depth=100):
162
  faiss_index = load_faiss_index(dataset_name)
163
  if faiss_index is None:
164
  raise ValueError(f"No FAISS index found for dataset {dataset_name}")
 
170
 
171
  logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
172
  logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
173
+
174
 
175
  psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
176
 
 
193
  logger.info(f"Sample corpus lookup entry: {corpus_lookups[dataset_name][:10]}")
194
 
195
  def load_queries(dataset_name):
196
+ global queries, q_lookups, qrels, query2qid
197
  dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
198
 
199
  queries[dataset_name] = []
200
+ query2qid[dataset_name] = {}
201
  q_lookups[dataset_name] = {}
202
  qrels[dataset_name] = {}
203
  for query in dataset.queries_iter():
204
  queries[dataset_name].append(query.text)
205
  q_lookups[dataset_name][query.query_id] = query.text
206
+ query2qid[dataset_name][query.text] = query.query_id
207
 
208
  for qrel in dataset.qrels_iter():
209
  if qrel.query_id not in qrels[dataset_name]:
 
235
 
236
  @spaces.GPU
237
  def run_evaluation(dataset, postfix):
238
+ global current_dataset, queries, model, query2qid
239
  current_dataset = dataset
240
 
241
  input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[current_dataset]]
 
252
  logging.info(f"Number of queries in q_lookups: {len(q_lookups[dataset])}")
253
  logging.info("Size of all_scores: " + str(len(all_scores)))
254
  logging.info("Size of psg_indices: " + str(len(psg_indices)))
255
+ for query, scores, doc_ids in zip(queries, all_scores, psg_indices):
256
+ qid = query2qid[dataset][query]
257
  qid_str = str(qid)
258
  results[qid_str] = {}
259
  for doc_id, score in zip(doc_ids, scores):