maxiaolong03 commited on
Commit
a93c636
·
1 Parent(s): 47fd9da
Files changed (2) hide show
  1. app.py +122 -169
  2. bot_requests.py +56 -71
app.py CHANGED
@@ -15,23 +15,22 @@
15
  """This file contains the code for the chatbot demo using Gradio."""
16
 
17
  import argparse
18
- from collections import namedtuple
19
- from functools import partial
20
  import json
21
  import logging
22
  import os
23
- import base64
24
  from argparse import ArgumentParser
 
 
25
 
26
  import gradio as gr
27
-
28
  from bot_requests import BotClient
29
 
30
  os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
31
 
32
  logging.root.setLevel(logging.INFO)
33
 
34
- MULTI_MODEL_PREFIX = "ernie-4.5-turbo-vl"
35
 
36
 
37
  def get_args() -> argparse.Namespace:
@@ -48,21 +47,13 @@ def get_args() -> argparse.Namespace:
48
  """
49
  parser = ArgumentParser(description="ERNIE models web chat demo.")
50
 
 
 
 
 
51
  parser.add_argument(
52
- "--server-port", type=int, default=7860, help="Demo server port."
53
- )
54
- parser.add_argument(
55
- "--server-name", type=str, default="0.0.0.0", help="Demo server name."
56
- )
57
- parser.add_argument(
58
- "--max_char", type=int, default=8000, help="Maximum character limit for messages."
59
- )
60
- parser.add_argument(
61
- "--max_retry_num", type=int, default=3, help="Maximum retry number for request."
62
- )
63
- parser.add_argument(
64
- "--model_map",
65
- type=str,
66
  default="""{
67
  "ernie-4.5-turbo-128k-preview": "https://qianfan.baidubce.com/v2",
68
  "ernie-4.5-21b-a3b": "https://qianfan.baidubce.com/v2",
@@ -80,7 +71,7 @@ def get_args() -> argparse.Namespace:
80
  - Prefix determines model capabilities:
81
  * ERNIE-4.5[-*]: Text-only model
82
  * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
83
- """
84
  )
85
 
86
  args = parser.parse_args()
@@ -96,7 +87,7 @@ def get_args() -> argparse.Namespace:
96
  return args
97
 
98
 
99
- class GradioEvents(object):
100
  """
101
  Central handler for all Gradio interface events in the chatbot demo. Provides static methods
102
  for processing user interactions including:
@@ -104,16 +95,17 @@ class GradioEvents(object):
104
  - Conversation state management
105
  - Image handling and URL conversion
106
  - Component visibility control
107
-
108
- Coordinates with BotClient to interface with backend models while maintaining
109
  conversation history and handling multimodal inputs.
110
  """
 
111
  @staticmethod
112
  def get_image_url(image_path: str) -> str:
113
  """
114
- Converts an image file at the given path to a base64 encoded data URL
115
- that can be used directly in HTML or Gradio interfaces.
116
- Reads the image file, encodes it in base64 format, and constructs
117
  a data URL with the appropriate image MIME type.
118
 
119
  Args:
@@ -126,26 +118,26 @@ class GradioEvents(object):
126
  extension = image_path.split(".")[-1]
127
  with open(image_path, "rb") as image_file:
128
  base64_image = base64.b64encode(image_file.read()).decode("utf-8")
129
- url = "data:image/{ext};base64,{img}".format(ext=extension, img=base64_image)
130
  return url
131
 
132
  @staticmethod
133
  def chat_stream(
134
- query: str,
135
- task_history: list,
136
- image_history: dict,
137
- model_name: str,
138
- file_url: str,
139
- system_msg: str,
140
- max_tokens: int,
141
- temperature: float,
142
- top_p: float,
143
- bot_client: BotClient
144
  ) -> str:
145
  """
146
- Handles streaming chat interactions by processing user queries and
147
- generating real-time responses from the bot client. Constructs conversation
148
- history including system messages, text inputs and image attachments, then
149
  streams back model responses.
150
 
151
  Args:
@@ -169,10 +161,9 @@ class GradioEvents(object):
169
  for idx, (query_h, response_h) in enumerate(task_history):
170
  if idx in image_history:
171
  content = []
172
- content.append({
173
- "type": "image_url",
174
- "image_url": {"url": GradioEvents.get_image_url(image_history[idx])}
175
- })
176
  content.append({"type": "text", "text": query_h})
177
  conversation.append({"role": "user", "content": content})
178
  else:
@@ -193,29 +184,29 @@ class GradioEvents(object):
193
  for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
194
  if "error" in chunk:
195
  raise Exception(chunk["error"])
196
-
197
  message = chunk.get("choices", [{}])[0].get("delta", {})
198
  content = message.get("content", "")
199
-
200
  if content:
201
  yield content
202
-
203
  except Exception as e:
204
  raise gr.Error("Exception: " + repr(e))
205
 
206
  @staticmethod
207
  def predict_stream(
208
- query: str,
209
- chatbot: list,
210
- task_history: list,
211
- image_history: dict,
212
- model: str,
213
- file_url: str,
214
- system_msg: str,
215
- max_tokens: int,
216
- temperature: float,
217
- top_p: float,
218
- bot_client: BotClient
219
  ) -> list:
220
  """
221
  Processes user queries in a streaming manner by coordinating with the chat stream handler,
@@ -240,29 +231,20 @@ class GradioEvents(object):
240
  list: A list containing the updated chatbot state after processing the user's query.
241
  """
242
 
243
- logging.info("User: {}".format(query))
244
- chatbot.append({"role": "user", "content": query})
245
-
246
  # First yield the chatbot with user message
247
  yield chatbot
248
 
249
  new_texts = GradioEvents.chat_stream(
250
- query,
251
- task_history,
252
- image_history,
253
- model,
254
- file_url,
255
- system_msg,
256
- max_tokens,
257
- temperature,
258
- top_p,
259
- bot_client
260
  )
261
 
262
  response = ""
263
- for new_text in new_texts:
264
  response += new_text
265
-
266
  # Remove previous message if exists
267
  if chatbot[-1].get("role") == "assistant":
268
  chatbot.pop(-1)
@@ -271,26 +253,26 @@ class GradioEvents(object):
271
  chatbot.append({"role": "assistant", "content": response})
272
  yield chatbot
273
 
274
- logging.info("History: {}".format(task_history))
275
- task_history.append((query, response))
276
- logging.info("ERNIE models: {}".format(response))
277
 
278
  @staticmethod
279
  def regenerate(
280
- chatbot: list,
281
- task_history: list,
282
- image_history: dict,
283
- model: str,
284
- file_url: str,
285
- system_msg: str,
286
- max_tokens: int,
287
- temperature: float,
288
- top_p: float,
289
- bot_client: BotClient
290
  ) -> list:
291
  """
292
- Reconstructs the conversation context by removing the last interaction and
293
- reprocesses the user's previous query to generate a fresh response. Maintains
294
  consistency in conversation flow while allowing response regeneration.
295
 
296
  Args:
@@ -319,26 +301,25 @@ class GradioEvents(object):
319
  chatbot.pop(-1)
320
  chatbot.pop(-1)
321
 
322
- for chunk in GradioEvents.predict_stream(
323
- item[0],
324
- chatbot,
325
- task_history,
326
  image_history,
327
- model,
328
  file_url,
329
- system_msg,
330
- max_tokens,
331
- temperature,
332
  top_p,
333
- bot_client
334
- ):
335
- yield chunk
336
 
337
  @staticmethod
338
  def reset_user_input() -> gr.update:
339
  """
340
  Reset user input field value to empty string.
341
-
342
  Returns:
343
  gr.update: Update object representing the new value of the user input field.
344
  """
@@ -348,7 +329,7 @@ class GradioEvents(object):
348
  def reset_state() -> tuple:
349
  """
350
  Reset all states including chatbot, task_history, image_history, and file_btn.
351
-
352
  Returns:
353
  tuple: A tuple containing the following values:
354
  - chatbot (list): An empty list that represents the cleared chatbot state.
@@ -357,19 +338,15 @@ class GradioEvents(object):
357
  - file_btn (gr.update): An update object that sets the value of the file button to None.
358
  """
359
  GradioEvents.gc()
360
-
361
- reset_result = namedtuple("reset_result",
362
- ["chatbot",
363
- "task_history",
364
- "image_history",
365
- "file_btn"])
366
  return reset_result(
367
  [], # clear chatbot
368
  [], # clear task_history
369
  {}, # clear image_history
370
  gr.update(value=None), # clear file_btn
371
  )
372
-
373
  @staticmethod
374
  def gc():
375
  """Run garbage collection to free up memory resources."""
@@ -381,10 +358,10 @@ class GradioEvents(object):
381
  def toggle_components_visibility(model_name: str) -> gr.update:
382
  """
383
  Toggle visibility of components depending on the selected model name.
384
-
385
  Args:
386
  model_name (str): Name of the selected model.
387
-
388
  Returns:
389
  gr.update: An update object representing the visibility of the file button.
390
  """
@@ -394,7 +371,7 @@ class GradioEvents(object):
394
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
395
  """
396
  Launch demo program
397
-
398
  Args:
399
  args (argparse.Namespace): argparse Namespace object containing parsed command line arguments
400
  bot_client (BotClient): Bot client instance
@@ -420,34 +397,29 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
420
  """
421
  with gr.Blocks(css=css) as demo:
422
  logo_url = GradioEvents.get_image_url("assets/logo.png")
423
- gr.Markdown("""\
424
- <p align="center"><img src="{}" \
425
- style="height: 60px"/><p>""".format(logo_url))
 
 
426
  gr.Markdown(
427
  """\
428
  <center><font size=3>This demo is based on ERNIE models. \
429
  (本演示基于文心大模型实现。)</center>"""
430
  )
431
 
432
- chatbot = gr.Chatbot(
433
- label="ERNIE",
434
- elem_classes="control-height",
435
- type="messages"
436
- )
437
  model_names = list(args.model_map.keys())
438
  with gr.Row():
439
  model_name = gr.Dropdown(
440
- label="Select Model",
441
- choices=model_names,
442
- value=model_names[0],
443
- allow_custom_value=True
444
  )
445
  file_btn = gr.File(
446
- label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
447
- height="80px",
448
- visible=True,
449
  file_types=[".png", ".jpeg", "jpg"],
450
- elem_id="file-upload"
451
  )
452
  query = gr.Textbox(label="Input", elem_id="text_input")
453
 
@@ -462,66 +434,46 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
462
  system_message,
463
  gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
464
  gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Temperature"),
465
- gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Top-p (nucleus sampling)")
466
  ]
467
-
468
  task_history = gr.State([])
469
  image_history = gr.State({})
470
-
471
- model_name.change(
472
- GradioEvents.toggle_components_visibility,
473
- inputs=model_name,
474
- outputs=file_btn
475
- )
476
  model_name.change(
477
- GradioEvents.reset_state,
478
- outputs=[chatbot, task_history, image_history, file_btn],
479
- show_progress=True
480
- )
481
- predict_with_clients = partial(
482
- GradioEvents.predict_stream,
483
- bot_client=bot_client
484
- )
485
- regenerate_with_clients = partial(
486
- GradioEvents.regenerate,
487
- bot_client=bot_client
488
  )
 
 
489
  query.submit(
490
- predict_with_clients,
491
- inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
492
- outputs=[chatbot],
493
- show_progress=True
494
  )
495
  query.submit(GradioEvents.reset_user_input, [], [query])
496
  submit_btn.click(
497
- predict_with_clients,
498
- inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
499
- outputs=[chatbot],
500
  show_progress=True,
501
  )
502
  submit_btn.click(GradioEvents.reset_user_input, [], [query])
503
  empty_btn.click(
504
- GradioEvents.reset_state,
505
- outputs=[chatbot, task_history, image_history, file_btn],
506
- show_progress=True
507
  )
508
  regen_btn.click(
509
- regenerate_with_clients,
510
- inputs=[chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
511
- outputs=[chatbot],
512
- show_progress=True
513
  )
514
 
515
- demo.load(
516
- GradioEvents.toggle_components_visibility,
517
- inputs=gr.State(model_names[0]),
518
- outputs=file_btn
519
- )
520
 
521
- demo.queue().launch(
522
- server_port=args.server_port,
523
- server_name=args.server_name
524
- )
525
 
526
  def main():
527
  """Main function that runs when this script is executed."""
@@ -529,5 +481,6 @@ def main():
529
  bot_client = BotClient(args)
530
  launch_demo(args, bot_client)
531
 
 
532
  if __name__ == "__main__":
533
  main()
 
15
  """This file contains the code for the chatbot demo using Gradio."""
16
 
17
  import argparse
18
+ import base64
 
19
  import json
20
  import logging
21
  import os
 
22
  from argparse import ArgumentParser
23
+ from collections import namedtuple
24
+ from functools import partial
25
 
26
  import gradio as gr
 
27
  from bot_requests import BotClient
28
 
29
  os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
30
 
31
  logging.root.setLevel(logging.INFO)
32
 
33
+ MULTI_MODEL_PREFIX = "ERNIE-4.5-VL"
34
 
35
 
36
  def get_args() -> argparse.Namespace:
 
47
  """
48
  parser = ArgumentParser(description="ERNIE models web chat demo.")
49
 
50
+ parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.")
51
+ parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.")
52
+ parser.add_argument("--max_char", type=int, default=8000, help="Maximum character limit for messages.")
53
+ parser.add_argument("--max_retry_num", type=int, default=3, help="Maximum retry number for request.")
54
  parser.add_argument(
55
+ "--model_map",
56
+ type=str,
 
 
 
 
 
 
 
 
 
 
 
 
57
  default="""{
58
  "ernie-4.5-turbo-128k-preview": "https://qianfan.baidubce.com/v2",
59
  "ernie-4.5-21b-a3b": "https://qianfan.baidubce.com/v2",
 
71
  - Prefix determines model capabilities:
72
  * ERNIE-4.5[-*]: Text-only model
73
  * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
74
+ """,
75
  )
76
 
77
  args = parser.parse_args()
 
87
  return args
88
 
89
 
90
+ class GradioEvents:
91
  """
92
  Central handler for all Gradio interface events in the chatbot demo. Provides static methods
93
  for processing user interactions including:
 
95
  - Conversation state management
96
  - Image handling and URL conversion
97
  - Component visibility control
98
+
99
+ Coordinates with BotClient to interface with backend models while maintaining
100
  conversation history and handling multimodal inputs.
101
  """
102
+
103
  @staticmethod
104
  def get_image_url(image_path: str) -> str:
105
  """
106
+ Converts an image file at the given path to a base64 encoded data URL
107
+ that can be used directly in HTML or Gradio interfaces.
108
+ Reads the image file, encodes it in base64 format, and constructs
109
  a data URL with the appropriate image MIME type.
110
 
111
  Args:
 
118
  extension = image_path.split(".")[-1]
119
  with open(image_path, "rb") as image_file:
120
  base64_image = base64.b64encode(image_file.read()).decode("utf-8")
121
+ url = f"data:image/{extension};base64,{base64_image}"
122
  return url
123
 
124
  @staticmethod
125
  def chat_stream(
126
+ query: str,
127
+ task_history: list,
128
+ image_history: dict,
129
+ model_name: str,
130
+ file_url: str,
131
+ system_msg: str,
132
+ max_tokens: int,
133
+ temperature: float,
134
+ top_p: float,
135
+ bot_client: BotClient,
136
  ) -> str:
137
  """
138
+ Handles streaming chat interactions by processing user queries and
139
+ generating real-time responses from the bot client. Constructs conversation
140
+ history including system messages, text inputs and image attachments, then
141
  streams back model responses.
142
 
143
  Args:
 
161
  for idx, (query_h, response_h) in enumerate(task_history):
162
  if idx in image_history:
163
  content = []
164
+ content.append(
165
+ {"type": "image_url", "image_url": {"url": GradioEvents.get_image_url(image_history[idx])}}
166
+ )
 
167
  content.append({"type": "text", "text": query_h})
168
  conversation.append({"role": "user", "content": content})
169
  else:
 
184
  for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
185
  if "error" in chunk:
186
  raise Exception(chunk["error"])
187
+
188
  message = chunk.get("choices", [{}])[0].get("delta", {})
189
  content = message.get("content", "")
190
+
191
  if content:
192
  yield content
193
+
194
  except Exception as e:
195
  raise gr.Error("Exception: " + repr(e))
196
 
197
  @staticmethod
198
  def predict_stream(
199
+ query: str,
200
+ chatbot: list,
201
+ task_history: list,
202
+ image_history: dict,
203
+ model: str,
204
+ file_url: str,
205
+ system_msg: str,
206
+ max_tokens: int,
207
+ temperature: float,
208
+ top_p: float,
209
+ bot_client: BotClient,
210
  ) -> list:
211
  """
212
  Processes user queries in a streaming manner by coordinating with the chat stream handler,
 
231
  list: A list containing the updated chatbot state after processing the user's query.
232
  """
233
 
234
+ logging.info(f"User: {query}")
235
+ chatbot.append({"role": "user", "content": query})
236
+
237
  # First yield the chatbot with user message
238
  yield chatbot
239
 
240
  new_texts = GradioEvents.chat_stream(
241
+ query, task_history, image_history, model, file_url, system_msg, max_tokens, temperature, top_p, bot_client
 
 
 
 
 
 
 
 
 
242
  )
243
 
244
  response = ""
245
+ for new_text in new_texts:
246
  response += new_text
247
+
248
  # Remove previous message if exists
249
  if chatbot[-1].get("role") == "assistant":
250
  chatbot.pop(-1)
 
253
  chatbot.append({"role": "assistant", "content": response})
254
  yield chatbot
255
 
256
+ logging.info(f"History: {task_history}")
257
+ task_history.append((query, response))
258
+ logging.info(f"ERNIE models: {response}")
259
 
260
  @staticmethod
261
  def regenerate(
262
+ chatbot: list,
263
+ task_history: list,
264
+ image_history: dict,
265
+ model: str,
266
+ file_url: str,
267
+ system_msg: str,
268
+ max_tokens: int,
269
+ temperature: float,
270
+ top_p: float,
271
+ bot_client: BotClient,
272
  ) -> list:
273
  """
274
+ Reconstructs the conversation context by removing the last interaction and
275
+ reprocesses the user's previous query to generate a fresh response. Maintains
276
  consistency in conversation flow while allowing response regeneration.
277
 
278
  Args:
 
301
  chatbot.pop(-1)
302
  chatbot.pop(-1)
303
 
304
+ yield from GradioEvents.predict_stream(
305
+ item[0],
306
+ chatbot,
307
+ task_history,
308
  image_history,
309
+ model,
310
  file_url,
311
+ system_msg,
312
+ max_tokens,
313
+ temperature,
314
  top_p,
315
+ bot_client,
316
+ )
 
317
 
318
  @staticmethod
319
  def reset_user_input() -> gr.update:
320
  """
321
  Reset user input field value to empty string.
322
+
323
  Returns:
324
  gr.update: Update object representing the new value of the user input field.
325
  """
 
329
  def reset_state() -> tuple:
330
  """
331
  Reset all states including chatbot, task_history, image_history, and file_btn.
332
+
333
  Returns:
334
  tuple: A tuple containing the following values:
335
  - chatbot (list): An empty list that represents the cleared chatbot state.
 
338
  - file_btn (gr.update): An update object that sets the value of the file button to None.
339
  """
340
  GradioEvents.gc()
341
+
342
+ reset_result = namedtuple("reset_result", ["chatbot", "task_history", "image_history", "file_btn"])
 
 
 
 
343
  return reset_result(
344
  [], # clear chatbot
345
  [], # clear task_history
346
  {}, # clear image_history
347
  gr.update(value=None), # clear file_btn
348
  )
349
+
350
  @staticmethod
351
  def gc():
352
  """Run garbage collection to free up memory resources."""
 
358
  def toggle_components_visibility(model_name: str) -> gr.update:
359
  """
360
  Toggle visibility of components depending on the selected model name.
361
+
362
  Args:
363
  model_name (str): Name of the selected model.
364
+
365
  Returns:
366
  gr.update: An update object representing the visibility of the file button.
367
  """
 
371
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
372
  """
373
  Launch demo program
374
+
375
  Args:
376
  args (argparse.Namespace): argparse Namespace object containing parsed command line arguments
377
  bot_client (BotClient): Bot client instance
 
397
  """
398
  with gr.Blocks(css=css) as demo:
399
  logo_url = GradioEvents.get_image_url("assets/logo.png")
400
+ gr.Markdown(
401
+ f"""\
402
+ <p align="center"><img src="{logo_url}" \
403
+ style="height: 60px"/><p>"""
404
+ )
405
  gr.Markdown(
406
  """\
407
  <center><font size=3>This demo is based on ERNIE models. \
408
  (本演示基于文心大模型实现。)</center>"""
409
  )
410
 
411
+ chatbot = gr.Chatbot(label="ERNIE", elem_classes="control-height", type="messages")
 
 
 
 
412
  model_names = list(args.model_map.keys())
413
  with gr.Row():
414
  model_name = gr.Dropdown(
415
+ label="Select Model", choices=model_names, value=model_names[0], allow_custom_value=True
 
 
 
416
  )
417
  file_btn = gr.File(
418
+ label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
419
+ height="80px",
420
+ visible=True,
421
  file_types=[".png", ".jpeg", "jpg"],
422
+ elem_id="file-upload",
423
  )
424
  query = gr.Textbox(label="Input", elem_id="text_input")
425
 
 
434
  system_message,
435
  gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
436
  gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Temperature"),
437
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Top-p (nucleus sampling)"),
438
  ]
439
+
440
  task_history = gr.State([])
441
  image_history = gr.State({})
442
+
443
+ model_name.change(GradioEvents.toggle_components_visibility, inputs=model_name, outputs=file_btn)
 
 
 
 
444
  model_name.change(
445
+ GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
 
 
 
 
 
 
 
 
 
 
446
  )
447
+ predict_with_clients = partial(GradioEvents.predict_stream, bot_client=bot_client)
448
+ regenerate_with_clients = partial(GradioEvents.regenerate, bot_client=bot_client)
449
  query.submit(
450
+ predict_with_clients,
451
+ inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
452
+ outputs=[chatbot],
453
+ show_progress=True,
454
  )
455
  query.submit(GradioEvents.reset_user_input, [], [query])
456
  submit_btn.click(
457
+ predict_with_clients,
458
+ inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
459
+ outputs=[chatbot],
460
  show_progress=True,
461
  )
462
  submit_btn.click(GradioEvents.reset_user_input, [], [query])
463
  empty_btn.click(
464
+ GradioEvents.reset_state, outputs=[chatbot, task_history, image_history, file_btn], show_progress=True
 
 
465
  )
466
  regen_btn.click(
467
+ regenerate_with_clients,
468
+ inputs=[chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
469
+ outputs=[chatbot],
470
+ show_progress=True,
471
  )
472
 
473
+ demo.load(GradioEvents.toggle_components_visibility, inputs=gr.State(model_names[0]), outputs=file_btn)
474
+
475
+ demo.queue().launch(server_port=args.server_port, server_name=args.server_name)
 
 
476
 
 
 
 
 
477
 
478
  def main():
479
  """Main function that runs when this script is executed."""
 
481
  bot_client = BotClient(args)
482
  launch_demo(args, bot_client)
483
 
484
+
485
  if __name__ == "__main__":
486
  main()
bot_requests.py CHANGED
@@ -14,22 +14,23 @@
14
 
15
  """BotClient class for interacting with bot models."""
16
 
17
- import os
18
  import argparse
 
19
  import logging
20
  import traceback
21
- import json
22
  import jieba
 
23
  from openai import OpenAI
24
 
25
- import requests
26
 
27
- class BotClient(object):
28
  """Client for interacting with various AI models."""
 
29
  def __init__(self, args: argparse.Namespace):
30
  """
31
- Initializes the BotClient instance by configuring essential parameters from command line arguments
32
- including retry limits, character constraints, model endpoints and API credentials while setting up
33
  default values for missing arguments to ensure robust operation.
34
 
35
  Args:
@@ -37,7 +38,7 @@ class BotClient(object):
37
  Uses getattr() to safely retrieve values with fallback defaults.
38
  """
39
  self.logger = logging.getLogger(__name__)
40
-
41
  self.max_retry_num = getattr(args, 'max_retry_num', 3)
42
  self.max_char = getattr(args, 'max_char', 8000)
43
 
@@ -54,8 +55,8 @@ class BotClient(object):
54
 
55
  def call_back(self, host_url: str, req_data: dict) -> dict:
56
  """
57
- Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response
58
- conversion to a compatible dictionary format, and manages any exceptions that may occur during
59
  the request process while logging errors appropriately.
60
 
61
  Args:
@@ -68,20 +69,18 @@ class BotClient(object):
68
  """
69
  try:
70
  client = OpenAI(base_url=host_url, api_key=self.api_key)
71
- response = client.chat.completions.create(
72
- **req_data
73
- )
74
-
75
  # Convert OpenAI response to compatible format
76
  return response.model_dump()
77
 
78
  except Exception as e:
79
- self.logger.error("Stream request failed: {}".format(e))
80
  raise
81
 
82
  def call_back_stream(self, host_url: str, req_data: dict) -> dict:
83
  """
84
- Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks
85
  in real-time while handling any exceptions that may occur during the streaming process.
86
 
87
  Args:
@@ -100,25 +99,20 @@ class BotClient(object):
100
  for chunk in response:
101
  if not chunk.choices:
102
  continue
103
-
104
  # Convert OpenAI response to compatible format
105
  yield chunk.model_dump()
106
 
107
  except Exception as e:
108
- self.logger.error("Stream request failed: {}".format(e))
109
  raise
110
 
111
  def process(
112
- self,
113
- model_name: str,
114
- req_data: dict,
115
- max_tokens: int=2048,
116
- temperature: float=1.0,
117
- top_p: float=0.7
118
  ) -> dict:
119
  """
120
- Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
121
- including token limits and sampling settings, truncating messages to fit character limits, making API calls
122
  with built-in retry mechanism, and logging the full request/response cycle for debugging purposes.
123
 
124
  Args:
@@ -140,7 +134,7 @@ class BotClient(object):
140
  req_data["messages"] = self.truncate_messages(req_data["messages"])
141
  for _ in range(self.max_retry_num):
142
  try:
143
- self.logger.info("[MODEL] {}".format(model_url))
144
  self.logger.info("[req_data]====>")
145
  self.logger.info(json.dumps(req_data, ensure_ascii=False))
146
  res = self.call_back(model_url, req_data)
@@ -153,15 +147,11 @@ class BotClient(object):
153
  res = {}
154
  if len(res) != 0 and "error" not in res:
155
  break
156
-
157
  return res
158
 
159
  def process_stream(
160
- self, model_name: str,
161
- req_data: dict,
162
- max_tokens: int=2048,
163
- temperature: float=1.0,
164
- top_p: float=0.7
165
  ) -> dict:
166
  """
167
  Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
@@ -184,29 +174,28 @@ class BotClient(object):
184
  req_data["temperature"] = temperature
185
  req_data["top_p"] = top_p
186
  req_data["messages"] = self.truncate_messages(req_data["messages"])
187
-
188
  last_error = None
189
  for _ in range(self.max_retry_num):
190
  try:
191
- self.logger.info("[MODEL] {}".format(model_url))
192
  self.logger.info("[req_data]====>")
193
  self.logger.info(json.dumps(req_data, ensure_ascii=False))
194
-
195
- for chunk in self.call_back_stream(model_url, req_data):
196
- yield chunk
197
  return
198
-
199
  except Exception as e:
200
  last_error = e
201
- self.logger.error("Stream request failed (attempt {}/{}): {}".format(_ + 1, self.max_retry_num, e))
202
-
203
  self.logger.error("All retry attempts failed for stream request")
204
  yield {"error": str(last_error)}
205
 
206
  def cut_chinese_english(self, text: str) -> list:
207
  """
208
- Segments mixed Chinese and English text into individual components using Jieba for Chinese words
209
- while preserving English words as whole units, with special handling for Unicode character ranges
210
  to distinguish between the two languages.
211
 
212
  Args:
@@ -239,10 +228,10 @@ class BotClient(object):
239
  """
240
  if not messages:
241
  return messages
242
-
243
  processed = []
244
  total_units = 0
245
-
246
  for msg in messages:
247
  # Handle two different content formats
248
  if isinstance(msg["content"], str):
@@ -251,31 +240,33 @@ class BotClient(object):
251
  text_content = msg["content"][1]["text"]
252
  else:
253
  text_content = ""
254
-
255
  # Calculate unit count after tokenization
256
  units = self.cut_chinese_english(text_content)
257
  unit_count = len(units)
258
-
259
- processed.append({
260
- "role": msg["role"],
261
- "original_content": msg["content"], # Preserve original content
262
- "text_content": text_content, # Extracted plain text
263
- "units": units,
264
- "unit_count": unit_count
265
- })
 
 
266
  total_units += unit_count
267
-
268
  if total_units <= self.max_char:
269
  return messages
270
-
271
  # Number of units to remove
272
  to_remove = total_units - self.max_char
273
-
274
  # 1. Truncate historical messages
275
  for i in range(len(processed) - 1, 1):
276
  if to_remove <= 0:
277
  break
278
-
279
  # current = processed[i]
280
  if processed[i]["unit_count"] <= to_remove:
281
  processed[i]["text_content"] = ""
@@ -293,7 +284,7 @@ class BotClient(object):
293
  elif isinstance(processed[i]["original_content"], list):
294
  processed[i]["original_content"][1]["text"] = new_text
295
  to_remove = 0
296
-
297
  # 2. Truncate system message
298
  if to_remove > 0:
299
  system_msg = processed[0]
@@ -313,7 +304,7 @@ class BotClient(object):
313
  elif isinstance(processed[0]["original_content"], list):
314
  processed[0]["original_content"][1]["text"] = new_text
315
  to_remove = 0
316
-
317
  # 3. Truncate last message
318
  if to_remove > 0 and len(processed) > 1:
319
  last_msg = processed[-1]
@@ -331,15 +322,12 @@ class BotClient(object):
331
  last_msg["original_content"] = ""
332
  elif isinstance(last_msg["original_content"], list):
333
  last_msg["original_content"][1]["text"] = ""
334
-
335
  result = []
336
  for msg in processed:
337
  if msg["text_content"]:
338
- result.append({
339
- "role": msg["role"],
340
- "content": msg["original_content"]
341
- })
342
-
343
  return result
344
 
345
  def embed_fn(self, text: str) -> list:
@@ -366,17 +354,14 @@ class BotClient(object):
366
  Returns:
367
  list: List of responses from the AI Search service.
368
  """
369
- headers = {
370
- "Authorization": "Bearer " + self.qianfan_api_key,
371
- "Content-Type": "application/json"
372
- }
373
 
374
  results = []
375
  top_k = self.max_search_results_num // len(query_list)
376
  for query in query_list:
377
  payload = {
378
  "messages": [{"role": "user", "content": query}],
379
- "resource_type_filter": [{"type": "web", "top_k": top_k}]
380
  }
381
  response = requests.post(self.web_search_service_url, headers=headers, json=payload)
382
 
@@ -387,4 +372,4 @@ class BotClient(object):
387
  else:
388
  self.logger.info(f"请求失败,状态码: {response.status_code}")
389
  self.logger.info(response.text)
390
- return results
 
14
 
15
  """BotClient class for interacting with bot models."""
16
 
 
17
  import argparse
18
+ import json
19
  import logging
20
  import traceback
21
+
22
  import jieba
23
+ import requests
24
  from openai import OpenAI
25
 
 
26
 
27
+ class BotClient:
28
  """Client for interacting with various AI models."""
29
+
30
  def __init__(self, args: argparse.Namespace):
31
  """
32
+ Initializes the BotClient instance by configuring essential parameters from command line arguments
33
+ including retry limits, character constraints, model endpoints and API credentials while setting up
34
  default values for missing arguments to ensure robust operation.
35
 
36
  Args:
 
38
  Uses getattr() to safely retrieve values with fallback defaults.
39
  """
40
  self.logger = logging.getLogger(__name__)
41
+
42
  self.max_retry_num = getattr(args, 'max_retry_num', 3)
43
  self.max_char = getattr(args, 'max_char', 8000)
44
 
 
55
 
56
  def call_back(self, host_url: str, req_data: dict) -> dict:
57
  """
58
+ Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response
59
+ conversion to a compatible dictionary format, and manages any exceptions that may occur during
60
  the request process while logging errors appropriately.
61
 
62
  Args:
 
69
  """
70
  try:
71
  client = OpenAI(base_url=host_url, api_key=self.api_key)
72
+ response = client.chat.completions.create(**req_data)
73
+
 
 
74
  # Convert OpenAI response to compatible format
75
  return response.model_dump()
76
 
77
  except Exception as e:
78
+ self.logger.error(f"Stream request failed: {e}")
79
  raise
80
 
81
  def call_back_stream(self, host_url: str, req_data: dict) -> dict:
82
  """
83
+ Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks
84
  in real-time while handling any exceptions that may occur during the streaming process.
85
 
86
  Args:
 
99
  for chunk in response:
100
  if not chunk.choices:
101
  continue
102
+
103
  # Convert OpenAI response to compatible format
104
  yield chunk.model_dump()
105
 
106
  except Exception as e:
107
+ self.logger.error(f"Stream request failed: {e}")
108
  raise
109
 
110
  def process(
111
+ self, model_name: str, req_data: dict, max_tokens: int = 2048, temperature: float = 1.0, top_p: float = 0.7
 
 
 
 
 
112
  ) -> dict:
113
  """
114
+ Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
115
+ including token limits and sampling settings, truncating messages to fit character limits, making API calls
116
  with built-in retry mechanism, and logging the full request/response cycle for debugging purposes.
117
 
118
  Args:
 
134
  req_data["messages"] = self.truncate_messages(req_data["messages"])
135
  for _ in range(self.max_retry_num):
136
  try:
137
+ self.logger.info(f"[MODEL] {model_url}")
138
  self.logger.info("[req_data]====>")
139
  self.logger.info(json.dumps(req_data, ensure_ascii=False))
140
  res = self.call_back(model_url, req_data)
 
147
  res = {}
148
  if len(res) != 0 and "error" not in res:
149
  break
150
+
151
  return res
152
 
153
  def process_stream(
154
+ self, model_name: str, req_data: dict, max_tokens: int = 2048, temperature: float = 1.0, top_p: float = 0.7
 
 
 
 
155
  ) -> dict:
156
  """
157
  Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
 
174
  req_data["temperature"] = temperature
175
  req_data["top_p"] = top_p
176
  req_data["messages"] = self.truncate_messages(req_data["messages"])
177
+
178
  last_error = None
179
  for _ in range(self.max_retry_num):
180
  try:
181
+ self.logger.info(f"[MODEL] {model_url}")
182
  self.logger.info("[req_data]====>")
183
  self.logger.info(json.dumps(req_data, ensure_ascii=False))
184
+
185
+ yield from self.call_back_stream(model_url, req_data)
 
186
  return
187
+
188
  except Exception as e:
189
  last_error = e
190
+ self.logger.error(f"Stream request failed (attempt {_ + 1}/{self.max_retry_num}): {e}")
191
+
192
  self.logger.error("All retry attempts failed for stream request")
193
  yield {"error": str(last_error)}
194
 
195
  def cut_chinese_english(self, text: str) -> list:
196
  """
197
+ Segments mixed Chinese and English text into individual components using Jieba for Chinese words
198
+ while preserving English words as whole units, with special handling for Unicode character ranges
199
  to distinguish between the two languages.
200
 
201
  Args:
 
228
  """
229
  if not messages:
230
  return messages
231
+
232
  processed = []
233
  total_units = 0
234
+
235
  for msg in messages:
236
  # Handle two different content formats
237
  if isinstance(msg["content"], str):
 
240
  text_content = msg["content"][1]["text"]
241
  else:
242
  text_content = ""
243
+
244
  # Calculate unit count after tokenization
245
  units = self.cut_chinese_english(text_content)
246
  unit_count = len(units)
247
+
248
+ processed.append(
249
+ {
250
+ "role": msg["role"],
251
+ "original_content": msg["content"], # Preserve original content
252
+ "text_content": text_content, # Extracted plain text
253
+ "units": units,
254
+ "unit_count": unit_count,
255
+ }
256
+ )
257
  total_units += unit_count
258
+
259
  if total_units <= self.max_char:
260
  return messages
261
+
262
  # Number of units to remove
263
  to_remove = total_units - self.max_char
264
+
265
  # 1. Truncate historical messages
266
  for i in range(len(processed) - 1, 1):
267
  if to_remove <= 0:
268
  break
269
+
270
  # current = processed[i]
271
  if processed[i]["unit_count"] <= to_remove:
272
  processed[i]["text_content"] = ""
 
284
  elif isinstance(processed[i]["original_content"], list):
285
  processed[i]["original_content"][1]["text"] = new_text
286
  to_remove = 0
287
+
288
  # 2. Truncate system message
289
  if to_remove > 0:
290
  system_msg = processed[0]
 
304
  elif isinstance(processed[0]["original_content"], list):
305
  processed[0]["original_content"][1]["text"] = new_text
306
  to_remove = 0
307
+
308
  # 3. Truncate last message
309
  if to_remove > 0 and len(processed) > 1:
310
  last_msg = processed[-1]
 
322
  last_msg["original_content"] = ""
323
  elif isinstance(last_msg["original_content"], list):
324
  last_msg["original_content"][1]["text"] = ""
325
+
326
  result = []
327
  for msg in processed:
328
  if msg["text_content"]:
329
+ result.append({"role": msg["role"], "content": msg["original_content"]})
330
+
 
 
 
331
  return result
332
 
333
  def embed_fn(self, text: str) -> list:
 
354
  Returns:
355
  list: List of responses from the AI Search service.
356
  """
357
+ headers = {"Authorization": "Bearer " + self.qianfan_api_key, "Content-Type": "application/json"}
 
 
 
358
 
359
  results = []
360
  top_k = self.max_search_results_num // len(query_list)
361
  for query in query_list:
362
  payload = {
363
  "messages": [{"role": "user", "content": query}],
364
+ "resource_type_filter": [{"type": "web", "top_k": top_k}],
365
  }
366
  response = requests.post(self.web_search_service_url, headers=headers, json=payload)
367
 
 
372
  else:
373
  self.logger.info(f"请求失败,状态码: {response.status_code}")
374
  self.logger.info(response.text)
375
+ return results