maxiaolong03 commited on
Commit
9de4aae
·
1 Parent(s): eea129f
Files changed (2) hide show
  1. app.py +61 -69
  2. bot_requests.py +40 -38
app.py CHANGED
@@ -17,6 +17,7 @@
17
  import argparse
18
  from collections import namedtuple
19
  from functools import partial
 
20
  import logging
21
  import os
22
  import base64
@@ -30,6 +31,8 @@ os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
30
 
31
  logging.root.setLevel(logging.INFO)
32
 
 
 
33
 
34
  def get_args() -> argparse.Namespace:
35
  """
@@ -38,17 +41,10 @@ def get_args() -> argparse.Namespace:
38
  The arguments include:
39
  - Server port and name for the Gradio interface
40
  - Character limits and retry settings for conversation handling
41
- - Model endpoints for different AI services
42
- - API keys and other service configurations
43
 
44
  Returns:
45
- argparse.Namespace: Parsed command line arguments containing:
46
- - server_port (int): Port number for the demo server (default: 8232)
47
- - server_name (str): Hostname/IP for the server (default: "0.0.0.0")
48
- - max_char (int): Maximum character limit for messages (default: 8000)
49
- - max_retry_num (int): Maximum retry attempts for API calls (default: 3)
50
- - eb45t_model_url (str): Endpoint URL for the multimodal model
51
- - x1_model_url (str): Endpoint URL for the text inference model
52
  """
53
  parser = ArgumentParser(description="ERNIE models web chat demo.")
54
 
@@ -65,19 +61,38 @@ def get_args() -> argparse.Namespace:
65
  "--max_retry_num", type=int, default=3, help="Maximum retry number for request."
66
  )
67
  parser.add_argument(
68
- "--eb45t_model_url",
69
  type=str,
70
- default="https://qianfan.baidubce.com/v2",
71
- help="Model URL for multimodal model."
72
- )
73
- parser.add_argument(
74
- "--x1_model_url",
75
- type=str,
76
- default="https://qianfan.baidubce.com/v2",
77
- help="Model URL for text inference model."
 
 
 
 
 
 
 
 
 
 
78
  )
79
 
80
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
81
  return args
82
 
83
 
@@ -85,7 +100,6 @@ class GradioEvents(object):
85
  """
86
  Central handler for all Gradio interface events in the chatbot demo. Provides static methods
87
  for processing user interactions including:
88
- - Streaming chat predictions with reasoning steps
89
  - Response regeneration
90
  - Conversation state management
91
  - Image handling and URL conversion
@@ -127,12 +141,12 @@ class GradioEvents(object):
127
  temperature: float,
128
  top_p: float,
129
  bot_client: BotClient
130
- ) -> dict:
131
  """
132
  Handles streaming chat interactions by processing user queries and
133
  generating real-time responses from the bot client. Constructs conversation
134
  history including system messages, text inputs and image attachments, then
135
- streams back model responses with reasoning steps and final answers.
136
 
137
  Args:
138
  query (str): User input.
@@ -147,7 +161,7 @@ class GradioEvents(object):
147
  bot_client (BotClient): Bot client.
148
 
149
  Yields:
150
- dict: A dictionary containing the event type and its corresponding content.
151
  """
152
  conversation = []
153
  if system_msg:
@@ -174,7 +188,6 @@ class GradioEvents(object):
174
  else:
175
  conversation.append({"role": "user", "content": query})
176
 
177
-
178
  try:
179
  req_data = {"messages": conversation}
180
  for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
@@ -183,12 +196,9 @@ class GradioEvents(object):
183
 
184
  message = chunk.get("choices", [{}])[0].get("delta", {})
185
  content = message.get("content", "")
186
- reasoning_content = message.get("reasoning_content", "")
187
 
188
- if reasoning_content:
189
- yield {"type": "thinking", "content": reasoning_content}
190
  if content:
191
- yield {"type": "answer", "content": content}
192
 
193
  except Exception as e:
194
  raise gr.Error("Exception: " + repr(e))
@@ -209,7 +219,7 @@ class GradioEvents(object):
209
  ) -> list:
210
  """
211
  Processes user queries in a streaming manner by coordinating with the chat stream handler,
212
- progressively updates the chatbot state with intermediate reasoning steps and final responses,
213
  and maintains conversation history. Handles both text and multimodal inputs while preserving
214
  the interactive chat experience with real-time updates.
215
 
@@ -249,35 +259,16 @@ class GradioEvents(object):
249
  bot_client
250
  )
251
 
252
- reasoning_content = ""
253
  response = ""
254
- has_thinking = False
255
- for new_text in new_texts:
256
- if not isinstance(new_text, dict):
257
- continue
258
-
259
- if new_text.get("type") == "thinking":
260
- has_thinking = True
261
- reasoning_content += new_text["content"]
262
-
263
- elif new_text.get("type") == "answer":
264
- response += new_text["content"]
265
 
266
- # Remove previous thinking message if exists
267
  if chatbot[-1].get("role") == "assistant":
268
  chatbot.pop(-1)
269
-
270
- content = ""
271
- if has_thinking:
272
- content = "**思考过程:**<br>{}<br>".format(reasoning_content)
273
  if response:
274
- if has_thinking:
275
- content += "<br><br>**最终回答:**<br>{}".format(response)
276
- else:
277
- content = response
278
-
279
- if content:
280
- chatbot.append({"role": "assistant", "content": content})
281
  yield chatbot
282
 
283
  logging.info("History: {}".format(task_history))
@@ -387,7 +378,7 @@ class GradioEvents(object):
387
  gc.collect()
388
 
389
  @staticmethod
390
- def toggle_components_visibility(model_name: str) -> tuple:
391
  """
392
  Toggle visibility of components depending on the selected model name.
393
 
@@ -395,13 +386,9 @@ class GradioEvents(object):
395
  model_name (str): Name of the selected model.
396
 
397
  Returns:
398
- tuple: A tuple containing two updates: one for the file button and another for the system message.
399
  """
400
- is_eb45t = (model_name == "eb-45t")
401
- return (
402
- gr.update(visible=is_eb45t), # file_btn
403
- gr.update(visible=is_eb45t) # system_message
404
- )
405
 
406
 
407
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
@@ -426,14 +413,11 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
426
  /* Insert English prompt text below the SVG icon */
427
  #file-upload .wrap::after {
428
  content: "Drag and drop files here or click to upload";
429
- font-size: 18px;
430
  color: #555;
431
- margin-top: 8px;
432
  white-space: nowrap;
433
  }
434
  """
435
- model_names = ["eb-45t", "eb-x1"]
436
-
437
  with gr.Blocks(css=css) as demo:
438
  logo_url = GradioEvents.get_image_url("assets/logo.png")
439
  gr.Markdown("""\
@@ -444,18 +428,20 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
444
  <center><font size=3>This demo is based on ERNIE models. \
445
  (本演示基于文心大模型实现。)</center>"""
446
  )
447
- gr.Markdown("""\
448
- <center><font size=4>
449
- <a href="https://yiyan.baidu.com/">eb-45t</a> |
450
- &nbsp<a href="https://yiyan.baidu.com/">eb-x1</a></center>""")
451
 
452
  chatbot = gr.Chatbot(
453
  label="ERNIE",
454
  elem_classes="control-height",
455
  type="messages"
456
  )
 
457
  with gr.Row():
458
- model_name = gr.Dropdown(label="Select Model", choices=model_names, value="eb-45t", allow_custom_value=True)
 
 
 
 
 
459
  file_btn = gr.File(
460
  label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
461
  height="80px",
@@ -485,7 +471,7 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
485
  model_name.change(
486
  GradioEvents.toggle_components_visibility,
487
  inputs=model_name,
488
- outputs=[file_btn, system_message]
489
  )
490
  model_name.change(
491
  GradioEvents.reset_state,
@@ -526,6 +512,12 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
526
  show_progress=True
527
  )
528
 
 
 
 
 
 
 
529
  demo.queue().launch(
530
  server_port=args.server_port,
531
  server_name=args.server_name
 
17
  import argparse
18
  from collections import namedtuple
19
  from functools import partial
20
+ import json
21
  import logging
22
  import os
23
  import base64
 
31
 
32
  logging.root.setLevel(logging.INFO)
33
 
34
+ MULTI_MODEL_PREFIX = "ernie-4.5-turbo-vl"
35
+
36
 
37
  def get_args() -> argparse.Namespace:
38
  """
 
41
  The arguments include:
42
  - Server port and name for the Gradio interface
43
  - Character limits and retry settings for conversation handling
44
+ - Model name to endpoint mappings for the chatbot
 
45
 
46
  Returns:
47
+ argparse.Namespace: Parsed command line arguments containing all the above settings
 
 
 
 
 
 
48
  """
49
  parser = ArgumentParser(description="ERNIE models web chat demo.")
50
 
 
61
  "--max_retry_num", type=int, default=3, help="Maximum retry number for request."
62
  )
63
  parser.add_argument(
64
+ "--model_map",
65
  type=str,
66
+ default="""{
67
+ "ernie-4.5-turbo-128k": "https://qianfan.baidubce.com/v2",
68
+ "ernie-4.5-turbo-32k": "https://qianfan.baidubce.com/v2",
69
+ "ernie-4.5-8k-preview": "https://qianfan.baidubce.com/v2",
70
+ "ernie-4.5-turbo-vl-32k": "https://qianfan.baidubce.com/v2",
71
+ "ernie-4.5-turbo-vl-32k-preview": "https://qianfan.baidubce.com/v2"
72
+ }""",
73
+ help="""JSON string defining model name to endpoint mappings.
74
+ Required Format:
75
+ {"model_name": "http://localhost:port/v1", ...}
76
+
77
+ Note:
78
+ - All endpoints must be valid HTTP URLs
79
+ - At least one model must be specified
80
+ - Prefix determines model capabilities:
81
+ * ERNIE-4.5[-*]: Text-only model
82
+ * ERNIE-4.5-VL[-*]: Multimodal models (image+text)
83
+ """
84
  )
85
 
86
  args = parser.parse_args()
87
+ try:
88
+ args.model_map = json.loads(args.model_map)
89
+
90
+ # Validation: Check at least one model exists
91
+ if len(args.model_map) < 1:
92
+ raise ValueError("model_map must contain at least one model configuration")
93
+ except json.JSONDecodeError as e:
94
+ raise ValueError("Invalid JSON format for --model-map") from e
95
+
96
  return args
97
 
98
 
 
100
  """
101
  Central handler for all Gradio interface events in the chatbot demo. Provides static methods
102
  for processing user interactions including:
 
103
  - Response regeneration
104
  - Conversation state management
105
  - Image handling and URL conversion
 
141
  temperature: float,
142
  top_p: float,
143
  bot_client: BotClient
144
+ ) -> str:
145
  """
146
  Handles streaming chat interactions by processing user queries and
147
  generating real-time responses from the bot client. Constructs conversation
148
  history including system messages, text inputs and image attachments, then
149
+ streams back model responses.
150
 
151
  Args:
152
  query (str): User input.
 
161
  bot_client (BotClient): Bot client.
162
 
163
  Yields:
164
+ str: Model response.
165
  """
166
  conversation = []
167
  if system_msg:
 
188
  else:
189
  conversation.append({"role": "user", "content": query})
190
 
 
191
  try:
192
  req_data = {"messages": conversation}
193
  for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
 
196
 
197
  message = chunk.get("choices", [{}])[0].get("delta", {})
198
  content = message.get("content", "")
 
199
 
 
 
200
  if content:
201
+ yield content
202
 
203
  except Exception as e:
204
  raise gr.Error("Exception: " + repr(e))
 
219
  ) -> list:
220
  """
221
  Processes user queries in a streaming manner by coordinating with the chat stream handler,
222
+ progressively updates the chatbot state with responses,
223
  and maintains conversation history. Handles both text and multimodal inputs while preserving
224
  the interactive chat experience with real-time updates.
225
 
 
259
  bot_client
260
  )
261
 
 
262
  response = ""
263
+ for new_text in new_texts:
264
+ response += new_text
 
 
 
 
 
 
 
 
 
265
 
266
+ # Remove previous message if exists
267
  if chatbot[-1].get("role") == "assistant":
268
  chatbot.pop(-1)
269
+
 
 
 
270
  if response:
271
+ chatbot.append({"role": "assistant", "content": response})
 
 
 
 
 
 
272
  yield chatbot
273
 
274
  logging.info("History: {}".format(task_history))
 
378
  gc.collect()
379
 
380
  @staticmethod
381
+ def toggle_components_visibility(model_name: str) -> gr.update:
382
  """
383
  Toggle visibility of components depending on the selected model name.
384
 
 
386
  model_name (str): Name of the selected model.
387
 
388
  Returns:
389
+ gr.update: An update object representing the visibility of the file button.
390
  """
391
+ return gr.update(visible=model_name.startswith(MULTI_MODEL_PREFIX)) # file_btn
 
 
 
 
392
 
393
 
394
  def launch_demo(args: argparse.Namespace, bot_client: BotClient):
 
413
  /* Insert English prompt text below the SVG icon */
414
  #file-upload .wrap::after {
415
  content: "Drag and drop files here or click to upload";
416
+ font-size: 15px;
417
  color: #555;
 
418
  white-space: nowrap;
419
  }
420
  """
 
 
421
  with gr.Blocks(css=css) as demo:
422
  logo_url = GradioEvents.get_image_url("assets/logo.png")
423
  gr.Markdown("""\
 
428
  <center><font size=3>This demo is based on ERNIE models. \
429
  (本演示基于文心大模型实现。)</center>"""
430
  )
 
 
 
 
431
 
432
  chatbot = gr.Chatbot(
433
  label="ERNIE",
434
  elem_classes="control-height",
435
  type="messages"
436
  )
437
+ model_names = list(args.model_map.keys())
438
  with gr.Row():
439
+ model_name = gr.Dropdown(
440
+ label="Select Model",
441
+ choices=model_names,
442
+ value=model_names[0],
443
+ allow_custom_value=True
444
+ )
445
  file_btn = gr.File(
446
  label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
447
  height="80px",
 
471
  model_name.change(
472
  GradioEvents.toggle_components_visibility,
473
  inputs=model_name,
474
+ outputs=file_btn
475
  )
476
  model_name.change(
477
  GradioEvents.reset_state,
 
512
  show_progress=True
513
  )
514
 
515
+ demo.load(
516
+ GradioEvents.toggle_components_visibility,
517
+ inputs=gr.State(model_names[0]),
518
+ outputs=file_btn
519
+ )
520
+
521
  demo.queue().launch(
522
  server_port=args.server_port,
523
  server_name=args.server_name
bot_requests.py CHANGED
@@ -22,7 +22,7 @@ import json
22
  import jieba
23
  from openai import OpenAI
24
 
25
- from appbuilder.mcp_server.client import MCPClient
26
 
27
  class BotClient(object):
28
  """Client for interacting with various AI models."""
@@ -41,15 +41,16 @@ class BotClient(object):
41
  self.max_retry_num = getattr(args, 'max_retry_num', 3)
42
  self.max_char = getattr(args, 'max_char', 8000)
43
 
44
- self.eb45t_model_url = getattr(args, 'eb45t_model_url', 'eb45t_model_url')
45
- self.x1_model_url = getattr(args, 'x1_model_url', 'x1_model_url')
46
  self.api_key = os.environ.get("API_KEY")
47
 
48
- self.qianfan_url = getattr(args, 'qianfan_url', 'qianfan_url')
49
- self.qianfan_api_key = getattr(args, 'qianfan_api_key', 'qianfan_api_key')
50
  self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
51
 
52
- self.ai_search_service_url = getattr(args, 'ai_search_service_url', 'ai_search_service_url')
 
 
 
53
 
54
  def call_back(self, host_url: str, req_data: dict) -> dict:
55
  """
@@ -130,14 +131,9 @@ class BotClient(object):
130
  Returns:
131
  dict: Dictionary containing the model's processing results.
132
  """
133
- model_map = {
134
- "eb-45t": self.eb45t_model_url,
135
- "eb-x1": self.x1_model_url
136
- }
137
-
138
- model_url = model_map[model_name]
139
 
140
- req_data["model"] = "ernie-4.5-turbo-vl-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
141
  req_data["max_tokens"] = max_tokens
142
  req_data["temperature"] = temperature
143
  req_data["top_p"] = top_p
@@ -157,7 +153,6 @@ class BotClient(object):
157
  res = {}
158
  if len(res) != 0 and "error" not in res:
159
  break
160
- self.logger.info(json.dumps(res, ensure_ascii=False))
161
 
162
  return res
163
 
@@ -183,13 +178,8 @@ class BotClient(object):
183
  Yields:
184
  dict: Dictionary containing the model's processing results.
185
  """
186
- model_map = {
187
- "eb-45t": self.eb45t_model_url,
188
- "eb-x1": self.x1_model_url
189
- }
190
-
191
- model_url = model_map[model_name]
192
- req_data["model"] = "ernie-4.5-turbo-vl-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
193
  req_data["max_tokens"] = max_tokens
194
  req_data["temperature"] = temperature
195
  req_data["top_p"] = top_p
@@ -282,7 +272,7 @@ class BotClient(object):
282
  to_remove = total_units - self.max_char
283
 
284
  # 1. Truncate historical messages
285
- for i in range(1, len(processed) - 1):
286
  if to_remove <= 0:
287
  break
288
 
@@ -362,27 +352,39 @@ class BotClient(object):
362
  Returns:
363
  list: A list of floats representing the embedding.
364
  """
365
- client = OpenAI(base_url=self.qianfan_url, api_key=self.qianfan_api_key)
366
  response = client.embeddings.create(input=[text], model=self.embedding_model)
367
  return response.data[0].embedding
368
 
369
- async def get_ai_search_res(self, query_list: list) -> list:
370
  """
371
- Get AI search results for the given queries using the MCPClient.
372
-
373
  Args:
374
- query_list (list): List of queries to search for.
375
 
376
  Returns:
377
- list: List of search results as strings.
378
  """
379
- try:
380
- client = MCPClient()
381
- await client.connect_to_server(service_url=self.ai_search_service_url)
382
- result = []
383
- for query in query_list:
384
- response = await client.call_tool("AIsearch", {"query": query})
385
- result.append(response.content[0].text)
386
- finally:
387
- await client.cleanup()
388
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  import jieba
23
  from openai import OpenAI
24
 
25
+ import requests
26
 
27
  class BotClient(object):
28
  """Client for interacting with various AI models."""
 
41
  self.max_retry_num = getattr(args, 'max_retry_num', 3)
42
  self.max_char = getattr(args, 'max_char', 8000)
43
 
44
+ self.model_map = getattr(args, 'model_map', {})
 
45
  self.api_key = os.environ.get("API_KEY")
46
 
47
+ self.embedding_service_url = getattr(args, 'embedding_service_url', 'embedding_service_url')
 
48
  self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
49
 
50
+ self.web_search_service_url = getattr(args, 'web_search_service_url', 'web_search_service_url')
51
+ self.max_search_results_num = getattr(args, 'max_search_results_num', 15)
52
+
53
+ self.qianfan_api_key = os.environ.get("API_KEY")
54
 
55
  def call_back(self, host_url: str, req_data: dict) -> dict:
56
  """
 
131
  Returns:
132
  dict: Dictionary containing the model's processing results.
133
  """
134
+ model_url = self.model_map[model_name]
 
 
 
 
 
135
 
136
+ req_data["model"] = model_name
137
  req_data["max_tokens"] = max_tokens
138
  req_data["temperature"] = temperature
139
  req_data["top_p"] = top_p
 
153
  res = {}
154
  if len(res) != 0 and "error" not in res:
155
  break
 
156
 
157
  return res
158
 
 
178
  Yields:
179
  dict: Dictionary containing the model's processing results.
180
  """
181
+ model_url = self.model_map[model_name]
182
+ req_data["model"] = model_name
 
 
 
 
 
183
  req_data["max_tokens"] = max_tokens
184
  req_data["temperature"] = temperature
185
  req_data["top_p"] = top_p
 
272
  to_remove = total_units - self.max_char
273
 
274
  # 1. Truncate historical messages
275
+ for i in range(len(processed) - 1, 1):
276
  if to_remove <= 0:
277
  break
278
 
 
352
  Returns:
353
  list: A list of floats representing the embedding.
354
  """
355
+ client = OpenAI(base_url=self.embedding_service_url, api_key=self.qianfan_api_key)
356
  response = client.embeddings.create(input=[text], model=self.embedding_model)
357
  return response.data[0].embedding
358
 
359
+ def get_web_search_res(self, query_list: list) -> list:
360
  """
361
+ Send a request to the AI Search service using the provided API key and service URL.
362
+
363
  Args:
364
+ query_list (list): List of queries to send to the AI Search service.
365
 
366
  Returns:
367
+ list: List of responses from the AI Search service.
368
  """
369
+ headers = {
370
+ "Authorization": "Bearer " + self.qianfan_api_key,
371
+ "Content-Type": "application/json"
372
+ }
373
+
374
+ results = []
375
+ top_k = self.max_search_results_num // len(query_list)
376
+ for query in query_list:
377
+ payload = {
378
+ "messages": [{"role": "user", "content": query}],
379
+ "resource_type_filter": [{"type": "web", "top_k": top_k}]
380
+ }
381
+ response = requests.post(self.web_search_service_url, headers=headers, json=payload)
382
+
383
+ if response.status_code == 200:
384
+ response = response.json()
385
+ self.logger.info(response)
386
+ results.append(response["references"])
387
+ else:
388
+ self.logger.info(f"请求失败,状态码: {response.status_code}")
389
+ self.logger.info(response.text)
390
+ return results