Kevin Hu
commited on
Commit
·
692cc99
1
Parent(s):
0b587a0
fix: term weight issue (#3294)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/benchmark.py +34 -5
- rag/nlp/term_weight.py +1 -1
rag/benchmark.py
CHANGED
|
@@ -16,11 +16,15 @@
|
|
| 16 |
import json
|
| 17 |
import os
|
| 18 |
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
| 19 |
from api.db import LLMType
|
| 20 |
from api.db.services.llm_service import LLMBundle
|
| 21 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 22 |
from api.settings import retrievaler
|
| 23 |
from api.utils import get_uuid
|
|
|
|
| 24 |
from rag.nlp import tokenize, search
|
| 25 |
from rag.utils.es_conn import ELASTICSEARCH
|
| 26 |
from ranx import evaluate
|
|
@@ -63,14 +67,34 @@ class Benchmark:
|
|
| 63 |
d["q_%d_vec" % len(v)] = v
|
| 64 |
return docs
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def ms_marco_index(self, file_path, index_name):
|
| 67 |
qrels = defaultdict(dict)
|
| 68 |
texts = defaultdict(dict)
|
| 69 |
docs = []
|
| 70 |
filelist = os.listdir(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
for dir in filelist:
|
| 72 |
data = pd.read_parquet(os.path.join(file_path, dir))
|
| 73 |
-
for i in tqdm(range(len(data)), colour="green", desc="
|
| 74 |
|
| 75 |
query = data.iloc[i]['query']
|
| 76 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
|
@@ -82,12 +106,17 @@ class Benchmark:
|
|
| 82 |
texts[d["id"]] = text
|
| 83 |
qrels[query][d["id"]] = int(rel)
|
| 84 |
if len(docs) >= 32:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
docs = []
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
return qrels, texts
|
| 92 |
|
| 93 |
def trivia_qa_index(self, file_path, index_name):
|
|
|
|
| 16 |
import json
|
| 17 |
import os
|
| 18 |
from collections import defaultdict
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 20 |
+
from copy import deepcopy
|
| 21 |
+
|
| 22 |
from api.db import LLMType
|
| 23 |
from api.db.services.llm_service import LLMBundle
|
| 24 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 25 |
from api.settings import retrievaler
|
| 26 |
from api.utils import get_uuid
|
| 27 |
+
from api.utils.file_utils import get_project_base_directory
|
| 28 |
from rag.nlp import tokenize, search
|
| 29 |
from rag.utils.es_conn import ELASTICSEARCH
|
| 30 |
from ranx import evaluate
|
|
|
|
| 67 |
d["q_%d_vec" % len(v)] = v
|
| 68 |
return docs
|
| 69 |
|
| 70 |
+
@staticmethod
|
| 71 |
+
def init_kb(index_name):
|
| 72 |
+
idxnm = search.index_name(index_name)
|
| 73 |
+
if ELASTICSEARCH.indexExist(idxnm):
|
| 74 |
+
ELASTICSEARCH.deleteIdx(search.index_name(index_name))
|
| 75 |
+
|
| 76 |
+
return ELASTICSEARCH.createIdx(idxnm, json.load(
|
| 77 |
+
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
|
| 78 |
+
|
| 79 |
def ms_marco_index(self, file_path, index_name):
|
| 80 |
qrels = defaultdict(dict)
|
| 81 |
texts = defaultdict(dict)
|
| 82 |
docs = []
|
| 83 |
filelist = os.listdir(file_path)
|
| 84 |
+
self.init_kb(index_name)
|
| 85 |
+
|
| 86 |
+
max_workers = int(os.environ.get('MAX_WORKERS', 3))
|
| 87 |
+
exe = ThreadPoolExecutor(max_workers=max_workers)
|
| 88 |
+
threads = []
|
| 89 |
+
|
| 90 |
+
def slow_actions(es_docs, idx_nm):
|
| 91 |
+
es_docs = self.embedding(es_docs)
|
| 92 |
+
ELASTICSEARCH.bulk(es_docs, idx_nm)
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
for dir in filelist:
|
| 96 |
data = pd.read_parquet(os.path.join(file_path, dir))
|
| 97 |
+
for i in tqdm(range(len(data)), colour="green", desc="Tokenizing:" + dir):
|
| 98 |
|
| 99 |
query = data.iloc[i]['query']
|
| 100 |
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
|
|
|
| 106 |
texts[d["id"]] = text
|
| 107 |
qrels[query][d["id"]] = int(rel)
|
| 108 |
if len(docs) >= 32:
|
| 109 |
+
threads.append(
|
| 110 |
+
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
| 111 |
docs = []
|
| 112 |
|
| 113 |
+
threads.append(
|
| 114 |
+
exe.submit(slow_actions, deepcopy(docs), search.index_name(index_name)))
|
| 115 |
+
|
| 116 |
+
for i in tqdm(range(len(threads)), colour="red", desc="Indexing:" + dir):
|
| 117 |
+
if not threads[i].result().output:
|
| 118 |
+
print("Indexing error...")
|
| 119 |
+
|
| 120 |
return qrels, texts
|
| 121 |
|
| 122 |
def trivia_qa_index(self, file_path, index_name):
|
rag/nlp/term_weight.py
CHANGED
|
@@ -227,7 +227,7 @@ class Dealer:
|
|
| 227 |
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
| 228 |
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
| 229 |
np.array([ner(t) * postag(t) for t in tks])
|
| 230 |
-
tw = zip(tks, wts)
|
| 231 |
else:
|
| 232 |
for tk in tks:
|
| 233 |
tt = self.tokenMerge(self.pretoken(tk, True))
|
|
|
|
| 227 |
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
|
| 228 |
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
| 229 |
np.array([ner(t) * postag(t) for t in tks])
|
| 230 |
+
tw = list(zip(tks, wts))
|
| 231 |
else:
|
| 232 |
for tk in tks:
|
| 233 |
tt = self.tokenMerge(self.pretoken(tk, True))
|