Kevin Hu
commited on
Commit
·
41012b3
1
Parent(s):
63ae668
add elapsed time of conversation (#2316)
Browse files### What problem does this PR solve?
#2315
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
api/db/services/dialog_service.py
CHANGED
|
@@ -18,7 +18,7 @@ import os
|
|
| 18 |
import json
|
| 19 |
import re
|
| 20 |
from copy import deepcopy
|
| 21 |
-
|
| 22 |
from api.db import LLMType, ParserType
|
| 23 |
from api.db.db_models import Dialog, Conversation
|
| 24 |
from api.db.services.common_service import CommonService
|
|
@@ -88,6 +88,7 @@ def llm_id2llm_type(llm_id):
|
|
| 88 |
|
| 89 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 90 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
|
|
|
| 91 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
| 92 |
if not llm:
|
| 93 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
|
|
@@ -158,25 +159,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 158 |
doc_ids=attachments,
|
| 159 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
| 160 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 161 |
-
#self-rag
|
| 162 |
-
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
|
| 163 |
-
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
|
| 164 |
-
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
|
| 165 |
-
dialog.similarity_threshold,
|
| 166 |
-
dialog.vector_similarity_weight,
|
| 167 |
-
doc_ids=attachments,
|
| 168 |
-
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
| 169 |
-
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 170 |
-
|
| 171 |
chat_logger.info(
|
| 172 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
|
|
|
| 173 |
|
| 174 |
if not knowledges and prompt_config.get("empty_response"):
|
| 175 |
empty_res = prompt_config["empty_response"]
|
| 176 |
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
| 177 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
| 178 |
|
| 179 |
-
kwargs["knowledge"] = "\n".join(knowledges)
|
| 180 |
gen_conf = dialog.llm_setting
|
| 181 |
|
| 182 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
|
@@ -192,7 +184,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 192 |
max_tokens - used_token_count)
|
| 193 |
|
| 194 |
def decorate_answer(answer):
|
| 195 |
-
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt
|
| 196 |
refs = []
|
| 197 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 198 |
answer, idx = retr.insert_citations(answer,
|
|
@@ -216,7 +208,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 216 |
|
| 217 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 218 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 219 |
-
|
|
|
|
|
|
|
| 220 |
|
| 221 |
if stream:
|
| 222 |
last_ans = ""
|
|
@@ -415,4 +409,75 @@ def tts(tts_mdl, text):
|
|
| 415 |
bin = b""
|
| 416 |
for chunk in tts_mdl.tts(text):
|
| 417 |
bin += chunk
|
| 418 |
-
return binascii.hexlify(bin).decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
import json
|
| 19 |
import re
|
| 20 |
from copy import deepcopy
|
| 21 |
+
from timeit import default_timer as timer
|
| 22 |
from api.db import LLMType, ParserType
|
| 23 |
from api.db.db_models import Dialog, Conversation
|
| 24 |
from api.db.services.common_service import CommonService
|
|
|
|
| 88 |
|
| 89 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 90 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 91 |
+
st = timer()
|
| 92 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
| 93 |
if not llm:
|
| 94 |
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=dialog.llm_id)
|
|
|
|
| 159 |
doc_ids=attachments,
|
| 160 |
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
| 161 |
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
chat_logger.info(
|
| 163 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 164 |
+
retrieval_tm = timer()
|
| 165 |
|
| 166 |
if not knowledges and prompt_config.get("empty_response"):
|
| 167 |
empty_res = prompt_config["empty_response"]
|
| 168 |
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
| 169 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
| 170 |
|
| 171 |
+
kwargs["knowledge"] = "\n------\n".join(knowledges)
|
| 172 |
gen_conf = dialog.llm_setting
|
| 173 |
|
| 174 |
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
|
|
|
| 184 |
max_tokens - used_token_count)
|
| 185 |
|
| 186 |
def decorate_answer(answer):
|
| 187 |
+
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
|
| 188 |
refs = []
|
| 189 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 190 |
answer, idx = retr.insert_citations(answer,
|
|
|
|
| 208 |
|
| 209 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 210 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 211 |
+
done_tm = timer()
|
| 212 |
+
prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000)
|
| 213 |
+
return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", "<br/>", prompt)}
|
| 214 |
|
| 215 |
if stream:
|
| 216 |
last_ans = ""
|
|
|
|
| 409 |
bin = b""
|
| 410 |
for chunk in tts_mdl.tts(text):
|
| 411 |
bin += chunk
|
| 412 |
+
return binascii.hexlify(bin).decode("utf-8")
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def ask(question, kb_ids, tenant_id):
|
| 416 |
+
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
| 417 |
+
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 418 |
+
|
| 419 |
+
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
| 420 |
+
retr = retrievaler if not is_kg else kg_retrievaler
|
| 421 |
+
|
| 422 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
| 423 |
+
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
| 424 |
+
max_tokens = chat_mdl.max_length
|
| 425 |
+
|
| 426 |
+
kbinfos = retr.retrieval(question, embd_mdl, tenant_id, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
| 427 |
+
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
| 428 |
+
|
| 429 |
+
used_token_count = 0
|
| 430 |
+
for i, c in enumerate(knowledges):
|
| 431 |
+
used_token_count += num_tokens_from_string(c)
|
| 432 |
+
if max_tokens * 0.97 < used_token_count:
|
| 433 |
+
knowledges = knowledges[:i]
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
prompt = """
|
| 437 |
+
Role: You're a smart assistant. Your name is Miss R.
|
| 438 |
+
Task: Summarize the information from knowledge bases and answer user's question.
|
| 439 |
+
Requirements and restriction:
|
| 440 |
+
- DO NOT make things up, especially for numbers.
|
| 441 |
+
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
| 442 |
+
- Answer with markdown format text.
|
| 443 |
+
- Answer in language of user's question.
|
| 444 |
+
- DO NOT make things up, especially for numbers.
|
| 445 |
+
|
| 446 |
+
### Information from knowledge bases
|
| 447 |
+
%s
|
| 448 |
+
|
| 449 |
+
The above is information from knowledge bases.
|
| 450 |
+
|
| 451 |
+
"""%"\n".join(knowledges)
|
| 452 |
+
msg = [{"role": "user", "content": question}]
|
| 453 |
+
|
| 454 |
+
def decorate_answer(answer):
|
| 455 |
+
nonlocal knowledges, kbinfos, prompt
|
| 456 |
+
answer, idx = retr.insert_citations(answer,
|
| 457 |
+
[ck["content_ltks"]
|
| 458 |
+
for ck in kbinfos["chunks"]],
|
| 459 |
+
[ck["vector"]
|
| 460 |
+
for ck in kbinfos["chunks"]],
|
| 461 |
+
embd_mdl,
|
| 462 |
+
tkweight=0.7,
|
| 463 |
+
vtweight=0.3)
|
| 464 |
+
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 465 |
+
recall_docs = [
|
| 466 |
+
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
| 467 |
+
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
| 468 |
+
kbinfos["doc_aggs"] = recall_docs
|
| 469 |
+
refs = deepcopy(kbinfos)
|
| 470 |
+
for c in refs["chunks"]:
|
| 471 |
+
if c.get("vector"):
|
| 472 |
+
del c["vector"]
|
| 473 |
+
|
| 474 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 475 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 476 |
+
return {"answer": answer, "reference": refs}
|
| 477 |
+
|
| 478 |
+
answer = ""
|
| 479 |
+
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
| 480 |
+
answer = ans
|
| 481 |
+
yield {"answer": answer, "reference": {}}
|
| 482 |
+
yield decorate_answer(answer)
|
| 483 |
+
|