Spaces:
Sleeping
Sleeping
Commit
·
e61608e
1
Parent(s):
a2fb673
added
Browse files
app.py
CHANGED
@@ -206,6 +206,8 @@ def load_queries(dataset_name):
|
|
206 |
|
207 |
|
208 |
def evaluate(qrels, results, k_values):
|
|
|
|
|
209 |
evaluator = pytrec_eval.RelevanceEvaluator(
|
210 |
qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values}
|
211 |
)
|
@@ -232,13 +234,16 @@ def run_evaluation(dataset, postfix):
|
|
232 |
logger.info(f"Sample input text: {input_texts[0]}")
|
233 |
|
234 |
q_reps = model.encode(input_texts)
|
|
|
235 |
logger.info(f"Encoded query representations shape: {q_reps.shape}")
|
236 |
|
237 |
all_scores, psg_indices = search_queries(dataset, q_reps)
|
238 |
|
239 |
-
results = {qid:
|
240 |
-
|
241 |
-
|
|
|
|
|
242 |
logger.info(f"Number of results: {len(results)}")
|
243 |
logger.info(f"Sample result: {list(results.items())[0]}")
|
244 |
|
|
|
206 |
|
207 |
|
208 |
def evaluate(qrels, results, k_values):
|
209 |
+
qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
|
210 |
+
results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
|
211 |
evaluator = pytrec_eval.RelevanceEvaluator(
|
212 |
qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values}
|
213 |
)
|
|
|
234 |
logger.info(f"Sample input text: {input_texts[0]}")
|
235 |
|
236 |
q_reps = model.encode(input_texts)
|
237 |
+
logger.info(f"Encoded query first five: {q_reps[0][:5]}")
|
238 |
logger.info(f"Encoded query representations shape: {q_reps.shape}")
|
239 |
|
240 |
all_scores, psg_indices = search_queries(dataset, q_reps)
|
241 |
|
242 |
+
results = {str(qid): {str(doc_id): float(score) for doc_id, score in zip(doc_ids, scores)}
|
243 |
+
for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices)}
|
244 |
+
qrels[dataset] = {str(qid): {str(doc_id): rel for doc_id, rel in rels.items()}
|
245 |
+
for qid, rels in qrels[dataset].items()}
|
246 |
+
|
247 |
logger.info(f"Number of results: {len(results)}")
|
248 |
logger.info(f"Sample result: {list(results.items())[0]}")
|
249 |
|