Kevin Hu
commited on
Commit
·
1f1194f
1
Parent(s):
cfb71b4
fix benchmark issue (#3324)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/benchmark.py +18 -9
rag/benchmark.py
CHANGED
|
@@ -30,6 +30,7 @@ from rag.utils.es_conn import ELASTICSEARCH
|
|
| 30 |
from ranx import evaluate
|
| 31 |
import pandas as pd
|
| 32 |
from tqdm import tqdm
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class Benchmark:
|
|
@@ -50,8 +51,8 @@ class Benchmark:
|
|
| 50 |
query_list = list(qrels.keys())
|
| 51 |
for query in query_list:
|
| 52 |
|
| 53 |
-
ranks = retrievaler.retrieval(query, self.embd_mdl,
|
| 54 |
-
[self.kb.id],
|
| 55 |
0.0, self.vector_similarity_weight)
|
| 56 |
for c in ranks["chunks"]:
|
| 57 |
if "vector" in c:
|
|
@@ -105,7 +106,9 @@ class Benchmark:
|
|
| 105 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
| 106 |
d = {
|
| 107 |
"id": get_uuid(),
|
| 108 |
-
"kb_id": self.kb.id
|
|
|
|
|
|
|
| 109 |
}
|
| 110 |
tokenize(d, text, "english")
|
| 111 |
docs.append(d)
|
|
@@ -137,7 +140,10 @@ class Benchmark:
|
|
| 137 |
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
| 138 |
data.iloc[i]["search_results"]['search_context']):
|
| 139 |
d = {
|
| 140 |
-
"id": get_uuid()
|
|
|
|
|
|
|
|
|
|
| 141 |
}
|
| 142 |
tokenize(d, text, "english")
|
| 143 |
docs.append(d)
|
|
@@ -182,7 +188,10 @@ class Benchmark:
|
|
| 182 |
text = corpus_total[tmp_data.iloc[i]['docid']]
|
| 183 |
rel = tmp_data.iloc[i]['relevance']
|
| 184 |
d = {
|
| 185 |
-
"id": get_uuid()
|
|
|
|
|
|
|
|
|
|
| 186 |
}
|
| 187 |
tokenize(d, text, 'english')
|
| 188 |
docs.append(d)
|
|
@@ -204,7 +213,7 @@ class Benchmark:
|
|
| 204 |
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
| 205 |
key = run_keys[run_i]
|
| 206 |
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
| 207 |
-
'ndcg@10': evaluate({key: qrels[key]}, {key: run[key]}, "ndcg@10")})
|
| 208 |
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
| 209 |
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
| 210 |
f.write('## Score For Every Query\n')
|
|
@@ -222,12 +231,12 @@ class Benchmark:
|
|
| 222 |
if dataset == "ms_marco_v1.1":
|
| 223 |
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
| 224 |
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
|
| 225 |
-
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
| 226 |
self.save_results(qrels, run, texts, dataset, file_path)
|
| 227 |
if dataset == "trivia_qa":
|
| 228 |
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
| 229 |
run = self._get_retrieval(qrels, "benchmark_trivia_qa")
|
| 230 |
-
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
| 231 |
self.save_results(qrels, run, texts, dataset, file_path)
|
| 232 |
if dataset == "miracl":
|
| 233 |
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
|
@@ -248,7 +257,7 @@ class Benchmark:
|
|
| 248 |
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
| 249 |
"benchmark_miracl_" + lang)
|
| 250 |
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
|
| 251 |
-
print(dataset, evaluate(qrels, run, ["ndcg@10", "map@5", "mrr"]))
|
| 252 |
self.save_results(qrels, run, texts, dataset, file_path)
|
| 253 |
|
| 254 |
|
|
|
|
| 30 |
from ranx import evaluate
|
| 31 |
import pandas as pd
|
| 32 |
from tqdm import tqdm
|
| 33 |
+
from ranx import Qrels, Run
|
| 34 |
|
| 35 |
|
| 36 |
class Benchmark:
|
|
|
|
| 51 |
query_list = list(qrels.keys())
|
| 52 |
for query in query_list:
|
| 53 |
|
| 54 |
+
ranks = retrievaler.retrieval(query, self.embd_mdl,
|
| 55 |
+
dataset_idxnm, [self.kb.id], 1, 30,
|
| 56 |
0.0, self.vector_similarity_weight)
|
| 57 |
for c in ranks["chunks"]:
|
| 58 |
if "vector" in c:
|
|
|
|
| 106 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
| 107 |
d = {
|
| 108 |
"id": get_uuid(),
|
| 109 |
+
"kb_id": self.kb.id,
|
| 110 |
+
"docnm_kwd": "xxxxx",
|
| 111 |
+
"doc_id": "ksksks"
|
| 112 |
}
|
| 113 |
tokenize(d, text, "english")
|
| 114 |
docs.append(d)
|
|
|
|
| 140 |
for rel, text in zip(data.iloc[i]["search_results"]['rank'],
|
| 141 |
data.iloc[i]["search_results"]['search_context']):
|
| 142 |
d = {
|
| 143 |
+
"id": get_uuid(),
|
| 144 |
+
"kb_id": self.kb.id,
|
| 145 |
+
"docnm_kwd": "xxxxx",
|
| 146 |
+
"doc_id": "ksksks"
|
| 147 |
}
|
| 148 |
tokenize(d, text, "english")
|
| 149 |
docs.append(d)
|
|
|
|
| 188 |
text = corpus_total[tmp_data.iloc[i]['docid']]
|
| 189 |
rel = tmp_data.iloc[i]['relevance']
|
| 190 |
d = {
|
| 191 |
+
"id": get_uuid(),
|
| 192 |
+
"kb_id": self.kb.id,
|
| 193 |
+
"docnm_kwd": "xxxxx",
|
| 194 |
+
"doc_id": "ksksks"
|
| 195 |
}
|
| 196 |
tokenize(d, text, 'english')
|
| 197 |
docs.append(d)
|
|
|
|
| 213 |
for run_i in tqdm(range(len(run_keys)), desc="Calculating ndcg@10 for single query"):
|
| 214 |
key = run_keys[run_i]
|
| 215 |
keep_result.append({'query': key, 'qrel': qrels[key], 'run': run[key],
|
| 216 |
+
'ndcg@10': evaluate(Qrels({key: qrels[key]}), Run({key: run[key]}), "ndcg@10")})
|
| 217 |
keep_result = sorted(keep_result, key=lambda kk: kk['ndcg@10'])
|
| 218 |
with open(os.path.join(file_path, dataset + 'result.md'), 'w', encoding='utf-8') as f:
|
| 219 |
f.write('## Score For Every Query\n')
|
|
|
|
| 231 |
if dataset == "ms_marco_v1.1":
|
| 232 |
qrels, texts = self.ms_marco_index(file_path, "benchmark_ms_marco_v1.1")
|
| 233 |
run = self._get_retrieval(qrels, "benchmark_ms_marco_v1.1")
|
| 234 |
+
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
| 235 |
self.save_results(qrels, run, texts, dataset, file_path)
|
| 236 |
if dataset == "trivia_qa":
|
| 237 |
qrels, texts = self.trivia_qa_index(file_path, "benchmark_trivia_qa")
|
| 238 |
run = self._get_retrieval(qrels, "benchmark_trivia_qa")
|
| 239 |
+
print(dataset, evaluate((qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
| 240 |
self.save_results(qrels, run, texts, dataset, file_path)
|
| 241 |
if dataset == "miracl":
|
| 242 |
for lang in ['ar', 'bn', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'ja', 'ko', 'ru', 'sw', 'te', 'th',
|
|
|
|
| 257 |
os.path.join(miracl_corpus, 'miracl-corpus-v1.0-' + lang),
|
| 258 |
"benchmark_miracl_" + lang)
|
| 259 |
run = self._get_retrieval(qrels, "benchmark_miracl_" + lang)
|
| 260 |
+
print(dataset, evaluate(Qrels(qrels), Run(run), ["ndcg@10", "map@5", "mrr"]))
|
| 261 |
self.save_results(qrels, run, texts, dataset, file_path)
|
| 262 |
|
| 263 |
|