Kevin Hu
commited on
Commit
·
62a5517
1
Parent(s):
0e469cf
Rebuild graph when it's out of time. (#4607)
Browse files### What problem does this PR solve?
#4543
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
- api/db/services/dialog_service.py +12 -3
- graphrag/search.py +3 -3
- graphrag/utils.py +35 -1
- rag/nlp/search.py +1 -1
- rag/svr/task_executor.py +4 -2
api/db/services/dialog_service.py
CHANGED
|
@@ -17,6 +17,7 @@ import logging
|
|
| 17 |
import binascii
|
| 18 |
import os
|
| 19 |
import json
|
|
|
|
| 20 |
import re
|
| 21 |
from collections import defaultdict
|
| 22 |
from copy import deepcopy
|
|
@@ -353,7 +354,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 353 |
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
| 354 |
|
| 355 |
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
| 356 |
-
return {"answer": answer, "reference": refs, "prompt": prompt}
|
| 357 |
|
| 358 |
if stream:
|
| 359 |
last_ans = ""
|
|
@@ -795,5 +796,13 @@ Output:
|
|
| 795 |
if kwd.find("**ERROR**") >= 0:
|
| 796 |
raise Exception(kwd)
|
| 797 |
|
| 798 |
-
|
| 799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import binascii
|
| 18 |
import os
|
| 19 |
import json
|
| 20 |
+
import json_repair
|
| 21 |
import re
|
| 22 |
from collections import defaultdict
|
| 23 |
from copy import deepcopy
|
|
|
|
| 354 |
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
| 355 |
|
| 356 |
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
| 357 |
+
return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)}
|
| 358 |
|
| 359 |
if stream:
|
| 360 |
last_ans = ""
|
|
|
|
| 796 |
if kwd.find("**ERROR**") >= 0:
|
| 797 |
raise Exception(kwd)
|
| 798 |
|
| 799 |
+
try:
|
| 800 |
+
return json_repair.loads(kwd)
|
| 801 |
+
except json_repair.JSONDecodeError:
|
| 802 |
+
try:
|
| 803 |
+
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
|
| 804 |
+
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
| 805 |
+
return json_repair.loads(result)
|
| 806 |
+
except Exception as e:
|
| 807 |
+
logging.exception(f"JSON parsing error: {result} -> {e}")
|
| 808 |
+
raise e
|
graphrag/search.py
CHANGED
|
@@ -251,11 +251,11 @@ class KGSearch(Dealer):
|
|
| 251 |
break
|
| 252 |
|
| 253 |
if ents:
|
| 254 |
-
ents = "\n
|
| 255 |
else:
|
| 256 |
ents = ""
|
| 257 |
if relas:
|
| 258 |
-
relas = "\n
|
| 259 |
else:
|
| 260 |
relas = ""
|
| 261 |
|
|
@@ -296,7 +296,7 @@ class KGSearch(Dealer):
|
|
| 296 |
|
| 297 |
if not txts:
|
| 298 |
return ""
|
| 299 |
-
return "\n
|
| 300 |
|
| 301 |
|
| 302 |
if __name__ == "__main__":
|
|
|
|
| 251 |
break
|
| 252 |
|
| 253 |
if ents:
|
| 254 |
+
ents = "\n---- Entities ----\n{}".format(pd.DataFrame(ents).to_csv())
|
| 255 |
else:
|
| 256 |
ents = ""
|
| 257 |
if relas:
|
| 258 |
+
relas = "\n---- Relations ----\n{}".format(pd.DataFrame(relas).to_csv())
|
| 259 |
else:
|
| 260 |
relas = ""
|
| 261 |
|
|
|
|
| 296 |
|
| 297 |
if not txts:
|
| 298 |
return ""
|
| 299 |
+
return "\n---- Community Report ----\n" + "\n".join(txts)
|
| 300 |
|
| 301 |
|
| 302 |
if __name__ == "__main__":
|
graphrag/utils.py
CHANGED
|
@@ -23,6 +23,7 @@ from networkx.readwrite import json_graph
|
|
| 23 |
|
| 24 |
from api import settings
|
| 25 |
from rag.nlp import search, rag_tokenizer
|
|
|
|
| 26 |
from rag.utils.redis_conn import REDIS_CONN
|
| 27 |
|
| 28 |
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
|
@@ -363,7 +364,7 @@ def get_graph(tenant_id, kb_id):
|
|
| 363 |
res.field[id]["source_id"]
|
| 364 |
except Exception:
|
| 365 |
continue
|
| 366 |
-
return
|
| 367 |
|
| 368 |
|
| 369 |
def set_graph(tenant_id, kb_id, graph, docids):
|
|
@@ -517,3 +518,36 @@ def flat_uniq_list(arr, key):
|
|
| 517 |
res.append(a)
|
| 518 |
return list(set(res))
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
from api import settings
|
| 25 |
from rag.nlp import search, rag_tokenizer
|
| 26 |
+
from rag.utils.doc_store_conn import OrderByExpr
|
| 27 |
from rag.utils.redis_conn import REDIS_CONN
|
| 28 |
|
| 29 |
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
|
|
|
| 364 |
res.field[id]["source_id"]
|
| 365 |
except Exception:
|
| 366 |
continue
|
| 367 |
+
return rebuild_graph(tenant_id, kb_id)
|
| 368 |
|
| 369 |
|
| 370 |
def set_graph(tenant_id, kb_id, graph, docids):
|
|
|
|
| 518 |
res.append(a)
|
| 519 |
return list(set(res))
|
| 520 |
|
| 521 |
+
|
| 522 |
+
def rebuild_graph(tenant_id, kb_id):
|
| 523 |
+
graph = nx.Graph()
|
| 524 |
+
src_ids = []
|
| 525 |
+
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
|
| 526 |
+
bs = 256
|
| 527 |
+
for i in range(0, 10000000, bs):
|
| 528 |
+
es_res = settings.docStoreConn.search(flds, [],
|
| 529 |
+
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
|
| 530 |
+
[],
|
| 531 |
+
OrderByExpr(),
|
| 532 |
+
i, bs, search.index_name(tenant_id), [kb_id]
|
| 533 |
+
)
|
| 534 |
+
tot = settings.docStoreConn.getTotal(es_res)
|
| 535 |
+
if tot == 0:
|
| 536 |
+
return None, None
|
| 537 |
+
|
| 538 |
+
es_res = settings.docStoreConn.getFields(es_res, flds)
|
| 539 |
+
for id, d in es_res.items():
|
| 540 |
+
src_ids.extend(d.get("source_id", []))
|
| 541 |
+
if d["knowledge_graph_kwd"] == "entity":
|
| 542 |
+
graph.add_node(d["entity_kwd"], entity_type=d["entity_type_kwd"])
|
| 543 |
+
else:
|
| 544 |
+
graph.add_edge(
|
| 545 |
+
d["from_entity_kwd"],
|
| 546 |
+
d["to_entity_kwd"],
|
| 547 |
+
weight=int(d["weight_int"])
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if len(es_res.keys()) < 128:
|
| 551 |
+
return graph, list(set(src_ids))
|
| 552 |
+
|
| 553 |
+
return graph, list(set(src_ids))
|
rag/nlp/search.py
CHANGED
|
@@ -483,4 +483,4 @@ class Dealer:
|
|
| 483 |
cnt = np.sum([c for _, c in aggs])
|
| 484 |
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
|
| 485 |
key=lambda x: x[1] * -1)[:topn_tags]
|
| 486 |
-
return {a: c for a, c in tag_fea
|
|
|
|
| 483 |
cnt = np.sum([c for _, c in aggs])
|
| 484 |
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
|
| 485 |
key=lambda x: x[1] * -1)[:topn_tags]
|
| 486 |
+
return {a: max(1, c) for a, c in tag_fea}
|
rag/svr/task_executor.py
CHANGED
|
@@ -327,8 +327,10 @@ def build_chunks(task, progress_callback):
|
|
| 327 |
random.choices(examples, k=2) if len(examples)>2 else examples,
|
| 328 |
topn=topn_tags)
|
| 329 |
if cached:
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
| 332 |
|
| 333 |
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
|
| 334 |
|
|
|
|
| 327 |
random.choices(examples, k=2) if len(examples)>2 else examples,
|
| 328 |
topn=topn_tags)
|
| 329 |
if cached:
|
| 330 |
+
cached = json.dumps(cached)
|
| 331 |
+
if cached:
|
| 332 |
+
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
| 333 |
+
d[TAG_FLD] = json.loads(cached)
|
| 334 |
|
| 335 |
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
|
| 336 |
|