H
commited on
Commit
·
87d8c78
1
Parent(s):
970a3e8
Fix multiple generate (#1722)
Browse files### What problem does this PR solve?
#1625
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- graph/component/answer.py +3 -1
- graph/component/generate.py +32 -51
graph/component/answer.py
CHANGED
|
@@ -59,8 +59,10 @@ class Answer(ComponentBase, ABC):
|
|
| 59 |
stream = self.get_stream_input()
|
| 60 |
if isinstance(stream, pd.DataFrame):
|
| 61 |
res = stream
|
|
|
|
| 62 |
for ii, row in stream.iterrows():
|
| 63 |
-
|
|
|
|
| 64 |
else:
|
| 65 |
for st in stream():
|
| 66 |
res = st
|
|
|
|
| 59 |
stream = self.get_stream_input()
|
| 60 |
if isinstance(stream, pd.DataFrame):
|
| 61 |
res = stream
|
| 62 |
+
answer = ""
|
| 63 |
for ii, row in stream.iterrows():
|
| 64 |
+
answer += row.to_dict()["content"]
|
| 65 |
+
yield {"content": answer}
|
| 66 |
else:
|
| 67 |
for st in stream():
|
| 68 |
res = st
|
graph/component/generate.py
CHANGED
|
@@ -67,6 +67,34 @@ class Generate(ComponentBase):
|
|
| 67 |
cpnts = [para["component_id"] for para in self._param.parameters]
|
| 68 |
return cpnts
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def _run(self, history, **kwargs):
|
| 71 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
| 72 |
prompt = self._param.prompt
|
|
@@ -87,9 +115,8 @@ class Generate(ComponentBase):
|
|
| 87 |
prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
|
| 88 |
|
| 89 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
| 90 |
-
if kwargs.get("stream")
|
| 91 |
-
|
| 92 |
-
and self._canvas.get_component(downstreams[0])["obj"].component_name.lower() == "answer":
|
| 93 |
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
|
| 94 |
|
| 95 |
if "empty_response" in retrieval_res.columns:
|
|
@@ -97,27 +124,8 @@ class Generate(ComponentBase):
|
|
| 97 |
|
| 98 |
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
|
| 99 |
self._param.gen_conf())
|
| 100 |
-
|
| 101 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
| 102 |
-
|
| 103 |
-
[ck["content_ltks"]
|
| 104 |
-
for _, ck in retrieval_res.iterrows()],
|
| 105 |
-
[ck["vector"]
|
| 106 |
-
for _, ck in retrieval_res.iterrows()],
|
| 107 |
-
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
| 108 |
-
self._canvas.get_embedding_model()),
|
| 109 |
-
tkweight=0.7,
|
| 110 |
-
vtweight=0.3)
|
| 111 |
-
del retrieval_res["vector"]
|
| 112 |
-
retrieval_res = retrieval_res.to_dict("records")
|
| 113 |
-
df = []
|
| 114 |
-
for i in idx:
|
| 115 |
-
df.append(retrieval_res[int(i)])
|
| 116 |
-
r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans)
|
| 117 |
-
assert r, f"{i} => {ans}"
|
| 118 |
-
df[-1]["content"] = r.group(1)
|
| 119 |
-
ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans)
|
| 120 |
-
if ans: df.append({"content": ans})
|
| 121 |
return pd.DataFrame(df)
|
| 122 |
|
| 123 |
return Generate.be_output(ans)
|
|
@@ -138,34 +146,7 @@ class Generate(ComponentBase):
|
|
| 138 |
yield res
|
| 139 |
|
| 140 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
| 141 |
-
|
| 142 |
-
[ck["content_ltks"]
|
| 143 |
-
for _, ck in retrieval_res.iterrows()],
|
| 144 |
-
[ck["vector"]
|
| 145 |
-
for _, ck in retrieval_res.iterrows()],
|
| 146 |
-
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
| 147 |
-
self._canvas.get_embedding_model()),
|
| 148 |
-
tkweight=0.7,
|
| 149 |
-
vtweight=0.3)
|
| 150 |
-
doc_ids = set([])
|
| 151 |
-
recall_docs = []
|
| 152 |
-
for i in idx:
|
| 153 |
-
did = retrieval_res.loc[int(i), "doc_id"]
|
| 154 |
-
if did in doc_ids: continue
|
| 155 |
-
doc_ids.add(did)
|
| 156 |
-
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
| 157 |
-
|
| 158 |
-
del retrieval_res["vector"]
|
| 159 |
-
del retrieval_res["content_ltks"]
|
| 160 |
-
|
| 161 |
-
reference = {
|
| 162 |
-
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
|
| 163 |
-
"doc_aggs": recall_docs
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 167 |
-
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 168 |
-
res = {"content": answer, "reference": reference}
|
| 169 |
yield res
|
| 170 |
|
| 171 |
self.set_output(res)
|
|
|
|
| 67 |
cpnts = [para["component_id"] for para in self._param.parameters]
|
| 68 |
return cpnts
|
| 69 |
|
| 70 |
+
def set_cite(self, retrieval_res, answer):
|
| 71 |
+
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
| 72 |
+
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
| 73 |
+
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
| 74 |
+
self._canvas.get_embedding_model()), tkweight=0.7,
|
| 75 |
+
vtweight=0.3)
|
| 76 |
+
doc_ids = set([])
|
| 77 |
+
recall_docs = []
|
| 78 |
+
for i in idx:
|
| 79 |
+
did = retrieval_res.loc[int(i), "doc_id"]
|
| 80 |
+
if did in doc_ids: continue
|
| 81 |
+
doc_ids.add(did)
|
| 82 |
+
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
| 83 |
+
|
| 84 |
+
del retrieval_res["vector"]
|
| 85 |
+
del retrieval_res["content_ltks"]
|
| 86 |
+
|
| 87 |
+
reference = {
|
| 88 |
+
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
|
| 89 |
+
"doc_aggs": recall_docs
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 93 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 94 |
+
res = {"content": answer, "reference": reference}
|
| 95 |
+
|
| 96 |
+
return res
|
| 97 |
+
|
| 98 |
def _run(self, history, **kwargs):
|
| 99 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
| 100 |
prompt = self._param.prompt
|
|
|
|
| 115 |
prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
|
| 116 |
|
| 117 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
| 118 |
+
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
| 119 |
+
"obj"].component_name.lower() == "answer":
|
|
|
|
| 120 |
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
|
| 121 |
|
| 122 |
if "empty_response" in retrieval_res.columns:
|
|
|
|
| 124 |
|
| 125 |
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
|
| 126 |
self._param.gen_conf())
|
|
|
|
| 127 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
| 128 |
+
df = self.set_cite(retrieval_res, ans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
return pd.DataFrame(df)
|
| 130 |
|
| 131 |
return Generate.be_output(ans)
|
|
|
|
| 146 |
yield res
|
| 147 |
|
| 148 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
| 149 |
+
res = self.set_cite(retrieval_res, answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
yield res
|
| 151 |
|
| 152 |
self.set_output(res)
|