Spaces:
Paused
Paused
优化chatgpt对话的截断策略
Browse files- crazy_functions/谷歌检索小助手.py +2 -1
- request_llm/bridge_chatgpt.py +10 -7
- toolbox.py +46 -0
crazy_functions/谷歌检索小助手.py
CHANGED
|
@@ -98,7 +98,8 @@ def 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
|
|
| 98 |
history.extend([ "第一批", gpt_say ])
|
| 99 |
meta_paper_info_list = meta_paper_info_list[10:]
|
| 100 |
|
| 101 |
-
chatbot.append(["状态?",
|
|
|
|
| 102 |
msg = '正常'
|
| 103 |
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
| 104 |
res = write_results_to_file(history)
|
|
|
|
| 98 |
history.extend([ "第一批", gpt_say ])
|
| 99 |
meta_paper_info_list = meta_paper_info_list[10:]
|
| 100 |
|
| 101 |
+
chatbot.append(["状态?",
|
| 102 |
+
"已经全部完成,您可以试试让AI写一个Related Works,例如您可以继续输入Write a \"Related Works\" section about \"你搜索的研究领域\" for me."])
|
| 103 |
msg = '正常'
|
| 104 |
yield from update_ui(chatbot=chatbot, history=history, msg=msg) # 刷新界面
|
| 105 |
res = write_results_to_file(history)
|
request_llm/bridge_chatgpt.py
CHANGED
|
@@ -21,7 +21,7 @@ import importlib
|
|
| 21 |
|
| 22 |
# config_private.py放自己的秘密如API和代理网址
|
| 23 |
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
| 24 |
-
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys
|
| 25 |
proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
|
| 26 |
get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
| 27 |
|
|
@@ -145,7 +145,7 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|
| 145 |
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
| 146 |
return
|
| 147 |
|
| 148 |
-
history.append(inputs); history.append("
|
| 149 |
|
| 150 |
retry = 0
|
| 151 |
while True:
|
|
@@ -198,14 +198,17 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
|
|
| 198 |
chunk_decoded = chunk.decode()
|
| 199 |
error_msg = chunk_decoded
|
| 200 |
if "reduce the length" in error_msg:
|
| 201 |
-
|
| 202 |
-
history = []
|
|
|
|
|
|
|
|
|
|
| 203 |
elif "does not exist" in error_msg:
|
| 204 |
-
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist.
|
| 205 |
elif "Incorrect API key" in error_msg:
|
| 206 |
-
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY
|
| 207 |
elif "exceeded your current quota" in error_msg:
|
| 208 |
-
chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI
|
| 209 |
elif "bad forward key" in error_msg:
|
| 210 |
chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
|
| 211 |
elif "Not enough point" in error_msg:
|
|
|
|
| 21 |
|
| 22 |
# config_private.py放自己的秘密如API和代理网址
|
| 23 |
# 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
|
| 24 |
+
from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history
|
| 25 |
proxies, API_KEY, TIMEOUT_SECONDS, MAX_RETRY = \
|
| 26 |
get_conf('proxies', 'API_KEY', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
| 27 |
|
|
|
|
| 145 |
yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
|
| 146 |
return
|
| 147 |
|
| 148 |
+
history.append(inputs); history.append("")
|
| 149 |
|
| 150 |
retry = 0
|
| 151 |
while True:
|
|
|
|
| 198 |
chunk_decoded = chunk.decode()
|
| 199 |
error_msg = chunk_decoded
|
| 200 |
if "reduce the length" in error_msg:
|
| 201 |
+
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
| 202 |
+
history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
|
| 203 |
+
max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])//2) # history至少释放二分之一
|
| 204 |
+
chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
|
| 205 |
+
# history = [] # 清除历史
|
| 206 |
elif "does not exist" in error_msg:
|
| 207 |
+
chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.")
|
| 208 |
elif "Incorrect API key" in error_msg:
|
| 209 |
+
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务.")
|
| 210 |
elif "exceeded your current quota" in error_msg:
|
| 211 |
+
chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务.")
|
| 212 |
elif "bad forward key" in error_msg:
|
| 213 |
chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
|
| 214 |
elif "Not enough point" in error_msg:
|
toolbox.py
CHANGED
|
@@ -551,3 +551,49 @@ def run_gradio_in_subpath(demo, auth, port, custom_path):
|
|
| 551 |
return {"message": f"Gradio is running at: {custom_path}"}
|
| 552 |
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
| 553 |
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
return {"message": f"Gradio is running at: {custom_path}"}
|
| 552 |
app = gr.mount_gradio_app(app, demo, path=custom_path)
|
| 553 |
uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def clip_history(inputs, history, tokenizer, max_token_limit):
|
| 557 |
+
"""
|
| 558 |
+
reduce the length of input/history by clipping.
|
| 559 |
+
this function search for the longest entries to clip, little by little,
|
| 560 |
+
until the number of token of input/history is reduced under threshold.
|
| 561 |
+
通过剪辑来缩短输入/历史记录的长度。
|
| 562 |
+
此函数逐渐地搜索最长的条目进行剪辑,
|
| 563 |
+
直到输入/历史记录的标记数量降低到阈值以下。
|
| 564 |
+
"""
|
| 565 |
+
import numpy as np
|
| 566 |
+
from request_llm.bridge_all import model_info
|
| 567 |
+
def get_token_num(txt):
|
| 568 |
+
return len(tokenizer.encode(txt, disallowed_special=()))
|
| 569 |
+
input_token_num = get_token_num(inputs)
|
| 570 |
+
if input_token_num < max_token_limit * 3 / 4:
|
| 571 |
+
# 当输入部分的token占比小于限制的3/4时,在裁剪时把input的余量留出来
|
| 572 |
+
max_token_limit = max_token_limit - input_token_num
|
| 573 |
+
if max_token_limit < 128:
|
| 574 |
+
# 余量太小了,直接清除历史
|
| 575 |
+
history = []
|
| 576 |
+
return history
|
| 577 |
+
else:
|
| 578 |
+
# 当输入部分的token占比 > 限制的3/4时,直接清除历史
|
| 579 |
+
history = []
|
| 580 |
+
return history
|
| 581 |
+
|
| 582 |
+
everything = ['']
|
| 583 |
+
everything.extend(history)
|
| 584 |
+
n_token = get_token_num('\n'.join(everything))
|
| 585 |
+
everything_token = [get_token_num(e) for e in everything]
|
| 586 |
+
|
| 587 |
+
# 截断时的颗粒度
|
| 588 |
+
delta = max(everything_token) // 16
|
| 589 |
+
|
| 590 |
+
while n_token > max_token_limit:
|
| 591 |
+
where = np.argmax(everything_token)
|
| 592 |
+
encoded = tokenizer.encode(everything[where], disallowed_special=())
|
| 593 |
+
clipped_encoded = encoded[:len(encoded)-delta]
|
| 594 |
+
everything[where] = tokenizer.decode(clipped_encoded)[:-1] # -1 to remove the may-be illegal char
|
| 595 |
+
everything_token[where] = get_token_num(everything[where])
|
| 596 |
+
n_token = get_token_num('\n'.join(everything))
|
| 597 |
+
|
| 598 |
+
history = everything[1:]
|
| 599 |
+
return history
|