maxiaolong03 commited on
Commit
eea129f
·
1 Parent(s): b8aaafe
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +541 -0
  3. assets/logo.png +0 -0
  4. bot_requests.py +388 -0
  5. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
app.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """This file contains the code for the chatbot demo using Gradio."""
16
+
17
+ import argparse
18
+ from collections import namedtuple
19
+ from functools import partial
20
+ import logging
21
+ import os
22
+ import base64
23
+ from argparse import ArgumentParser
24
+
25
+ import gradio as gr
26
+
27
+ from bot_requests import BotClient
28
+
29
+ os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
30
+
31
+ logging.root.setLevel(logging.INFO)
32
+
33
+
34
+ def get_args() -> argparse.Namespace:
35
+ """
36
+ Parses and returns command line arguments for configuring the chatbot demo.
37
+ Sets up argument parser with default values for server configuration and model endpoints.
38
+ The arguments include:
39
+ - Server port and name for the Gradio interface
40
+ - Character limits and retry settings for conversation handling
41
+ - Model endpoints for different AI services
42
+ - API keys and other service configurations
43
+
44
+ Returns:
45
+ argparse.Namespace: Parsed command line arguments containing:
46
+ - server_port (int): Port number for the demo server (default: 8232)
47
+ - server_name (str): Hostname/IP for the server (default: "0.0.0.0")
48
+ - max_char (int): Maximum character limit for messages (default: 8000)
49
+ - max_retry_num (int): Maximum retry attempts for API calls (default: 3)
50
+ - eb45t_model_url (str): Endpoint URL for the multimodal model
51
+ - x1_model_url (str): Endpoint URL for the text inference model
52
+ """
53
+ parser = ArgumentParser(description="ERNIE models web chat demo.")
54
+
55
+ parser.add_argument(
56
+ "--server-port", type=int, default=7860, help="Demo server port."
57
+ )
58
+ parser.add_argument(
59
+ "--server-name", type=str, default="0.0.0.0", help="Demo server name."
60
+ )
61
+ parser.add_argument(
62
+ "--max_char", type=int, default=8000, help="Maximum character limit for messages."
63
+ )
64
+ parser.add_argument(
65
+ "--max_retry_num", type=int, default=3, help="Maximum retry number for request."
66
+ )
67
+ parser.add_argument(
68
+ "--eb45t_model_url",
69
+ type=str,
70
+ default="https://qianfan.baidubce.com/v2",
71
+ help="Model URL for multimodal model."
72
+ )
73
+ parser.add_argument(
74
+ "--x1_model_url",
75
+ type=str,
76
+ default="https://qianfan.baidubce.com/v2",
77
+ help="Model URL for text inference model."
78
+ )
79
+
80
+ args = parser.parse_args()
81
+ return args
82
+
83
+
84
+ class GradioEvents(object):
85
+ """
86
+ Central handler for all Gradio interface events in the chatbot demo. Provides static methods
87
+ for processing user interactions including:
88
+ - Streaming chat predictions with reasoning steps
89
+ - Response regeneration
90
+ - Conversation state management
91
+ - Image handling and URL conversion
92
+ - Component visibility control
93
+
94
+ Coordinates with BotClient to interface with backend models while maintaining
95
+ conversation history and handling multimodal inputs.
96
+ """
97
+ @staticmethod
98
+ def get_image_url(image_path: str) -> str:
99
+ """
100
+ Converts an image file at the given path to a base64 encoded data URL
101
+ that can be used directly in HTML or Gradio interfaces.
102
+ Reads the image file, encodes it in base64 format, and constructs
103
+ a data URL with the appropriate image MIME type.
104
+
105
+ Args:
106
+ image_path (str): Path to the image file.
107
+
108
+ Returns:
109
+ str: Image URL.
110
+ """
111
+ base64_image = ""
112
+ extension = image_path.split(".")[-1]
113
+ with open(image_path, "rb") as image_file:
114
+ base64_image = base64.b64encode(image_file.read()).decode("utf-8")
115
+ url = "data:image/{ext};base64,{img}".format(ext=extension, img=base64_image)
116
+ return url
117
+
118
+ @staticmethod
119
+ def chat_stream(
120
+ query: str,
121
+ task_history: list,
122
+ image_history: dict,
123
+ model_name: str,
124
+ file_url: str,
125
+ system_msg: str,
126
+ max_tokens: int,
127
+ temperature: float,
128
+ top_p: float,
129
+ bot_client: BotClient
130
+ ) -> dict:
131
+ """
132
+ Handles streaming chat interactions by processing user queries and
133
+ generating real-time responses from the bot client. Constructs conversation
134
+ history including system messages, text inputs and image attachments, then
135
+ streams back model responses with reasoning steps and final answers.
136
+
137
+ Args:
138
+ query (str): User input.
139
+ task_history (list): Task history.
140
+ image_history (dict): Image history.
141
+ model_name (str): Model name.
142
+ file_url (str): File URL.
143
+ system_msg (str): System message.
144
+ max_tokens (int): Maximum tokens.
145
+ temperature (float): Temperature.
146
+ top_p (float): Top p.
147
+ bot_client (BotClient): Bot client.
148
+
149
+ Yields:
150
+ dict: A dictionary containing the event type and its corresponding content.
151
+ """
152
+ conversation = []
153
+ if system_msg:
154
+ conversation.append({"role": "system", "content": system_msg})
155
+ for idx, (query_h, response_h) in enumerate(task_history):
156
+ if idx in image_history:
157
+ content = []
158
+ content.append({
159
+ "type": "image_url",
160
+ "image_url": {"url": GradioEvents.get_image_url(image_history[idx])}
161
+ })
162
+ content.append({"type": "text", "text": query_h})
163
+ conversation.append({"role": "user", "content": content})
164
+ else:
165
+ conversation.append({"role": "user", "content": query_h})
166
+ conversation.append({"role": "assistant", "content": response_h})
167
+
168
+ content = []
169
+ if file_url and (len(image_history) == 0 or file_url != list(image_history.values())[-1]):
170
+ image_history[len(task_history)] = file_url
171
+ content.append({"type": "image_url", "image_url": {"url": GradioEvents.get_image_url(file_url)}})
172
+ content.append({"type": "text", "text": query})
173
+ conversation.append({"role": "user", "content": content})
174
+ else:
175
+ conversation.append({"role": "user", "content": query})
176
+
177
+
178
+ try:
179
+ req_data = {"messages": conversation}
180
+ for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
181
+ if "error" in chunk:
182
+ raise Exception(chunk["error"])
183
+
184
+ message = chunk.get("choices", [{}])[0].get("delta", {})
185
+ content = message.get("content", "")
186
+ reasoning_content = message.get("reasoning_content", "")
187
+
188
+ if reasoning_content:
189
+ yield {"type": "thinking", "content": reasoning_content}
190
+ if content:
191
+ yield {"type": "answer", "content": content}
192
+
193
+ except Exception as e:
194
+ raise gr.Error("Exception: " + repr(e))
195
+
196
+ @staticmethod
197
+ def predict_stream(
198
+ query: str,
199
+ chatbot: list,
200
+ task_history: list,
201
+ image_history: dict,
202
+ model: str,
203
+ file_url: str,
204
+ system_msg: str,
205
+ max_tokens: int,
206
+ temperature: float,
207
+ top_p: float,
208
+ bot_client: BotClient
209
+ ) -> list:
210
+ """
211
+ Processes user queries in a streaming manner by coordinating with the chat stream handler,
212
+ progressively updates the chatbot state with intermediate reasoning steps and final responses,
213
+ and maintains conversation history. Handles both text and multimodal inputs while preserving
214
+ the interactive chat experience with real-time updates.
215
+
216
+ Args:
217
+ query (str): The user's query.
218
+ chatbot (list): The current chatbot state.
219
+ task_history (list): The task history.
220
+ image_history (dict): The image history.
221
+ model (str): The model name.
222
+ file_url (str): The file URL.
223
+ system_msg (str): The system message.
224
+ max_tokens (int): The maximum token length of the generated response.
225
+ temperature (float): The temperature parameter used by the model.
226
+ top_p (float): The top_p parameter used by the model.
227
+ bot_client (BotClient): The bot client.
228
+
229
+ Returns:
230
+ list: A list containing the updated chatbot state after processing the user's query.
231
+ """
232
+
233
+ logging.info("User: {}".format(query))
234
+ chatbot.append({"role": "user", "content": query})
235
+
236
+ # First yield the chatbot with user message
237
+ yield chatbot
238
+
239
+ new_texts = GradioEvents.chat_stream(
240
+ query,
241
+ task_history,
242
+ image_history,
243
+ model,
244
+ file_url,
245
+ system_msg,
246
+ max_tokens,
247
+ temperature,
248
+ top_p,
249
+ bot_client
250
+ )
251
+
252
+ reasoning_content = ""
253
+ response = ""
254
+ has_thinking = False
255
+ for new_text in new_texts:
256
+ if not isinstance(new_text, dict):
257
+ continue
258
+
259
+ if new_text.get("type") == "thinking":
260
+ has_thinking = True
261
+ reasoning_content += new_text["content"]
262
+
263
+ elif new_text.get("type") == "answer":
264
+ response += new_text["content"]
265
+
266
+ # Remove previous thinking message if exists
267
+ if chatbot[-1].get("role") == "assistant":
268
+ chatbot.pop(-1)
269
+
270
+ content = ""
271
+ if has_thinking:
272
+ content = "**思考过程:**<br>{}<br>".format(reasoning_content)
273
+ if response:
274
+ if has_thinking:
275
+ content += "<br><br>**最终回答:**<br>{}".format(response)
276
+ else:
277
+ content = response
278
+
279
+ if content:
280
+ chatbot.append({"role": "assistant", "content": content})
281
+ yield chatbot
282
+
283
+ logging.info("History: {}".format(task_history))
284
+ task_history.append((query, response))
285
+ logging.info("ERNIE models: {}".format(response))
286
+
287
+ @staticmethod
288
+ def regenerate(
289
+ chatbot: list,
290
+ task_history: list,
291
+ image_history: dict,
292
+ model: str,
293
+ file_url: str,
294
+ system_msg: str,
295
+ max_tokens: int,
296
+ temperature: float,
297
+ top_p: float,
298
+ bot_client: BotClient
299
+ ) -> list:
300
+ """
301
+ Reconstructs the conversation context by removing the last interaction and
302
+ reprocesses the user's previous query to generate a fresh response. Maintains
303
+ consistency in conversation flow while allowing response regeneration.
304
+
305
+ Args:
306
+ chatbot (list): The current chatbot state.
307
+ task_history (list): The task history.
308
+ image_history (dict): The image history.
309
+ model (str): The model name.
310
+ file_url (str): The file URL.
311
+ system_msg (str): The system message.
312
+ max_tokens (int): The maximum token length of the generated response.
313
+ temperature (float): The temperature parameter used by the model.
314
+ top_p (float): The top_p parameter used by the model.
315
+ bot_client (BotClient): The bot client.
316
+
317
+ Yields:
318
+ list: A list containing the updated chatbot state after processing the user's query.
319
+ """
320
+ if not task_history:
321
+ yield chatbot
322
+ return
323
+ # Pop the last user query and bot response from task_history
324
+ item = task_history.pop(-1)
325
+ if (len(task_history)) in image_history:
326
+ del image_history[len(task_history)]
327
+ while len(chatbot) != 0 and chatbot[-1].get("role") == "assistant":
328
+ chatbot.pop(-1)
329
+ chatbot.pop(-1)
330
+
331
+ for chunk in GradioEvents.predict_stream(
332
+ item[0],
333
+ chatbot,
334
+ task_history,
335
+ image_history,
336
+ model,
337
+ file_url,
338
+ system_msg,
339
+ max_tokens,
340
+ temperature,
341
+ top_p,
342
+ bot_client
343
+ ):
344
+ yield chunk
345
+
346
+ @staticmethod
347
+ def reset_user_input() -> gr.update:
348
+ """
349
+ Reset user input field value to empty string.
350
+
351
+ Returns:
352
+ gr.update: Update object representing the new value of the user input field.
353
+ """
354
+ return gr.update(value="")
355
+
356
+ @staticmethod
357
+ def reset_state() -> tuple:
358
+ """
359
+ Reset all states including chatbot, task_history, image_history, and file_btn.
360
+
361
+ Returns:
362
+ tuple: A tuple containing the following values:
363
+ - chatbot (list): An empty list that represents the cleared chatbot state.
364
+ - task_history (list): An empty list that represents the cleared task history.
365
+ - image_history (dict): An empty dictionary that represents the cleared image history.
366
+ - file_btn (gr.update): An update object that sets the value of the file button to None.
367
+ """
368
+ GradioEvents.gc()
369
+
370
+ reset_result = namedtuple("reset_result",
371
+ ["chatbot",
372
+ "task_history",
373
+ "image_history",
374
+ "file_btn"])
375
+ return reset_result(
376
+ [], # clear chatbot
377
+ [], # clear task_history
378
+ {}, # clear image_history
379
+ gr.update(value=None), # clear file_btn
380
+ )
381
+
382
+ @staticmethod
383
+ def gc():
384
+ """Run garbage collection to free up memory resources."""
385
+ import gc
386
+
387
+ gc.collect()
388
+
389
+ @staticmethod
390
+ def toggle_components_visibility(model_name: str) -> tuple:
391
+ """
392
+ Toggle visibility of components depending on the selected model name.
393
+
394
+ Args:
395
+ model_name (str): Name of the selected model.
396
+
397
+ Returns:
398
+ tuple: A tuple containing two updates: one for the file button and another for the system message.
399
+ """
400
+ is_eb45t = (model_name == "eb-45t")
401
+ return (
402
+ gr.update(visible=is_eb45t), # file_btn
403
+ gr.update(visible=is_eb45t) # system_message
404
+ )
405
+
406
+
407
+ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
408
+ """
409
+ Launch demo program
410
+
411
+ Args:
412
+ args (argparse.Namespace): argparse Namespace object containing parsed command line arguments
413
+ bot_client (BotClient): Bot client instance
414
+ """
415
+ css = """
416
+ /* Hide original Chinese text */
417
+ #file-upload .wrap {
418
+ font-size: 0 !important;
419
+ position: relative;
420
+ display: flex;
421
+ flex-direction: column;
422
+ align-items: center;
423
+ justify-content: center;
424
+ }
425
+
426
+ /* Insert English prompt text below the SVG icon */
427
+ #file-upload .wrap::after {
428
+ content: "Drag and drop files here or click to upload";
429
+ font-size: 18px;
430
+ color: #555;
431
+ margin-top: 8px;
432
+ white-space: nowrap;
433
+ }
434
+ """
435
+ model_names = ["eb-45t", "eb-x1"]
436
+
437
+ with gr.Blocks(css=css) as demo:
438
+ logo_url = GradioEvents.get_image_url("assets/logo.png")
439
+ gr.Markdown("""\
440
+ <p align="center"><img src="{}" \
441
+ style="height: 60px"/><p>""".format(logo_url))
442
+ gr.Markdown(
443
+ """\
444
+ <center><font size=3>This demo is based on ERNIE models. \
445
+ (本演示基于文心大模型实现。)</center>"""
446
+ )
447
+ gr.Markdown("""\
448
+ <center><font size=4>
449
+ <a href="https://yiyan.baidu.com/">eb-45t</a> |
450
+ &nbsp<a href="https://yiyan.baidu.com/">eb-x1</a></center>""")
451
+
452
+ chatbot = gr.Chatbot(
453
+ label="ERNIE",
454
+ elem_classes="control-height",
455
+ type="messages"
456
+ )
457
+ with gr.Row():
458
+ model_name = gr.Dropdown(label="Select Model", choices=model_names, value="eb-45t", allow_custom_value=True)
459
+ file_btn = gr.File(
460
+ label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
461
+ height="80px",
462
+ visible=True,
463
+ file_types=[".png", ".jpeg", "jpg"],
464
+ elem_id="file-upload"
465
+ )
466
+ query = gr.Textbox(label="Input", elem_id="text_input")
467
+
468
+ with gr.Row():
469
+ empty_btn = gr.Button("🧹 Clear History(清除历史)")
470
+ submit_btn = gr.Button("🚀 Submit(发送)", elem_id="submit-button")
471
+ regen_btn = gr.Button("🤔️ Regenerate(重试)")
472
+
473
+ with gr.Accordion("⚙️ Advanced Config", open=False): # open=False means collapsed by default
474
+ system_message = gr.Textbox(value="", label="System message", visible=True)
475
+ additional_inputs = [
476
+ system_message,
477
+ gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
478
+ gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Temperature"),
479
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Top-p (nucleus sampling)")
480
+ ]
481
+
482
+ task_history = gr.State([])
483
+ image_history = gr.State({})
484
+
485
+ model_name.change(
486
+ GradioEvents.toggle_components_visibility,
487
+ inputs=model_name,
488
+ outputs=[file_btn, system_message]
489
+ )
490
+ model_name.change(
491
+ GradioEvents.reset_state,
492
+ outputs=[chatbot, task_history, image_history, file_btn],
493
+ show_progress=True
494
+ )
495
+ predict_with_clients = partial(
496
+ GradioEvents.predict_stream,
497
+ bot_client=bot_client
498
+ )
499
+ regenerate_with_clients = partial(
500
+ GradioEvents.regenerate,
501
+ bot_client=bot_client
502
+ )
503
+ query.submit(
504
+ predict_with_clients,
505
+ inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
506
+ outputs=[chatbot],
507
+ show_progress=True
508
+ )
509
+ query.submit(GradioEvents.reset_user_input, [], [query])
510
+ submit_btn.click(
511
+ predict_with_clients,
512
+ inputs=[query, chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
513
+ outputs=[chatbot],
514
+ show_progress=True,
515
+ )
516
+ submit_btn.click(GradioEvents.reset_user_input, [], [query])
517
+ empty_btn.click(
518
+ GradioEvents.reset_state,
519
+ outputs=[chatbot, task_history, image_history, file_btn],
520
+ show_progress=True
521
+ )
522
+ regen_btn.click(
523
+ regenerate_with_clients,
524
+ inputs=[chatbot, task_history, image_history, model_name, file_btn] + additional_inputs,
525
+ outputs=[chatbot],
526
+ show_progress=True
527
+ )
528
+
529
+ demo.queue().launch(
530
+ server_port=args.server_port,
531
+ server_name=args.server_name
532
+ )
533
+
534
+ def main():
535
+ """Main function that runs when this script is executed."""
536
+ args = get_args()
537
+ bot_client = BotClient(args)
538
+ launch_demo(args, bot_client)
539
+
540
+ if __name__ == "__main__":
541
+ main()
assets/logo.png ADDED
bot_requests.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BotClient class for interacting with bot models."""
16
+
17
+ import os
18
+ import argparse
19
+ import logging
20
+ import traceback
21
+ import json
22
+ import jieba
23
+ from openai import OpenAI
24
+
25
+ from appbuilder.mcp_server.client import MCPClient
26
+
27
+ class BotClient(object):
28
+ """Client for interacting with various AI models."""
29
+ def __init__(self, args: argparse.Namespace):
30
+ """
31
+ Initializes the BotClient instance by configuring essential parameters from command line arguments
32
+ including retry limits, character constraints, model endpoints and API credentials while setting up
33
+ default values for missing arguments to ensure robust operation.
34
+
35
+ Args:
36
+ args (argparse.Namespace): Command line arguments containing configuration parameters.
37
+ Uses getattr() to safely retrieve values with fallback defaults.
38
+ """
39
+ self.logger = logging.getLogger(__name__)
40
+
41
+ self.max_retry_num = getattr(args, 'max_retry_num', 3)
42
+ self.max_char = getattr(args, 'max_char', 8000)
43
+
44
+ self.eb45t_model_url = getattr(args, 'eb45t_model_url', 'eb45t_model_url')
45
+ self.x1_model_url = getattr(args, 'x1_model_url', 'x1_model_url')
46
+ self.api_key = os.environ.get("API_KEY")
47
+
48
+ self.qianfan_url = getattr(args, 'qianfan_url', 'qianfan_url')
49
+ self.qianfan_api_key = getattr(args, 'qianfan_api_key', 'qianfan_api_key')
50
+ self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
51
+
52
+ self.ai_search_service_url = getattr(args, 'ai_search_service_url', 'ai_search_service_url')
53
+
54
+ def call_back(self, host_url: str, req_data: dict) -> dict:
55
+ """
56
+ Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response
57
+ conversion to a compatible dictionary format, and manages any exceptions that may occur during
58
+ the request process while logging errors appropriately.
59
+
60
+ Args:
61
+ host_url (str): The URL to send the request to.
62
+ req_data (dict): The data to send in the request body.
63
+
64
+ Returns:
65
+ dict: Parsed JSON response from the server. Returns empty dict
66
+ if request fails or response is invalid.
67
+ """
68
+ try:
69
+ client = OpenAI(base_url=host_url, api_key=self.api_key)
70
+ response = client.chat.completions.create(
71
+ **req_data
72
+ )
73
+
74
+ # Convert OpenAI response to compatible format
75
+ return response.model_dump()
76
+
77
+ except Exception as e:
78
+ self.logger.error("Stream request failed: {}".format(e))
79
+ raise
80
+
81
+ def call_back_stream(self, host_url: str, req_data: dict) -> dict:
82
+ """
83
+ Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks
84
+ in real-time while handling any exceptions that may occur during the streaming process.
85
+
86
+ Args:
87
+ host_url (str): The URL to send the request to.
88
+ req_data (dict): The data to send in the request body.
89
+
90
+ Returns:
91
+ generator: Generator that yields parsed JSON responses from the server.
92
+ """
93
+ try:
94
+ client = OpenAI(base_url=host_url, api_key=self.api_key)
95
+ response = client.chat.completions.create(
96
+ **req_data,
97
+ stream=True,
98
+ )
99
+ for chunk in response:
100
+ if not chunk.choices:
101
+ continue
102
+
103
+ # Convert OpenAI response to compatible format
104
+ yield chunk.model_dump()
105
+
106
+ except Exception as e:
107
+ self.logger.error("Stream request failed: {}".format(e))
108
+ raise
109
+
110
+ def process(
111
+ self,
112
+ model_name: str,
113
+ req_data: dict,
114
+ max_tokens: int=2048,
115
+ temperature: float=1.0,
116
+ top_p: float=0.7
117
+ ) -> dict:
118
+ """
119
+ Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
120
+ including token limits and sampling settings, truncating messages to fit character limits, making API calls
121
+ with built-in retry mechanism, and logging the full request/response cycle for debugging purposes.
122
+
123
+ Args:
124
+ model_name (str): Name of the model, used to look up the model URL from model_map.
125
+ req_data (dict): Dictionary containing request data, including information to be processed.
126
+ max_tokens (int): Maximum number of tokens to generate.
127
+ temperature (float): Sampling temperature to control the diversity of generated text.
128
+ top_p (float): Cumulative probability threshold to control the diversity of generated text.
129
+
130
+ Returns:
131
+ dict: Dictionary containing the model's processing results.
132
+ """
133
+ model_map = {
134
+ "eb-45t": self.eb45t_model_url,
135
+ "eb-x1": self.x1_model_url
136
+ }
137
+
138
+ model_url = model_map[model_name]
139
+
140
+ req_data["model"] = "ernie-4.5-turbo-vl-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
141
+ req_data["max_tokens"] = max_tokens
142
+ req_data["temperature"] = temperature
143
+ req_data["top_p"] = top_p
144
+ req_data["messages"] = self.truncate_messages(req_data["messages"])
145
+ for _ in range(self.max_retry_num):
146
+ try:
147
+ self.logger.info("[MODEL] {}".format(model_url))
148
+ self.logger.info("[req_data]====>")
149
+ self.logger.info(json.dumps(req_data, ensure_ascii=False))
150
+ res = self.call_back(model_url, req_data)
151
+ self.logger.info("model response")
152
+ self.logger.info(res)
153
+ self.logger.info("-" * 30)
154
+ except Exception as e:
155
+ self.logger.info(e)
156
+ self.logger.info(traceback.format_exc())
157
+ res = {}
158
+ if len(res) != 0 and "error" not in res:
159
+ break
160
+ self.logger.info(json.dumps(res, ensure_ascii=False))
161
+
162
+ return res
163
+
164
+ def process_stream(
165
+ self, model_name: str,
166
+ req_data: dict,
167
+ max_tokens: int=2048,
168
+ temperature: float=1.0,
169
+ top_p: float=0.7
170
+ ) -> dict:
171
+ """
172
+ Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
173
+ implementing a retry mechanism with logging, and streaming back response chunks in real-time while
174
+ handling any errors that may occur during the streaming session.
175
+
176
+ Args:
177
+ model_name (str): Name of the model, used to look up the model URL from model_map.
178
+ req_data (dict): Dictionary containing request data, including information to be processed.
179
+ max_tokens (int): Maximum number of tokens to generate.
180
+ temperature (float): Sampling temperature to control the diversity of generated text.
181
+ top_p (float): Cumulative probability threshold to control the diversity of generated text.
182
+
183
+ Yields:
184
+ dict: Dictionary containing the model's processing results.
185
+ """
186
+ model_map = {
187
+ "eb-45t": self.eb45t_model_url,
188
+ "eb-x1": self.x1_model_url
189
+ }
190
+
191
+ model_url = model_map[model_name]
192
+ req_data["model"] = "ernie-4.5-turbo-vl-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
193
+ req_data["max_tokens"] = max_tokens
194
+ req_data["temperature"] = temperature
195
+ req_data["top_p"] = top_p
196
+ req_data["messages"] = self.truncate_messages(req_data["messages"])
197
+
198
+ last_error = None
199
+ for _ in range(self.max_retry_num):
200
+ try:
201
+ self.logger.info("[MODEL] {}".format(model_url))
202
+ self.logger.info("[req_data]====>")
203
+ self.logger.info(json.dumps(req_data, ensure_ascii=False))
204
+
205
+ for chunk in self.call_back_stream(model_url, req_data):
206
+ yield chunk
207
+ return
208
+
209
+ except Exception as e:
210
+ last_error = e
211
+ self.logger.error("Stream request failed (attempt {}/{}): {}".format(_ + 1, self.max_retry_num, e))
212
+
213
+ self.logger.error("All retry attempts failed for stream request")
214
+ yield {"error": str(last_error)}
215
+
216
+ def cut_chinese_english(self, text: str) -> list:
217
+ """
218
+ Segments mixed Chinese and English text into individual components using Jieba for Chinese words
219
+ while preserving English words as whole units, with special handling for Unicode character ranges
220
+ to distinguish between the two languages.
221
+
222
+ Args:
223
+ text (str): Input string to be segmented.
224
+
225
+ Returns:
226
+ list: A list of segments, where each segment is either a letter or a word.
227
+ """
228
+ words = jieba.lcut(text)
229
+ en_ch_words = []
230
+
231
+ for word in words:
232
+ if word.isalpha() and not any("\u4e00" <= char <= "\u9fff" for char in word):
233
+ en_ch_words.append(word)
234
+ else:
235
+ en_ch_words.extend(list(word))
236
+ return en_ch_words
237
+
238
+ def truncate_messages(self, messages: list[dict]) -> list:
239
+ """
240
+ Truncates conversation messages to fit within the maximum character limit (self.max_char)
241
+ by intelligently removing content while preserving message structure. The truncation follows
242
+ a prioritized order: historical messages first, then system message, and finally the last message.
243
+
244
+ Args:
245
+ messages (list[dict]): List of messages to be truncated.
246
+
247
+ Returns:
248
+ list[dict]: Modified list of messages after truncation.
249
+ """
250
+ if not messages:
251
+ return messages
252
+
253
+ processed = []
254
+ total_units = 0
255
+
256
+ for msg in messages:
257
+ # Handle two different content formats
258
+ if isinstance(msg["content"], str):
259
+ text_content = msg["content"]
260
+ elif isinstance(msg["content"], list):
261
+ text_content = msg["content"][1]["text"]
262
+ else:
263
+ text_content = ""
264
+
265
+ # Calculate unit count after tokenization
266
+ units = self.cut_chinese_english(text_content)
267
+ unit_count = len(units)
268
+
269
+ processed.append({
270
+ "role": msg["role"],
271
+ "original_content": msg["content"], # Preserve original content
272
+ "text_content": text_content, # Extracted plain text
273
+ "units": units,
274
+ "unit_count": unit_count
275
+ })
276
+ total_units += unit_count
277
+
278
+ if total_units <= self.max_char:
279
+ return messages
280
+
281
+ # Number of units to remove
282
+ to_remove = total_units - self.max_char
283
+
284
+ # 1. Truncate historical messages
285
+ for i in range(1, len(processed) - 1):
286
+ if to_remove <= 0:
287
+ break
288
+
289
+ # current = processed[i]
290
+ if processed[i]["unit_count"] <= to_remove:
291
+ processed[i]["text_content"] = ""
292
+ to_remove -= processed[i]["unit_count"]
293
+ if isinstance(processed[i]["original_content"], str):
294
+ processed[i]["original_content"] = ""
295
+ elif isinstance(processed[i]["original_content"], list):
296
+ processed[i]["original_content"][1]["text"] = ""
297
+ else:
298
+ kept_units = processed[i]["units"][:-to_remove]
299
+ new_text = "".join(kept_units)
300
+ processed[i]["text_content"] = new_text
301
+ if isinstance(processed[i]["original_content"], str):
302
+ processed[i]["original_content"] = new_text
303
+ elif isinstance(processed[i]["original_content"], list):
304
+ processed[i]["original_content"][1]["text"] = new_text
305
+ to_remove = 0
306
+
307
+ # 2. Truncate system message
308
+ if to_remove > 0:
309
+ system_msg = processed[0]
310
+ if system_msg["unit_count"] <= to_remove:
311
+ processed[0]["text_content"] = ""
312
+ to_remove -= system_msg["unit_count"]
313
+ if isinstance(processed[0]["original_content"], str):
314
+ processed[0]["original_content"] = ""
315
+ elif isinstance(processed[0]["original_content"], list):
316
+ processed[0]["original_content"][1]["text"] = ""
317
+ else:
318
+ kept_units = system_msg["units"][:-to_remove]
319
+ new_text = "".join(kept_units)
320
+ processed[0]["text_content"] = new_text
321
+ if isinstance(processed[0]["original_content"], str):
322
+ processed[0]["original_content"] = new_text
323
+ elif isinstance(processed[0]["original_content"], list):
324
+ processed[0]["original_content"][1]["text"] = new_text
325
+ to_remove = 0
326
+
327
+ # 3. Truncate last message
328
+ if to_remove > 0 and len(processed) > 1:
329
+ last_msg = processed[-1]
330
+ if last_msg["unit_count"] > to_remove:
331
+ kept_units = last_msg["units"][:-to_remove]
332
+ new_text = "".join(kept_units)
333
+ last_msg["text_content"] = new_text
334
+ if isinstance(last_msg["original_content"], str):
335
+ last_msg["original_content"] = new_text
336
+ elif isinstance(last_msg["original_content"], list):
337
+ last_msg["original_content"][1]["text"] = new_text
338
+ else:
339
+ last_msg["text_content"] = ""
340
+ if isinstance(last_msg["original_content"], str):
341
+ last_msg["original_content"] = ""
342
+ elif isinstance(last_msg["original_content"], list):
343
+ last_msg["original_content"][1]["text"] = ""
344
+
345
+ result = []
346
+ for msg in processed:
347
+ if msg["text_content"]:
348
+ result.append({
349
+ "role": msg["role"],
350
+ "content": msg["original_content"]
351
+ })
352
+
353
+ return result
354
+
355
+ def embed_fn(self, text: str) -> list:
356
+ """
357
+ Generate an embedding for the given text using the QianFan API.
358
+
359
+ Args:
360
+ text (str): The input text to be embedded.
361
+
362
+ Returns:
363
+ list: A list of floats representing the embedding.
364
+ """
365
+ client = OpenAI(base_url=self.qianfan_url, api_key=self.qianfan_api_key)
366
+ response = client.embeddings.create(input=[text], model=self.embedding_model)
367
+ return response.data[0].embedding
368
+
369
+ async def get_ai_search_res(self, query_list: list) -> list:
370
+ """
371
+ Get AI search results for the given queries using the MCPClient.
372
+
373
+ Args:
374
+ query_list (list): List of queries to search for.
375
+
376
+ Returns:
377
+ list: List of search results as strings.
378
+ """
379
+ try:
380
+ client = MCPClient()
381
+ await client.connect_to_server(service_url=self.ai_search_service_url)
382
+ result = []
383
+ for query in query_list:
384
+ response = await client.call_tool("AIsearch", {"query": query})
385
+ result.append(response.content[0].text)
386
+ finally:
387
+ await client.cleanup()
388
+ return result
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requires Python 3.10-3.12
2
+ appbuilder_sdk==1.0.6
3
+ crawl4ai==0.6.3
4
+ docx==0.2.4
5
+ faiss-cpu==1.9.0
6
+ gradio==5.27.1
7
+ jieba==0.42.1
8
+ mcp==1.9.4
9
+ numpy==2.2.6
10
+ openai==1.88.0
11
+ pdfplumber==0.11.7
12
+ python_docx==1.1.2
13
+ Requests==2.32.4
14
+ sse-starlette==2.3.6