Spaces:
Running
Running
maxiaolong03
commited on
Commit
·
a93c636
1
Parent(s):
47fd9da
add files
Browse files- app.py +122 -169
- bot_requests.py +56 -71
app.py
CHANGED
@@ -15,23 +15,22 @@
|
|
15 |
"""This file contains the code for the chatbot demo using Gradio."""
|
16 |
|
17 |
import argparse
|
18 |
-
|
19 |
-
from functools import partial
|
20 |
import json
|
21 |
import logging
|
22 |
import os
|
23 |
-
import base64
|
24 |
from argparse import ArgumentParser
|
|
|
|
|
25 |
|
26 |
import gradio as gr
|
27 |
-
|
28 |
from bot_requests import BotClient
|
29 |
|
30 |
os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
|
31 |
|
32 |
logging.root.setLevel(logging.INFO)
|
33 |
|
34 |
-
MULTI_MODEL_PREFIX = "
|
35 |
|
36 |
|
37 |
def get_args() -> argparse.Namespace:
|
@@ -48,21 +47,13 @@ def get_args() -> argparse.Namespace:
|
|
48 |
"""
|
49 |
parser = ArgumentParser(description="ERNIE models web chat demo.")
|
50 |
|
|
|
|
|
|
|
|
|
51 |
parser.add_argument(
|
52 |
-
"--
|
53 |
-
|
54 |
-
parser.add_argument(
|
55 |
-
"--server-name", type=str, default="0.0.0.0", help="Demo server name."
|
56 |
-
)
|
57 |
-
parser.add_argument(
|
58 |
-
"--max_char", type=int, default=8000, help="Maximum character limit for messages."
|
59 |
-
)
|
60 |
-
parser.add_argument(
|
61 |
-
"--max_retry_num", type=int, default=3, help="Maximum retry number for request."
|
62 |
-
)
|
63 |
-
parser.add_argument(
|
64 |
-
"--model_map",
|
65 |
-
type=str,
|
66 |
default="""{
|
67 |
"ernie-4.5-turbo-128k-preview": "https://qianfan.baidubce.com/v2",
|
68 |
"ernie-4.5-21b-a3b": "https://qianfan.baidubce.com/v2",
|
@@ -80,7 +71,7 @@ def get_args() -> argparse.Namespace:
|
|
80 |
- Prefix determines model capabilities:
|
81 |
* ERNIE-4.5[-*]: Text-only model
|
82 |
* ERNIE-4.5-VL[-*]: Multimodal models (image+text)
|
83 |
-
"""
|
84 |
)
|
85 |
|
86 |
args = parser.parse_args()
|
@@ -96,7 +87,7 @@ def get_args() -> argparse.Namespace:
|
|
96 |
return args
|
97 |
|
98 |
|
99 |
-
class GradioEvents
|
100 |
"""
|
101 |
Central handler for all Gradio interface events in the chatbot demo. Provides static methods
|
102 |
for processing user interactions including:
|
@@ -104,16 +95,17 @@ class GradioEvents(object):
|
|
104 |
- Conversation state management
|
105 |
- Image handling and URL conversion
|
106 |
- Component visibility control
|
107 |
-
|
108 |
-
Coordinates with BotClient to interface with backend models while maintaining
|
109 |
conversation history and handling multimodal inputs.
|
110 |
"""
|
|
|
111 |
@staticmethod
|
112 |
def get_image_url(image_path: str) -> str:
|
113 |
"""
|
114 |
-
Converts an image file at the given path to a base64 encoded data URL
|
115 |
-
that can be used directly in HTML or Gradio interfaces.
|
116 |
-
Reads the image file, encodes it in base64 format, and constructs
|
117 |
a data URL with the appropriate image MIME type.
|
118 |
|
119 |
Args:
|
@@ -126,26 +118,26 @@ class GradioEvents(object):
|
|
126 |
extension = image_path.split(".")[-1]
|
127 |
with open(image_path, "rb") as image_file:
|
128 |
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
129 |
-
url = "data:image/{
|
130 |
return url
|
131 |
|
132 |
@staticmethod
|
133 |
def chat_stream(
|
134 |
-
query: str,
|
135 |
-
task_history: list,
|
136 |
-
image_history: dict,
|
137 |
-
model_name: str,
|
138 |
-
file_url: str,
|
139 |
-
system_msg: str,
|
140 |
-
max_tokens: int,
|
141 |
-
temperature: float,
|
142 |
-
top_p: float,
|
143 |
-
bot_client: BotClient
|
144 |
) -> str:
|
145 |
"""
|
146 |
-
Handles streaming chat interactions by processing user queries and
|
147 |
-
generating real-time responses from the bot client. Constructs conversation
|
148 |
-
history including system messages, text inputs and image attachments, then
|
149 |
streams back model responses.
|
150 |
|
151 |
Args:
|
@@ -169,10 +161,9 @@ class GradioEvents(object):
|
|
169 |
for idx, (query_h, response_h) in enumerate(task_history):
|
170 |
if idx in image_history:
|
171 |
content = []
|
172 |
-
content.append(
|
173 |
-
"type": "image_url",
|
174 |
-
|
175 |
-
})
|
176 |
content.append({"type": "text", "text": query_h})
|
177 |
conversation.append({"role": "user", "content": content})
|
178 |
else:
|
@@ -193,29 +184,29 @@ class GradioEvents(object):
|
|
193 |
for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
|
194 |
if "error" in chunk:
|
195 |
raise Exception(chunk["error"])
|
196 |
-
|
197 |
message = chunk.get("choices", [{}])[0].get("delta", {})
|
198 |
content = message.get("content", "")
|
199 |
-
|
200 |
if content:
|
201 |
yield content
|
202 |
-
|
203 |
except Exception as e:
|
204 |
raise gr.Error("Exception: " + repr(e))
|
205 |
|
206 |
@staticmethod
|
207 |
def predict_stream(
|
208 |
-
query: str,
|
209 |
-
chatbot: list,
|
210 |
-
task_history: list,
|
211 |
-
image_history: dict,
|
212 |
-
model: str,
|
213 |
-
file_url: str,
|
214 |
-
system_msg: str,
|
215 |
-
max_tokens: int,
|
216 |
-
temperature: float,
|
217 |
-
top_p: float,
|
218 |
-
bot_client: BotClient
|
219 |
) -> list:
|
220 |
"""
|
221 |
Processes user queries in a streaming manner by coordinating with the chat stream handler,
|
@@ -240,29 +231,20 @@ class GradioEvents(object):
|
|
240 |
list: A list containing the updated chatbot state after processing the user's query.
|
241 |
"""
|
242 |
|
243 |
-
logging.info("User: {}"
|
244 |
-
chatbot.append({"role": "user", "content": query})
|
245 |
-
|
246 |
# First yield the chatbot with user message
|
247 |
yield chatbot
|
248 |
|
249 |
new_texts = GradioEvents.chat_stream(
|
250 |
-
query,
|
251 |
-
task_history,
|
252 |
-
image_history,
|
253 |
-
model,
|
254 |
-
file_url,
|
255 |
-
system_msg,
|
256 |
-
max_tokens,
|
257 |
-
temperature,
|
258 |
-
top_p,
|
259 |
-
bot_client
|
260 |
)
|
261 |
|
262 |
response = ""
|
263 |
-
for new_text in new_texts:
|
264 |
response += new_text
|
265 |
-
|
266 |
# Remove previous message if exists
|
267 |
if chatbot[-1].get("role") == "assistant":
|
268 |
chatbot.pop(-1)
|
@@ -271,26 +253,26 @@ class GradioEvents(object):
|
|
271 |
chatbot.append({"role": "assistant", "content": response})
|
272 |
yield chatbot
|
273 |
|
274 |
-
logging.info("History: {}"
|
275 |
-
task_history.append((query, response))
|
276 |
-
logging.info("ERNIE models: {}"
|
277 |
|
278 |
@staticmethod
|
279 |
def regenerate(
|
280 |
-
chatbot: list,
|
281 |
-
task_history: list,
|
282 |
-
image_history: dict,
|
283 |
-
model: str,
|
284 |
-
file_url: str,
|
285 |
-
system_msg: str,
|
286 |
-
max_tokens: int,
|
287 |
-
temperature: float,
|
288 |
-
top_p: float,
|
289 |
-
bot_client: BotClient
|
290 |
) -> list:
|
291 |
"""
|
292 |
-
Reconstructs the conversation context by removing the last interaction and
|
293 |
-
reprocesses the user's previous query to generate a fresh response. Maintains
|
294 |
consistency in conversation flow while allowing response regeneration.
|
295 |
|
296 |
Args:
|
@@ -319,26 +301,25 @@ class GradioEvents(object):
|
|
319 |
chatbot.pop(-1)
|
320 |
chatbot.pop(-1)
|
321 |
|
322 |
-
|
323 |
-
item[0],
|
324 |
-
chatbot,
|
325 |
-
task_history,
|
326 |
image_history,
|
327 |
-
model,
|
328 |
file_url,
|
329 |
-
system_msg,
|
330 |
-
max_tokens,
|
331 |
-
temperature,
|
332 |
top_p,
|
333 |
-
bot_client
|
334 |
-
)
|
335 |
-
yield chunk
|
336 |
|
337 |
@staticmethod
|
338 |
def reset_user_input() -> gr.update:
|
339 |
"""
|
340 |
Reset user input field value to empty string.
|
341 |
-
|
342 |
Returns:
|
343 |
gr.update: Update object representing the new value of the user input field.
|
344 |
"""
|
@@ -348,7 +329,7 @@ class GradioEvents(object):
|
|
348 |
def reset_state() -> tuple:
|
349 |
"""
|
350 |
Reset all states including chatbot, task_history, image_history, and file_btn.
|
351 |
-
|
352 |
Returns:
|
353 |
tuple: A tuple containing the following values:
|
354 |
- chatbot (list): An empty list that represents the cleared chatbot state.
|
@@ -357,19 +338,15 @@ class GradioEvents(object):
|
|
357 |
- file_btn (gr.update): An update object that sets the value of the file button to None.
|
358 |
"""
|
359 |
GradioEvents.gc()
|
360 |
-
|
361 |
-
reset_result = namedtuple("reset_result",
|
362 |
-
["chatbot",
|
363 |
-
"task_history",
|
364 |
-
"image_history",
|
365 |
-
"file_btn"])
|
366 |
return reset_result(
|
367 |
[], # clear chatbot
|
368 |
[], # clear task_history
|
369 |
{}, # clear image_history
|
370 |
gr.update(value=None), # clear file_btn
|
371 |
)
|
372 |
-
|
373 |
@staticmethod
|
374 |
def gc():
|
375 |
"""Run garbage collection to free up memory resources."""
|
@@ -381,10 +358,10 @@ class GradioEvents(object):
|
|
381 |
def toggle_components_visibility(model_name: str) -> gr.update:
|
382 |
"""
|
383 |
Toggle visibility of components depending on the selected model name.
|
384 |
-
|
385 |
Args:
|
386 |
model_name (str): Name of the selected model.
|
387 |
-
|
388 |
Returns:
|
389 |
gr.update: An update object representing the visibility of the file button.
|
390 |
"""
|
@@ -394,7 +371,7 @@ class GradioEvents(object):
|
|
394 |
def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
395 |
"""
|
396 |
Launch demo program
|
397 |
-
|
398 |
Args:
|
399 |
args (argparse.Namespace): argparse Namespace object containing parsed command line arguments
|
400 |
bot_client (BotClient): Bot client instance
|
@@ -420,34 +397,29 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
|
420 |
"""
|
421 |
with gr.Blocks(css=css) as demo:
|
422 |
logo_url = GradioEvents.get_image_url("assets/logo.png")
|
423 |
-
gr.Markdown(
|
424 |
-
|
425 |
-
|
|
|
|
|
426 |
gr.Markdown(
|
427 |
"""\
|
428 |
<center><font size=3>This demo is based on ERNIE models. \
|
429 |
(本演示基于文心大模型实现。)</center>"""
|
430 |
)
|
431 |
|
432 |
-
chatbot = gr.Chatbot(
|
433 |
-
label="ERNIE",
|
434 |
-
elem_classes="control-height",
|
435 |
-
type="messages"
|
436 |
-
)
|
437 |
model_names = list(args.model_map.keys())
|
438 |
with gr.Row():
|
439 |
model_name = gr.Dropdown(
|
440 |
-
label="Select Model",
|
441 |
-
choices=model_names,
|
442 |
-
value=model_names[0],
|
443 |
-
allow_custom_value=True
|
444 |
)
|
445 |
file_btn = gr.File(
|
446 |
-
label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
|
447 |
-
height="80px",
|
448 |
-
visible=True,
|
449 |
file_types=[".png", ".jpeg", "jpg"],
|
450 |
-
elem_id="file-upload"
|
451 |
)
|
452 |
query = gr.Textbox(label="Input", elem_id="text_input")
|
453 |
|
@@ -462,66 +434,46 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
|
462 |
system_message,
|
463 |
gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
|
464 |
gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Temperature"),
|
465 |
-
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Top-p (nucleus sampling)")
|
466 |
]
|
467 |
-
|
468 |
task_history = gr.State([])
|
469 |
image_history = gr.State({})
|
470 |
-
|
471 |
-
model_name.change(
|
472 |
-
GradioEvents.toggle_components_visibility,
|
473 |
-
inputs=model_name,
|
474 |
-
outputs=file_btn
|
475 |
-
)
|
476 |
model_name.change(
|
477 |
-
GradioEvents.reset_state,
|
478 |
-
outputs=[chatbot, task_history, image_history, file_btn],
|
479 |
-
show_progress=True
|
480 |
-
)
|
481 |
-
predict_with_clients = partial(
|
482 |
-
GradioEvents.predict_stream,
|
483 |
-
bot_client=bot_client
|
484 |
-
)
|
485 |
-
regenerate_with_clients = partial(
|
486 |
-
GradioEvents.regenerate,
|
487 |
-
bot_client=bot_client
|
488 |
)
|
|
|
|
|
489 |
query.submit(
|
490 |
-
predict_with_clients,
|
491 |
-
inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
|
492 |
-
outputs=[chatbot],
|
493 |
-
show_progress=True
|
494 |
)
|
495 |
query.submit(GradioEvents.reset_user_input, [], [query])
|
496 |
submit_btn.click(
|
497 |
-
predict_with_clients,
|
498 |
-
inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
|
499 |
-
outputs=[chatbot],
|
500 |
show_progress=True,
|
501 |
)
|
502 |
submit_btn.click(GradioEvents.reset_user_input, [], [query])
|
503 |
empty_btn.click(
|
504 |
-
GradioEvents.reset_state,
|
505 |
-
outputs=[chatbot, task_history, image_history, file_btn],
|
506 |
-
show_progress=True
|
507 |
)
|
508 |
regen_btn.click(
|
509 |
-
regenerate_with_clients,
|
510 |
-
inputs=[chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
|
511 |
-
outputs=[chatbot],
|
512 |
-
show_progress=True
|
513 |
)
|
514 |
|
515 |
-
demo.load(
|
516 |
-
|
517 |
-
|
518 |
-
outputs=file_btn
|
519 |
-
)
|
520 |
|
521 |
-
demo.queue().launch(
|
522 |
-
server_port=args.server_port,
|
523 |
-
server_name=args.server_name
|
524 |
-
)
|
525 |
|
526 |
def main():
|
527 |
"""Main function that runs when this script is executed."""
|
@@ -529,5 +481,6 @@ def main():
|
|
529 |
bot_client = BotClient(args)
|
530 |
launch_demo(args, bot_client)
|
531 |
|
|
|
532 |
if __name__ == "__main__":
|
533 |
main()
|
|
|
15 |
"""This file contains the code for the chatbot demo using Gradio."""
|
16 |
|
17 |
import argparse
|
18 |
+
import base64
|
|
|
19 |
import json
|
20 |
import logging
|
21 |
import os
|
|
|
22 |
from argparse import ArgumentParser
|
23 |
+
from collections import namedtuple
|
24 |
+
from functools import partial
|
25 |
|
26 |
import gradio as gr
|
|
|
27 |
from bot_requests import BotClient
|
28 |
|
29 |
os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
|
30 |
|
31 |
logging.root.setLevel(logging.INFO)
|
32 |
|
33 |
+
MULTI_MODEL_PREFIX = "ERNIE-4.5-VL"
|
34 |
|
35 |
|
36 |
def get_args() -> argparse.Namespace:
|
|
|
47 |
"""
|
48 |
parser = ArgumentParser(description="ERNIE models web chat demo.")
|
49 |
|
50 |
+
parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
|
51 |
+
parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
|
52 |
+
parser.add_argument("--max_char", type=int, default=8000, help="Maximum character limit for messages.")
|
53 |
+
parser.add_argument("--max_retry_num", type=int, default=3, help="Maximum retry number for request.")
|
54 |
parser.add_argument(
|
55 |
+
"--model_map",
|
56 |
+
type=str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
default="""{
|
58 |
"ernie-4.5-turbo-128k-preview": "https://qianfan.baidubce.com/v2",
|
59 |
"ernie-4.5-21b-a3b": "https://qianfan.baidubce.com/v2",
|
|
|
71 |
- Prefix determines model capabilities:
|
72 |
* ERNIE-4.5[-*]: Text-only model
|
73 |
* ERNIE-4.5-VL[-*]: Multimodal models (image+text)
|
74 |
+
""",
|
75 |
)
|
76 |
|
77 |
args = parser.parse_args()
|
|
|
87 |
return args
|
88 |
|
89 |
|
90 |
+
class GradioEvents:
|
91 |
"""
|
92 |
Central handler for all Gradio interface events in the chatbot demo. Provides static methods
|
93 |
for processing user interactions including:
|
|
|
95 |
- Conversation state management
|
96 |
- Image handling and URL conversion
|
97 |
- Component visibility control
|
98 |
+
|
99 |
+
Coordinates with BotClient to interface with backend models while maintaining
|
100 |
conversation history and handling multimodal inputs.
|
101 |
"""
|
102 |
+
|
103 |
@staticmethod
|
104 |
def get_image_url(image_path: str) -> str:
|
105 |
"""
|
106 |
+
Converts an image file at the given path to a base64 encoded data URL
|
107 |
+
that can be used directly in HTML or Gradio interfaces.
|
108 |
+
Reads the image file, encodes it in base64 format, and constructs
|
109 |
a data URL with the appropriate image MIME type.
|
110 |
|
111 |
Args:
|
|
|
118 |
extension = image_path.split(".")[-1]
|
119 |
with open(image_path, "rb") as image_file:
|
120 |
base64_image = base64.b64encode(image_file.read()).decode("utf-8")
|
121 |
+
url = f"data:image/{extension};base64,{base64_image}"
|
122 |
return url
|
123 |
|
124 |
@staticmethod
|
125 |
def chat_stream(
|
126 |
+
query: str,
|
127 |
+
task_history: list,
|
128 |
+
image_history: dict,
|
129 |
+
model_name: str,
|
130 |
+
file_url: str,
|
131 |
+
system_msg: str,
|
132 |
+
max_tokens: int,
|
133 |
+
temperature: float,
|
134 |
+
top_p: float,
|
135 |
+
bot_client: BotClient,
|
136 |
) -> str:
|
137 |
"""
|
138 |
+
Handles streaming chat interactions by processing user queries and
|
139 |
+
generating real-time responses from the bot client. Constructs conversation
|
140 |
+
history including system messages, text inputs and image attachments, then
|
141 |
streams back model responses.
|
142 |
|
143 |
Args:
|
|
|
161 |
for idx, (query_h, response_h) in enumerate(task_history):
|
162 |
if idx in image_history:
|
163 |
content = []
|
164 |
+
content.append(
|
165 |
+
{"type": "image_url", "image_url": {"url": GradioEvents.get_image_url(image_history[idx])}}
|
166 |
+
)
|
|
|
167 |
content.append({"type": "text", "text": query_h})
|
168 |
conversation.append({"role": "user", "content": content})
|
169 |
else:
|
|
|
184 |
for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
|
185 |
if "error" in chunk:
|
186 |
raise Exception(chunk["error"])
|
187 |
+
|
188 |
message = chunk.get("choices", [{}])[0].get("delta", {})
|
189 |
content = message.get("content", "")
|
190 |
+
|
191 |
if content:
|
192 |
yield content
|
193 |
+
|
194 |
except Exception as e:
|
195 |
raise gr.Error("Exception: " + repr(e))
|
196 |
|
197 |
@staticmethod
|
198 |
def predict_stream(
|
199 |
+
query: str,
|
200 |
+
chatbot: list,
|
201 |
+
task_history: list,
|
202 |
+
image_history: dict,
|
203 |
+
model: str,
|
204 |
+
file_url: str,
|
205 |
+
system_msg: str,
|
206 |
+
max_tokens: int,
|
207 |
+
temperature: float,
|
208 |
+
top_p: float,
|
209 |
+
bot_client: BotClient,
|
210 |
) -> list:
|
211 |
"""
|
212 |
Processes user queries in a streaming manner by coordinating with the chat stream handler,
|
|
|
231 |
list: A list containing the updated chatbot state after processing the user's query.
|
232 |
"""
|
233 |
|
234 |
+
logging.info(f"User: {query}")
|
235 |
+
chatbot.append({"role": "user", "content": query})
|
236 |
+
|
237 |
# First yield the chatbot with user message
|
238 |
yield chatbot
|
239 |
|
240 |
new_texts = GradioEvents.chat_stream(
|
241 |
+
query, task_history, image_history, model, file_url, system_msg, max_tokens, temperature, top_p, bot_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
)
|
243 |
|
244 |
response = ""
|
245 |
+
for new_text in new_texts:
|
246 |
response += new_text
|
247 |
+
|
248 |
# Remove previous message if exists
|
249 |
if chatbot[-1].get("role") == "assistant":
|
250 |
chatbot.pop(-1)
|
|
|
253 |
chatbot.append({"role": "assistant", "content": response})
|
254 |
yield chatbot
|
255 |
|
256 |
+
logging.info(f"History: {task_history}")
|
257 |
+
task_history.append((query, response))
|
258 |
+
logging.info(f"ERNIE models: {response}")
|
259 |
|
260 |
@staticmethod
|
261 |
def regenerate(
|
262 |
+
chatbot: list,
|
263 |
+
task_history: list,
|
264 |
+
image_history: dict,
|
265 |
+
model: str,
|
266 |
+
file_url: str,
|
267 |
+
system_msg: str,
|
268 |
+
max_tokens: int,
|
269 |
+
temperature: float,
|
270 |
+
top_p: float,
|
271 |
+
bot_client: BotClient,
|
272 |
) -> list:
|
273 |
"""
|
274 |
+
Reconstructs the conversation context by removing the last interaction and
|
275 |
+
reprocesses the user's previous query to generate a fresh response. Maintains
|
276 |
consistency in conversation flow while allowing response regeneration.
|
277 |
|
278 |
Args:
|
|
|
301 |
chatbot.pop(-1)
|
302 |
chatbot.pop(-1)
|
303 |
|
304 |
+
yield from GradioEvents.predict_stream(
|
305 |
+
item[0],
|
306 |
+
chatbot,
|
307 |
+
task_history,
|
308 |
image_history,
|
309 |
+
model,
|
310 |
file_url,
|
311 |
+
system_msg,
|
312 |
+
max_tokens,
|
313 |
+
temperature,
|
314 |
top_p,
|
315 |
+
bot_client,
|
316 |
+
)
|
|
|
317 |
|
318 |
@staticmethod
|
319 |
def reset_user_input() -> gr.update:
|
320 |
"""
|
321 |
Reset user input field value to empty string.
|
322 |
+
|
323 |
Returns:
|
324 |
gr.update: Update object representing the new value of the user input field.
|
325 |
"""
|
|
|
329 |
def reset_state() -> tuple:
|
330 |
"""
|
331 |
Reset all states including chatbot, task_history, image_history, and file_btn.
|
332 |
+
|
333 |
Returns:
|
334 |
tuple: A tuple containing the following values:
|
335 |
- chatbot (list): An empty list that represents the cleared chatbot state.
|
|
|
338 |
- file_btn (gr.update): An update object that sets the value of the file button to None.
|
339 |
"""
|
340 |
GradioEvents.gc()
|
341 |
+
|
342 |
+
reset_result = namedtuple("reset_result", ["chatbot", "task_history", "image_history", "file_btn"])
|
|
|
|
|
|
|
|
|
343 |
return reset_result(
|
344 |
[], # clear chatbot
|
345 |
[], # clear task_history
|
346 |
{}, # clear image_history
|
347 |
gr.update(value=None), # clear file_btn
|
348 |
)
|
349 |
+
|
350 |
@staticmethod
|
351 |
def gc():
|
352 |
"""Run garbage collection to free up memory resources."""
|
|
|
358 |
def toggle_components_visibility(model_name: str) -> gr.update:
|
359 |
"""
|
360 |
Toggle visibility of components depending on the selected model name.
|
361 |
+
|
362 |
Args:
|
363 |
model_name (str): Name of the selected model.
|
364 |
+
|
365 |
Returns:
|
366 |
gr.update: An update object representing the visibility of the file button.
|
367 |
"""
|
|
|
371 |
def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
372 |
"""
|
373 |
Launch demo program
|
374 |
+
|
375 |
Args:
|
376 |
args (argparse.Namespace): argparse Namespace object containing parsed command line arguments
|
377 |
bot_client (BotClient): Bot client instance
|
|
|
397 |
"""
|
398 |
with gr.Blocks(css=css) as demo:
|
399 |
logo_url = GradioEvents.get_image_url("assets/logo.png")
|
400 |
+
gr.Markdown(
|
401 |
+
f"""\
|
402 |
+
<p align="center"><img src="{logo_url}" \
|
403 |
+
style="height: 60px"/><p>"""
|
404 |
+
)
|
405 |
gr.Markdown(
|
406 |
"""\
|
407 |
<center><font size=3>This demo is based on ERNIE models. \
|
408 |
(本演示基于文心大模型实现。)</center>"""
|
409 |
)
|
410 |
|
411 |
+
chatbot = gr.Chatbot(label="ERNIE", elem_classes="control-height", type="messages")
|
|
|
|
|
|
|
|
|
412 |
model_names = list(args.model_map.keys())
|
413 |
with gr.Row():
|
414 |
model_name = gr.Dropdown(
|
415 |
+
label="Select Model", choices=model_names, value=model_names[0], allow_custom_value=True
|
|
|
|
|
|
|
416 |
)
|
417 |
file_btn = gr.File(
|
418 |
+
label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
|
419 |
+
height="80px",
|
420 |
+
visible=True,
|
421 |
file_types=[".png", ".jpeg", "jpg"],
|
422 |
+
elem_id="file-upload",
|
423 |
)
|
424 |
query = gr.Textbox(label="Input", elem_id="text_input")
|
425 |
|
|
|
434 |
system_message,
|
435 |
gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
|
436 |
gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Temperature"),
|
437 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Top-p (nucleus sampling)"),
|
438 |
]
|
439 |
+
|
440 |
task_history = gr.State([])
|
441 |
image_history = gr.State({})
|
442 |
+
|
443 |
+
model_name.change(GradioEvents.toggle_components_visibility, inputs=model_name, outputs=file_btn)
|
|
|
|
|
|
|
|
|
444 |
model_name.change(
|
445 |
+
GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
)
|
447 |
+
predict_with_clients = partial(GradioEvents.predict_stream, bot_client=bot_client)
|
448 |
+
regenerate_with_clients = partial(GradioEvents.regenerate, bot_client=bot_client)
|
449 |
query.submit(
|
450 |
+
predict_with_clients,
|
451 |
+
inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
|
452 |
+
outputs=[chatbot],
|
453 |
+
show_progress=True,
|
454 |
)
|
455 |
query.submit(GradioEvents.reset_user_input, [], [query])
|
456 |
submit_btn.click(
|
457 |
+
predict_with_clients,
|
458 |
+
inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
|
459 |
+
outputs=[chatbot],
|
460 |
show_progress=True,
|
461 |
)
|
462 |
submit_btn.click(GradioEvents.reset_user_input, [], [query])
|
463 |
empty_btn.click(
|
464 |
+
GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
|
|
|
|
|
465 |
)
|
466 |
regen_btn.click(
|
467 |
+
regenerate_with_clients,
|
468 |
+
inputs=[chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
|
469 |
+
outputs=[chatbot],
|
470 |
+
show_progress=True,
|
471 |
)
|
472 |
|
473 |
+
demo.load(GradioEvents.toggle_components_visibility, inputs=gr.State(model_names[0]), outputs=file_btn)
|
474 |
+
|
475 |
+
demo.queue().launch(server_port=args.server_port, server_name=args.server_name)
|
|
|
|
|
476 |
|
|
|
|
|
|
|
|
|
477 |
|
478 |
def main():
|
479 |
"""Main function that runs when this script is executed."""
|
|
|
481 |
bot_client = BotClient(args)
|
482 |
launch_demo(args, bot_client)
|
483 |
|
484 |
+
|
485 |
if __name__ == "__main__":
|
486 |
main()
|
bot_requests.py
CHANGED
@@ -14,22 +14,23 @@
|
|
14 |
|
15 |
"""BotClient class for interacting with bot models."""
|
16 |
|
17 |
-
import os
|
18 |
import argparse
|
|
|
19 |
import logging
|
20 |
import traceback
|
21 |
-
|
22 |
import jieba
|
|
|
23 |
from openai import OpenAI
|
24 |
|
25 |
-
import requests
|
26 |
|
27 |
-
class BotClient
|
28 |
"""Client for interacting with various AI models."""
|
|
|
29 |
def __init__(self, args: argparse.Namespace):
|
30 |
"""
|
31 |
-
Initializes the BotClient instance by configuring essential parameters from command line arguments
|
32 |
-
including retry limits, character constraints, model endpoints and API credentials while setting up
|
33 |
default values for missing arguments to ensure robust operation.
|
34 |
|
35 |
Args:
|
@@ -37,7 +38,7 @@ class BotClient(object):
|
|
37 |
Uses getattr() to safely retrieve values with fallback defaults.
|
38 |
"""
|
39 |
self.logger = logging.getLogger(__name__)
|
40 |
-
|
41 |
self.max_retry_num = getattr(args, 'max_retry_num', 3)
|
42 |
self.max_char = getattr(args, 'max_char', 8000)
|
43 |
|
@@ -54,8 +55,8 @@ class BotClient(object):
|
|
54 |
|
55 |
def call_back(self, host_url: str, req_data: dict) -> dict:
|
56 |
"""
|
57 |
-
Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response
|
58 |
-
conversion to a compatible dictionary format, and manages any exceptions that may occur during
|
59 |
the request process while logging errors appropriately.
|
60 |
|
61 |
Args:
|
@@ -68,20 +69,18 @@ class BotClient(object):
|
|
68 |
"""
|
69 |
try:
|
70 |
client = OpenAI(base_url=host_url, api_key=self.api_key)
|
71 |
-
response = client.chat.completions.create(
|
72 |
-
|
73 |
-
)
|
74 |
-
|
75 |
# Convert OpenAI response to compatible format
|
76 |
return response.model_dump()
|
77 |
|
78 |
except Exception as e:
|
79 |
-
self.logger.error("Stream request failed: {}"
|
80 |
raise
|
81 |
|
82 |
def call_back_stream(self, host_url: str, req_data: dict) -> dict:
|
83 |
"""
|
84 |
-
Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks
|
85 |
in real-time while handling any exceptions that may occur during the streaming process.
|
86 |
|
87 |
Args:
|
@@ -100,25 +99,20 @@ class BotClient(object):
|
|
100 |
for chunk in response:
|
101 |
if not chunk.choices:
|
102 |
continue
|
103 |
-
|
104 |
# Convert OpenAI response to compatible format
|
105 |
yield chunk.model_dump()
|
106 |
|
107 |
except Exception as e:
|
108 |
-
self.logger.error("Stream request failed: {}"
|
109 |
raise
|
110 |
|
111 |
def process(
|
112 |
-
self,
|
113 |
-
model_name: str,
|
114 |
-
req_data: dict,
|
115 |
-
max_tokens: int=2048,
|
116 |
-
temperature: float=1.0,
|
117 |
-
top_p: float=0.7
|
118 |
) -> dict:
|
119 |
"""
|
120 |
-
Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
|
121 |
-
including token limits and sampling settings, truncating messages to fit character limits, making API calls
|
122 |
with built-in retry mechanism, and logging the full request/response cycle for debugging purposes.
|
123 |
|
124 |
Args:
|
@@ -140,7 +134,7 @@ class BotClient(object):
|
|
140 |
req_data["messages"] = self.truncate_messages(req_data["messages"])
|
141 |
for _ in range(self.max_retry_num):
|
142 |
try:
|
143 |
-
self.logger.info("[MODEL] {}"
|
144 |
self.logger.info("[req_data]====>")
|
145 |
self.logger.info(json.dumps(req_data, ensure_ascii=False))
|
146 |
res = self.call_back(model_url, req_data)
|
@@ -153,15 +147,11 @@ class BotClient(object):
|
|
153 |
res = {}
|
154 |
if len(res) != 0 and "error" not in res:
|
155 |
break
|
156 |
-
|
157 |
return res
|
158 |
|
159 |
def process_stream(
|
160 |
-
self, model_name: str,
|
161 |
-
req_data: dict,
|
162 |
-
max_tokens: int=2048,
|
163 |
-
temperature: float=1.0,
|
164 |
-
top_p: float=0.7
|
165 |
) -> dict:
|
166 |
"""
|
167 |
Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
|
@@ -184,29 +174,28 @@ class BotClient(object):
|
|
184 |
req_data["temperature"] = temperature
|
185 |
req_data["top_p"] = top_p
|
186 |
req_data["messages"] = self.truncate_messages(req_data["messages"])
|
187 |
-
|
188 |
last_error = None
|
189 |
for _ in range(self.max_retry_num):
|
190 |
try:
|
191 |
-
self.logger.info("[MODEL] {}"
|
192 |
self.logger.info("[req_data]====>")
|
193 |
self.logger.info(json.dumps(req_data, ensure_ascii=False))
|
194 |
-
|
195 |
-
|
196 |
-
yield chunk
|
197 |
return
|
198 |
-
|
199 |
except Exception as e:
|
200 |
last_error = e
|
201 |
-
self.logger.error("Stream request failed (attempt {}/{}): {}"
|
202 |
-
|
203 |
self.logger.error("All retry attempts failed for stream request")
|
204 |
yield {"error": str(last_error)}
|
205 |
|
206 |
def cut_chinese_english(self, text: str) -> list:
|
207 |
"""
|
208 |
-
Segments mixed Chinese and English text into individual components using Jieba for Chinese words
|
209 |
-
while preserving English words as whole units, with special handling for Unicode character ranges
|
210 |
to distinguish between the two languages.
|
211 |
|
212 |
Args:
|
@@ -239,10 +228,10 @@ class BotClient(object):
|
|
239 |
"""
|
240 |
if not messages:
|
241 |
return messages
|
242 |
-
|
243 |
processed = []
|
244 |
total_units = 0
|
245 |
-
|
246 |
for msg in messages:
|
247 |
# Handle two different content formats
|
248 |
if isinstance(msg["content"], str):
|
@@ -251,31 +240,33 @@ class BotClient(object):
|
|
251 |
text_content = msg["content"][1]["text"]
|
252 |
else:
|
253 |
text_content = ""
|
254 |
-
|
255 |
# Calculate unit count after tokenization
|
256 |
units = self.cut_chinese_english(text_content)
|
257 |
unit_count = len(units)
|
258 |
-
|
259 |
-
processed.append(
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
266 |
total_units += unit_count
|
267 |
-
|
268 |
if total_units <= self.max_char:
|
269 |
return messages
|
270 |
-
|
271 |
# Number of units to remove
|
272 |
to_remove = total_units - self.max_char
|
273 |
-
|
274 |
# 1. Truncate historical messages
|
275 |
for i in range(len(processed) - 1, 1):
|
276 |
if to_remove <= 0:
|
277 |
break
|
278 |
-
|
279 |
# current = processed[i]
|
280 |
if processed[i]["unit_count"] <= to_remove:
|
281 |
processed[i]["text_content"] = ""
|
@@ -293,7 +284,7 @@ class BotClient(object):
|
|
293 |
elif isinstance(processed[i]["original_content"], list):
|
294 |
processed[i]["original_content"][1]["text"] = new_text
|
295 |
to_remove = 0
|
296 |
-
|
297 |
# 2. Truncate system message
|
298 |
if to_remove > 0:
|
299 |
system_msg = processed[0]
|
@@ -313,7 +304,7 @@ class BotClient(object):
|
|
313 |
elif isinstance(processed[0]["original_content"], list):
|
314 |
processed[0]["original_content"][1]["text"] = new_text
|
315 |
to_remove = 0
|
316 |
-
|
317 |
# 3. Truncate last message
|
318 |
if to_remove > 0 and len(processed) > 1:
|
319 |
last_msg = processed[-1]
|
@@ -331,15 +322,12 @@ class BotClient(object):
|
|
331 |
last_msg["original_content"] = ""
|
332 |
elif isinstance(last_msg["original_content"], list):
|
333 |
last_msg["original_content"][1]["text"] = ""
|
334 |
-
|
335 |
result = []
|
336 |
for msg in processed:
|
337 |
if msg["text_content"]:
|
338 |
-
result.append({
|
339 |
-
|
340 |
-
"content": msg["original_content"]
|
341 |
-
})
|
342 |
-
|
343 |
return result
|
344 |
|
345 |
def embed_fn(self, text: str) -> list:
|
@@ -366,17 +354,14 @@ class BotClient(object):
|
|
366 |
Returns:
|
367 |
list: List of responses from the AI Search service.
|
368 |
"""
|
369 |
-
headers = {
|
370 |
-
"Authorization": "Bearer " + self.qianfan_api_key,
|
371 |
-
"Content-Type": "application/json"
|
372 |
-
}
|
373 |
|
374 |
results = []
|
375 |
top_k = self.max_search_results_num // len(query_list)
|
376 |
for query in query_list:
|
377 |
payload = {
|
378 |
"messages": [{"role": "user", "content": query}],
|
379 |
-
"resource_type_filter": [{"type": "web", "top_k": top_k}]
|
380 |
}
|
381 |
response = requests.post(self.web_search_service_url, headers=headers, json=payload)
|
382 |
|
@@ -387,4 +372,4 @@ class BotClient(object):
|
|
387 |
else:
|
388 |
self.logger.info(f"请求失败,状态码: {response.status_code}")
|
389 |
self.logger.info(response.text)
|
390 |
-
return results
|
|
|
14 |
|
15 |
"""BotClient class for interacting with bot models."""
|
16 |
|
|
|
17 |
import argparse
|
18 |
+
import json
|
19 |
import logging
|
20 |
import traceback
|
21 |
+
|
22 |
import jieba
|
23 |
+
import requests
|
24 |
from openai import OpenAI
|
25 |
|
|
|
26 |
|
27 |
+
class BotClient:
|
28 |
"""Client for interacting with various AI models."""
|
29 |
+
|
30 |
def __init__(self, args: argparse.Namespace):
|
31 |
"""
|
32 |
+
Initializes the BotClient instance by configuring essential parameters from command line arguments
|
33 |
+
including retry limits, character constraints, model endpoints and API credentials while setting up
|
34 |
default values for missing arguments to ensure robust operation.
|
35 |
|
36 |
Args:
|
|
|
38 |
Uses getattr() to safely retrieve values with fallback defaults.
|
39 |
"""
|
40 |
self.logger = logging.getLogger(__name__)
|
41 |
+
|
42 |
self.max_retry_num = getattr(args, 'max_retry_num', 3)
|
43 |
self.max_char = getattr(args, 'max_char', 8000)
|
44 |
|
|
|
55 |
|
56 |
def call_back(self, host_url: str, req_data: dict) -> dict:
|
57 |
"""
|
58 |
+
Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response
|
59 |
+
conversion to a compatible dictionary format, and manages any exceptions that may occur during
|
60 |
the request process while logging errors appropriately.
|
61 |
|
62 |
Args:
|
|
|
69 |
"""
|
70 |
try:
|
71 |
client = OpenAI(base_url=host_url, api_key=self.api_key)
|
72 |
+
response = client.chat.completions.create(**req_data)
|
73 |
+
|
|
|
|
|
74 |
# Convert OpenAI response to compatible format
|
75 |
return response.model_dump()
|
76 |
|
77 |
except Exception as e:
|
78 |
+
self.logger.error(f"Stream request failed: {e}")
|
79 |
raise
|
80 |
|
81 |
def call_back_stream(self, host_url: str, req_data: dict) -> dict:
|
82 |
"""
|
83 |
+
Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks
|
84 |
in real-time while handling any exceptions that may occur during the streaming process.
|
85 |
|
86 |
Args:
|
|
|
99 |
for chunk in response:
|
100 |
if not chunk.choices:
|
101 |
continue
|
102 |
+
|
103 |
# Convert OpenAI response to compatible format
|
104 |
yield chunk.model_dump()
|
105 |
|
106 |
except Exception as e:
|
107 |
+
self.logger.error(f"Stream request failed: {e}")
|
108 |
raise
|
109 |
|
110 |
def process(
|
111 |
+
self, model_name: str, req_data: dict, max_tokens: int = 2048, temperature: float = 1.0, top_p: float = 0.7
|
|
|
|
|
|
|
|
|
|
|
112 |
) -> dict:
|
113 |
"""
|
114 |
+
Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
|
115 |
+
including token limits and sampling settings, truncating messages to fit character limits, making API calls
|
116 |
with built-in retry mechanism, and logging the full request/response cycle for debugging purposes.
|
117 |
|
118 |
Args:
|
|
|
134 |
req_data["messages"] = self.truncate_messages(req_data["messages"])
|
135 |
for _ in range(self.max_retry_num):
|
136 |
try:
|
137 |
+
self.logger.info(f"[MODEL] {model_url}")
|
138 |
self.logger.info("[req_data]====>")
|
139 |
self.logger.info(json.dumps(req_data, ensure_ascii=False))
|
140 |
res = self.call_back(model_url, req_data)
|
|
|
147 |
res = {}
|
148 |
if len(res) != 0 and "error" not in res:
|
149 |
break
|
150 |
+
|
151 |
return res
|
152 |
|
153 |
def process_stream(
|
154 |
+
self, model_name: str, req_data: dict, max_tokens: int = 2048, temperature: float = 1.0, top_p: float = 0.7
|
|
|
|
|
|
|
|
|
155 |
) -> dict:
|
156 |
"""
|
157 |
Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
|
|
|
174 |
req_data["temperature"] = temperature
|
175 |
req_data["top_p"] = top_p
|
176 |
req_data["messages"] = self.truncate_messages(req_data["messages"])
|
177 |
+
|
178 |
last_error = None
|
179 |
for _ in range(self.max_retry_num):
|
180 |
try:
|
181 |
+
self.logger.info(f"[MODEL] {model_url}")
|
182 |
self.logger.info("[req_data]====>")
|
183 |
self.logger.info(json.dumps(req_data, ensure_ascii=False))
|
184 |
+
|
185 |
+
yield from self.call_back_stream(model_url, req_data)
|
|
|
186 |
return
|
187 |
+
|
188 |
except Exception as e:
|
189 |
last_error = e
|
190 |
+
self.logger.error(f"Stream request failed (attempt {_ + 1}/{self.max_retry_num}): {e}")
|
191 |
+
|
192 |
self.logger.error("All retry attempts failed for stream request")
|
193 |
yield {"error": str(last_error)}
|
194 |
|
195 |
def cut_chinese_english(self, text: str) -> list:
|
196 |
"""
|
197 |
+
Segments mixed Chinese and English text into individual components using Jieba for Chinese words
|
198 |
+
while preserving English words as whole units, with special handling for Unicode character ranges
|
199 |
to distinguish between the two languages.
|
200 |
|
201 |
Args:
|
|
|
228 |
"""
|
229 |
if not messages:
|
230 |
return messages
|
231 |
+
|
232 |
processed = []
|
233 |
total_units = 0
|
234 |
+
|
235 |
for msg in messages:
|
236 |
# Handle two different content formats
|
237 |
if isinstance(msg["content"], str):
|
|
|
240 |
text_content = msg["content"][1]["text"]
|
241 |
else:
|
242 |
text_content = ""
|
243 |
+
|
244 |
# Calculate unit count after tokenization
|
245 |
units = self.cut_chinese_english(text_content)
|
246 |
unit_count = len(units)
|
247 |
+
|
248 |
+
processed.append(
|
249 |
+
{
|
250 |
+
"role": msg["role"],
|
251 |
+
"original_content": msg["content"], # Preserve original content
|
252 |
+
"text_content": text_content, # Extracted plain text
|
253 |
+
"units": units,
|
254 |
+
"unit_count": unit_count,
|
255 |
+
}
|
256 |
+
)
|
257 |
total_units += unit_count
|
258 |
+
|
259 |
if total_units <= self.max_char:
|
260 |
return messages
|
261 |
+
|
262 |
# Number of units to remove
|
263 |
to_remove = total_units - self.max_char
|
264 |
+
|
265 |
# 1. Truncate historical messages
|
266 |
for i in range(len(processed) - 1, 1):
|
267 |
if to_remove <= 0:
|
268 |
break
|
269 |
+
|
270 |
# current = processed[i]
|
271 |
if processed[i]["unit_count"] <= to_remove:
|
272 |
processed[i]["text_content"] = ""
|
|
|
284 |
elif isinstance(processed[i]["original_content"], list):
|
285 |
processed[i]["original_content"][1]["text"] = new_text
|
286 |
to_remove = 0
|
287 |
+
|
288 |
# 2. Truncate system message
|
289 |
if to_remove > 0:
|
290 |
system_msg = processed[0]
|
|
|
304 |
elif isinstance(processed[0]["original_content"], list):
|
305 |
processed[0]["original_content"][1]["text"] = new_text
|
306 |
to_remove = 0
|
307 |
+
|
308 |
# 3. Truncate last message
|
309 |
if to_remove > 0 and len(processed) > 1:
|
310 |
last_msg = processed[-1]
|
|
|
322 |
last_msg["original_content"] = ""
|
323 |
elif isinstance(last_msg["original_content"], list):
|
324 |
last_msg["original_content"][1]["text"] = ""
|
325 |
+
|
326 |
result = []
|
327 |
for msg in processed:
|
328 |
if msg["text_content"]:
|
329 |
+
result.append({"role": msg["role"], "content": msg["original_content"]})
|
330 |
+
|
|
|
|
|
|
|
331 |
return result
|
332 |
|
333 |
def embed_fn(self, text: str) -> list:
|
|
|
354 |
Returns:
|
355 |
list: List of responses from the AI Search service.
|
356 |
"""
|
357 |
+
headers = {"Authorization": "Bearer " + self.qianfan_api_key, "Content-Type": "application/json"}
|
|
|
|
|
|
|
358 |
|
359 |
results = []
|
360 |
top_k = self.max_search_results_num // len(query_list)
|
361 |
for query in query_list:
|
362 |
payload = {
|
363 |
"messages": [{"role": "user", "content": query}],
|
364 |
+
"resource_type_filter": [{"type": "web", "top_k": top_k}],
|
365 |
}
|
366 |
response = requests.post(self.web_search_service_url, headers=headers, json=payload)
|
367 |
|
|
|
372 |
else:
|
373 |
self.logger.info(f"请求失败,状态码: {response.status_code}")
|
374 |
self.logger.info(response.text)
|
375 |
+
return results
|