orionweller commited on
Commit
e61608e
·
1 Parent(s): a2fb673
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -206,6 +206,8 @@ def load_queries(dataset_name):
206
 
207
 
208
  def evaluate(qrels, results, k_values):
 
 
209
  evaluator = pytrec_eval.RelevanceEvaluator(
210
  qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values}
211
  )
@@ -232,13 +234,16 @@ def run_evaluation(dataset, postfix):
232
  logger.info(f"Sample input text: {input_texts[0]}")
233
 
234
  q_reps = model.encode(input_texts)
 
235
  logger.info(f"Encoded query representations shape: {q_reps.shape}")
236
 
237
  all_scores, psg_indices = search_queries(dataset, q_reps)
238
 
239
- results = {qid: dict(zip(doc_ids, map(float, scores)))
240
- for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices)}
241
-
 
 
242
  logger.info(f"Number of results: {len(results)}")
243
  logger.info(f"Sample result: {list(results.items())[0]}")
244
 
 
206
 
207
 
208
  def evaluate(qrels, results, k_values):
209
+ qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
210
+ results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
211
  evaluator = pytrec_eval.RelevanceEvaluator(
212
  qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values}
213
  )
 
234
  logger.info(f"Sample input text: {input_texts[0]}")
235
 
236
  q_reps = model.encode(input_texts)
237
+ logger.info(f"Encoded query first five: {q_reps[0][:5]}")
238
  logger.info(f"Encoded query representations shape: {q_reps.shape}")
239
 
240
  all_scores, psg_indices = search_queries(dataset, q_reps)
241
 
242
+ results = {str(qid): {str(doc_id): float(score) for doc_id, score in zip(doc_ids, scores)}
243
+ for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices)}
244
+ qrels[dataset] = {str(qid): {str(doc_id): rel for doc_id, rel in rels.items()}
245
+ for qid, rels in qrels[dataset].items()}
246
+
247
  logger.info(f"Number of results: {len(results)}")
248
  logger.info(f"Sample result: {list(results.items())[0]}")
249