Kevin Hu
commited on
Commit
·
82bdd9f
1
Parent(s):
c39b5d3
add search TAB backend api (#2375)
Browse files### What problem does this PR solve?
#2247
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/chunk_app.py +17 -4
- api/apps/conversation_app.py +90 -4
- api/db/services/dialog_service.py +1 -1
- api/db/services/llm_service.py +1 -1
- graphrag/search.py +1 -1
- rag/nlp/search.py +7 -5
api/apps/chunk_app.py
CHANGED
|
@@ -58,7 +58,7 @@ def list_chunk():
|
|
| 58 |
}
|
| 59 |
if "available_int" in req:
|
| 60 |
query["available_int"] = int(req["available_int"])
|
| 61 |
-
sres = retrievaler.search(query, search.index_name(tenant_id))
|
| 62 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
| 63 |
for id in sres.ids:
|
| 64 |
d = {
|
|
@@ -259,12 +259,25 @@ def retrieval_test():
|
|
| 259 |
size = int(req.get("size", 30))
|
| 260 |
question = req["question"]
|
| 261 |
kb_id = req["kb_id"]
|
|
|
|
| 262 |
doc_ids = req.get("doc_ids", [])
|
| 263 |
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
| 264 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
| 265 |
top = int(req.get("top_k", 1024))
|
|
|
|
| 266 |
try:
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
if not e:
|
| 269 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
| 270 |
|
|
@@ -281,9 +294,9 @@ def retrieval_test():
|
|
| 281 |
question += keyword_extraction(chat_mdl, question)
|
| 282 |
|
| 283 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
| 284 |
-
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id,
|
| 285 |
similarity_threshold, vector_similarity_weight, top,
|
| 286 |
-
doc_ids, rerank_mdl=rerank_mdl)
|
| 287 |
for c in ranks["chunks"]:
|
| 288 |
if "vector" in c:
|
| 289 |
del c["vector"]
|
|
|
|
| 58 |
}
|
| 59 |
if "available_int" in req:
|
| 60 |
query["available_int"] = int(req["available_int"])
|
| 61 |
+
sres = retrievaler.search(query, search.index_name(tenant_id), highlight=True)
|
| 62 |
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
| 63 |
for id in sres.ids:
|
| 64 |
d = {
|
|
|
|
| 259 |
size = int(req.get("size", 30))
|
| 260 |
question = req["question"]
|
| 261 |
kb_id = req["kb_id"]
|
| 262 |
+
if isinstance(kb_id, str): kb_id = [kb_id]
|
| 263 |
doc_ids = req.get("doc_ids", [])
|
| 264 |
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
| 265 |
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
| 266 |
top = int(req.get("top_k", 1024))
|
| 267 |
+
|
| 268 |
try:
|
| 269 |
+
tenants = UserTenantService.query(user_id=current_user.id)
|
| 270 |
+
for kid in kb_id:
|
| 271 |
+
for tenant in tenants:
|
| 272 |
+
if KnowledgebaseService.query(
|
| 273 |
+
tenant_id=tenant.tenant_id, id=kid):
|
| 274 |
+
break
|
| 275 |
+
else:
|
| 276 |
+
return get_json_result(
|
| 277 |
+
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
|
| 278 |
+
retcode=RetCode.OPERATING_ERROR)
|
| 279 |
+
|
| 280 |
+
e, kb = KnowledgebaseService.get_by_id(kb_id[0])
|
| 281 |
if not e:
|
| 282 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
| 283 |
|
|
|
|
| 294 |
question += keyword_extraction(chat_mdl, question)
|
| 295 |
|
| 296 |
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
| 297 |
+
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_id, page, size,
|
| 298 |
similarity_threshold, vector_similarity_weight, top,
|
| 299 |
+
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
| 300 |
for c in ranks["chunks"]:
|
| 301 |
if "vector" in c:
|
| 302 |
del c["vector"]
|
api/apps/conversation_app.py
CHANGED
|
@@ -14,19 +14,22 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import json
|
|
|
|
| 17 |
from copy import deepcopy
|
| 18 |
|
| 19 |
-
from db.services.user_service import UserTenantService
|
| 20 |
from flask import request, Response
|
| 21 |
from flask_login import login_required, current_user
|
| 22 |
|
| 23 |
from api.db import LLMType
|
| 24 |
-
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
| 25 |
-
from api.db.services.
|
| 26 |
-
from api.
|
|
|
|
| 27 |
from api.utils import get_uuid
|
| 28 |
from api.utils.api_utils import get_json_result
|
| 29 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
@manager.route('/set', methods=['POST'])
|
|
@@ -286,3 +289,86 @@ def thumbup():
|
|
| 286 |
|
| 287 |
ConversationService.update_by_id(conv["id"], conv)
|
| 288 |
return get_json_result(data=conv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import json
|
| 17 |
+
import re
|
| 18 |
from copy import deepcopy
|
| 19 |
|
| 20 |
+
from api.db.services.user_service import UserTenantService
|
| 21 |
from flask import request, Response
|
| 22 |
from flask_login import login_required, current_user
|
| 23 |
|
| 24 |
from api.db import LLMType
|
| 25 |
+
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
| 26 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 27 |
+
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
| 28 |
+
from api.settings import RetCode, retrievaler
|
| 29 |
from api.utils import get_uuid
|
| 30 |
from api.utils.api_utils import get_json_result
|
| 31 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 32 |
+
from graphrag.mind_map_extractor import MindMapExtractor
|
| 33 |
|
| 34 |
|
| 35 |
@manager.route('/set', methods=['POST'])
|
|
|
|
| 289 |
|
| 290 |
ConversationService.update_by_id(conv["id"], conv)
|
| 291 |
return get_json_result(data=conv)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@manager.route('/ask', methods=['POST'])
|
| 295 |
+
@login_required
|
| 296 |
+
@validate_request("question", "kb_ids")
|
| 297 |
+
def ask_about():
|
| 298 |
+
req = request.json
|
| 299 |
+
uid = current_user.id
|
| 300 |
+
def stream():
|
| 301 |
+
nonlocal req, uid
|
| 302 |
+
try:
|
| 303 |
+
for ans in ask(req["question"], req["kb_ids"], uid):
|
| 304 |
+
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
| 305 |
+
except Exception as e:
|
| 306 |
+
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
| 307 |
+
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
| 308 |
+
ensure_ascii=False) + "\n\n"
|
| 309 |
+
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
| 310 |
+
|
| 311 |
+
resp = Response(stream(), mimetype="text/event-stream")
|
| 312 |
+
resp.headers.add_header("Cache-control", "no-cache")
|
| 313 |
+
resp.headers.add_header("Connection", "keep-alive")
|
| 314 |
+
resp.headers.add_header("X-Accel-Buffering", "no")
|
| 315 |
+
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
| 316 |
+
return resp
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
@manager.route('/mindmap', methods=['POST'])
|
| 320 |
+
@login_required
|
| 321 |
+
@validate_request("question", "kb_ids")
|
| 322 |
+
def mindmap():
|
| 323 |
+
req = request.json
|
| 324 |
+
kb_ids = req["kb_ids"]
|
| 325 |
+
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
| 326 |
+
if not e:
|
| 327 |
+
return get_data_error_result(retmsg="Knowledgebase not found!")
|
| 328 |
+
|
| 329 |
+
embd_mdl = TenantLLMService.model_instance(
|
| 330 |
+
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
| 331 |
+
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
| 332 |
+
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
| 333 |
+
0.3, 0.3, aggs=False)
|
| 334 |
+
mindmap = MindMapExtractor(chat_mdl)
|
| 335 |
+
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
| 336 |
+
return get_json_result(data=mind_map)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@manager.route('/related_questions', methods=['POST'])
|
| 340 |
+
@login_required
|
| 341 |
+
@validate_request("question")
|
| 342 |
+
def related_questions():
|
| 343 |
+
req = request.json
|
| 344 |
+
question = req["question"]
|
| 345 |
+
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
| 346 |
+
prompt = """
|
| 347 |
+
Objective: To generate search terms related to the user's search keywords, helping users find more valuable information.
|
| 348 |
+
Instructions:
|
| 349 |
+
- Based on the keywords provided by the user, generate 5-10 related search terms.
|
| 350 |
+
- Each search term should be directly or indirectly related to the keyword, guiding the user to find more valuable information.
|
| 351 |
+
- Use common, general terms as much as possible, avoiding obscure words or technical jargon.
|
| 352 |
+
- Keep the term length between 2-4 words, concise and clear.
|
| 353 |
+
- DO NOT translate, use the language of the original keywords.
|
| 354 |
+
|
| 355 |
+
### Example:
|
| 356 |
+
Keywords: Chinese football
|
| 357 |
+
Related search terms:
|
| 358 |
+
1. Current status of Chinese football
|
| 359 |
+
2. Reform of Chinese football
|
| 360 |
+
3. Youth training of Chinese football
|
| 361 |
+
4. Chinese football in the Asian Cup
|
| 362 |
+
5. Chinese football in the World Cup
|
| 363 |
+
|
| 364 |
+
Reason:
|
| 365 |
+
- When searching, users often only use one or two keywords, making it difficult to fully express their information needs.
|
| 366 |
+
- Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
|
| 367 |
+
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.
|
| 368 |
+
|
| 369 |
+
"""
|
| 370 |
+
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
|
| 371 |
+
Keywords: {question}
|
| 372 |
+
Related search terms:
|
| 373 |
+
"""}], {"temperature": 0.9})
|
| 374 |
+
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
|
api/db/services/dialog_service.py
CHANGED
|
@@ -210,7 +210,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 210 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 211 |
done_tm = timer()
|
| 212 |
prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000)
|
| 213 |
-
return {"answer": answer, "reference": refs, "prompt":
|
| 214 |
|
| 215 |
if stream:
|
| 216 |
last_ans = ""
|
|
|
|
| 210 |
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 211 |
done_tm = timer()
|
| 212 |
prompt += "\n### Elapsed\n - Retrieval: %.1f ms\n - LLM: %.1f ms"%((retrieval_tm-st)*1000, (done_tm-st)*1000)
|
| 213 |
+
return {"answer": answer, "reference": refs, "prompt": prompt}
|
| 214 |
|
| 215 |
if stream:
|
| 216 |
last_ans = ""
|
api/db/services/llm_service.py
CHANGED
|
@@ -190,7 +190,7 @@ class LLMBundle(object):
|
|
| 190 |
tenant_id, llm_type, llm_name, lang=lang)
|
| 191 |
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
| 192 |
tenant_id, llm_type, llm_name)
|
| 193 |
-
self.max_length =
|
| 194 |
for lm in LLMService.query(llm_name=llm_name):
|
| 195 |
self.max_length = lm.max_tokens
|
| 196 |
break
|
|
|
|
| 190 |
tenant_id, llm_type, llm_name, lang=lang)
|
| 191 |
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
| 192 |
tenant_id, llm_type, llm_name)
|
| 193 |
+
self.max_length = 8192
|
| 194 |
for lm in LLMService.query(llm_name=llm_name):
|
| 195 |
self.max_length = lm.max_tokens
|
| 196 |
break
|
graphrag/search.py
CHANGED
|
@@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
|
|
| 23 |
|
| 24 |
|
| 25 |
class KGSearch(Dealer):
|
| 26 |
-
def search(self, req, idxnm, emb_mdl=None):
|
| 27 |
def merge_into_first(sres, title=""):
|
| 28 |
df,texts = [],[]
|
| 29 |
for d in sres["hits"]["hits"]:
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class KGSearch(Dealer):
|
| 26 |
+
def search(self, req, idxnm, emb_mdl=None, highlight=False):
|
| 27 |
def merge_into_first(sres, title=""):
|
| 28 |
df,texts = [],[]
|
| 29 |
for d in sres["hits"]["hits"]:
|
rag/nlp/search.py
CHANGED
|
@@ -79,9 +79,9 @@ class Dealer:
|
|
| 79 |
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
| 80 |
return bqry
|
| 81 |
|
| 82 |
-
def search(self, req, idxnm, emb_mdl=None):
|
| 83 |
qst = req.get("question", "")
|
| 84 |
-
bqry, keywords = self.qryr.question(qst)
|
| 85 |
bqry = self._add_filters(bqry, req)
|
| 86 |
bqry.boost = 0.05
|
| 87 |
|
|
@@ -130,7 +130,7 @@ class Dealer:
|
|
| 130 |
qst, emb_mdl, req.get(
|
| 131 |
"similarity", 0.1), topk)
|
| 132 |
s["knn"]["filter"] = bqry.to_dict()
|
| 133 |
-
if "highlight" in s:
|
| 134 |
del s["highlight"]
|
| 135 |
q_vec = s["knn"]["query_vector"]
|
| 136 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
|
@@ -356,7 +356,7 @@ class Dealer:
|
|
| 356 |
rag_tokenizer.tokenize(inst).split(" "))
|
| 357 |
|
| 358 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
| 359 |
-
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None):
|
| 360 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
| 361 |
if not question:
|
| 362 |
return ranks
|
|
@@ -364,7 +364,7 @@ class Dealer:
|
|
| 364 |
"question": question, "vector": True, "topk": top,
|
| 365 |
"similarity": similarity_threshold,
|
| 366 |
"available_int": 1}
|
| 367 |
-
sres = self.search(req, index_name(tenant_id), embd_mdl)
|
| 368 |
|
| 369 |
if rerank_mdl:
|
| 370 |
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
|
@@ -405,6 +405,8 @@ class Dealer:
|
|
| 405 |
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
|
| 406 |
"positions": sres.field[id].get("position_int", "").split("\t")
|
| 407 |
}
|
|
|
|
|
|
|
| 408 |
if len(d["positions"]) % 5 == 0:
|
| 409 |
poss = []
|
| 410 |
for i in range(0, len(d["positions"]), 5):
|
|
|
|
| 79 |
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
| 80 |
return bqry
|
| 81 |
|
| 82 |
+
def search(self, req, idxnm, emb_mdl=None, highlight=False):
|
| 83 |
qst = req.get("question", "")
|
| 84 |
+
bqry, keywords = self.qryr.question(qst, min_match="30%")
|
| 85 |
bqry = self._add_filters(bqry, req)
|
| 86 |
bqry.boost = 0.05
|
| 87 |
|
|
|
|
| 130 |
qst, emb_mdl, req.get(
|
| 131 |
"similarity", 0.1), topk)
|
| 132 |
s["knn"]["filter"] = bqry.to_dict()
|
| 133 |
+
if not highlight and "highlight" in s:
|
| 134 |
del s["highlight"]
|
| 135 |
q_vec = s["knn"]["query_vector"]
|
| 136 |
es_logger.info("【Q】: {}".format(json.dumps(s)))
|
|
|
|
| 356 |
rag_tokenizer.tokenize(inst).split(" "))
|
| 357 |
|
| 358 |
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
|
| 359 |
+
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
|
| 360 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
| 361 |
if not question:
|
| 362 |
return ranks
|
|
|
|
| 364 |
"question": question, "vector": True, "topk": top,
|
| 365 |
"similarity": similarity_threshold,
|
| 366 |
"available_int": 1}
|
| 367 |
+
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
|
| 368 |
|
| 369 |
if rerank_mdl:
|
| 370 |
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
|
|
|
| 405 |
"vector": self.trans2floats(sres.field[id].get("q_%d_vec" % dim, "\t".join(["0"] * dim))),
|
| 406 |
"positions": sres.field[id].get("position_int", "").split("\t")
|
| 407 |
}
|
| 408 |
+
if highlight:
|
| 409 |
+
d["highlight"] = rmSpace(sres.highlight[id])
|
| 410 |
if len(d["positions"]) % 5 == 0:
|
| 411 |
poss = []
|
| 412 |
for i in range(0, len(d["positions"]), 5):
|