Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -107,9 +107,40 @@ def reset_chat(idx, ld, state):
|
|
| 107 |
gr.update(interactive=False),
|
| 108 |
)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
async def chat_stream(
|
| 111 |
idx, local_data, instruction_txtbox, chat_state,
|
| 112 |
-
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
|
|
|
|
| 113 |
):
|
| 114 |
res = [
|
| 115 |
chat_state["ppmanager_type"].from_json(json.dumps(ppm))
|
|
@@ -121,6 +152,14 @@ async def chat_stream(
|
|
| 121 |
PingPong(instruction_txtbox, "")
|
| 122 |
)
|
| 123 |
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
async for result in gen_text(
|
| 125 |
prompt, hf_model=MODEL_ID, hf_token=TOKEN,
|
| 126 |
parameters={
|
|
@@ -283,14 +322,15 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
| 283 |
elem_id="global-context"
|
| 284 |
)
|
| 285 |
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
| 294 |
|
| 295 |
gr.Markdown("#### GenConfig for **response** text generation")
|
| 296 |
with gr.Row():
|
|
@@ -315,7 +355,8 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
| 315 |
).then(
|
| 316 |
chat_stream,
|
| 317 |
[idx, local_data, instruction_txtbox, chat_state,
|
| 318 |
-
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
|
|
|
|
| 319 |
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate]
|
| 320 |
).then(
|
| 321 |
None, local_data, None,
|
|
@@ -346,7 +387,8 @@ with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
|
|
| 346 |
regen_event = regenerate.click(
|
| 347 |
rollback_last,
|
| 348 |
[idx, local_data, chat_state,
|
| 349 |
-
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv
|
|
|
|
| 350 |
[context_inspector, chatbot, local_data, regenerate]
|
| 351 |
).then(
|
| 352 |
None, local_data, None,
|
|
|
|
| 107 |
gr.update(interactive=False),
|
| 108 |
)
|
| 109 |
|
| 110 |
+
def internet_search(ppmanager, serper_api_key, global_context, ctx_num_lconv, device="cpu"):
|
| 111 |
+
internet_search_ppm = copy.deepcopy(ppm)
|
| 112 |
+
internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, "
|
| 113 |
+
f"give me an appropriate query to answer my question for google search. "
|
| 114 |
+
f"You should not say more than query. You should not say any words except the query."
|
| 115 |
+
|
| 116 |
+
internet_search_ppm.pingpongs[-1].ping = internet_search_prompt
|
| 117 |
+
internet_search_prompt = build_prompts(internet_search_ppm, "", win_size=ctx_num_lconv)
|
| 118 |
+
|
| 119 |
+
instruction = gen_text_none_stream(internet_search_prompt, hf_model=MODEL_ID, hf_token=TOKEN)
|
| 120 |
+
###
|
| 121 |
+
|
| 122 |
+
searcher = SimilaritySearcher.from_pretrained(device=device)
|
| 123 |
+
iss = InternetSearchStrategy(
|
| 124 |
+
searcher,
|
| 125 |
+
instruction=instruction,
|
| 126 |
+
serper_api_key=serper_api_key
|
| 127 |
+
)(ppmanager)
|
| 128 |
+
|
| 129 |
+
step_ppm = None
|
| 130 |
+
while True:
|
| 131 |
+
try:
|
| 132 |
+
step_ppm, _ = next(iss)
|
| 133 |
+
yield "", step_ppm.build_uis()
|
| 134 |
+
except StopIteration:
|
| 135 |
+
break
|
| 136 |
+
|
| 137 |
+
search_prompt = build_prompts(step_ppm, global_context, ctx_num_lconv)
|
| 138 |
+
yield search_prompt, ppmanager.build_uis()
|
| 139 |
+
|
| 140 |
async def chat_stream(
|
| 141 |
idx, local_data, instruction_txtbox, chat_state,
|
| 142 |
+
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
| 143 |
+
internet_option, serper_api_key
|
| 144 |
):
|
| 145 |
res = [
|
| 146 |
chat_state["ppmanager_type"].from_json(json.dumps(ppm))
|
|
|
|
| 152 |
PingPong(instruction_txtbox, "")
|
| 153 |
)
|
| 154 |
prompt = build_prompts(ppm, global_context, ctx_num_lconv)
|
| 155 |
+
|
| 156 |
+
#######
|
| 157 |
+
if internet_option:
|
| 158 |
+
search_prompt = None
|
| 159 |
+
for tmp_prompt, uis in internet_search(ppm, serper_api_key, global_context, ctx_num_lconv):
|
| 160 |
+
search_prompt = tmp_prompt
|
| 161 |
+
yield "", uis, prompt, str(res)
|
| 162 |
+
|
| 163 |
async for result in gen_text(
|
| 164 |
prompt, hf_model=MODEL_ID, hf_token=TOKEN,
|
| 165 |
parameters={
|
|
|
|
| 322 |
elem_id="global-context"
|
| 323 |
)
|
| 324 |
|
| 325 |
+
gr.Markdown("#### Internet search")
|
| 326 |
+
with gr.Row():
|
| 327 |
+
internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
|
| 328 |
+
serper_api_key = gr.Textbox(
|
| 329 |
+
value= os.getenv("SERPER_API_KEY"),
|
| 330 |
+
placeholder="Get one by visiting serper.dev",
|
| 331 |
+
label="Serper api key",
|
| 332 |
+
visible=False
|
| 333 |
+
)
|
| 334 |
|
| 335 |
gr.Markdown("#### GenConfig for **response** text generation")
|
| 336 |
with gr.Row():
|
|
|
|
| 355 |
).then(
|
| 356 |
chat_stream,
|
| 357 |
[idx, local_data, instruction_txtbox, chat_state,
|
| 358 |
+
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
| 359 |
+
internet_option, serper_api_key],
|
| 360 |
[instruction_txtbox, context_inspector, chatbot, local_data, regenerate]
|
| 361 |
).then(
|
| 362 |
None, local_data, None,
|
|
|
|
| 387 |
regen_event = regenerate.click(
|
| 388 |
rollback_last,
|
| 389 |
[idx, local_data, chat_state,
|
| 390 |
+
global_context, res_temp, res_topk, res_rpen, res_mnts, res_sample, ctx_num_lconv,
|
| 391 |
+
internet_option, serper_api_key],
|
| 392 |
[context_inspector, chatbot, local_data, regenerate]
|
| 393 |
).then(
|
| 394 |
None, local_data, None,
|