Spaces:
Running
Running
maxiaolong03
commited on
Commit
·
9de4aae
1
Parent(s):
eea129f
add files
Browse files- app.py +61 -69
- bot_requests.py +40 -38
app.py
CHANGED
@@ -17,6 +17,7 @@
|
|
17 |
import argparse
|
18 |
from collections import namedtuple
|
19 |
from functools import partial
|
|
|
20 |
import logging
|
21 |
import os
|
22 |
import base64
|
@@ -30,6 +31,8 @@ os.environ["NO_PROXY"] = "localhost,127.0.0.1" # Disable proxy
|
|
30 |
|
31 |
logging.root.setLevel(logging.INFO)
|
32 |
|
|
|
|
|
33 |
|
34 |
def get_args() -> argparse.Namespace:
|
35 |
"""
|
@@ -38,17 +41,10 @@ def get_args() -> argparse.Namespace:
|
|
38 |
The arguments include:
|
39 |
- Server port and name for the Gradio interface
|
40 |
- Character limits and retry settings for conversation handling
|
41 |
-
- Model
|
42 |
-
- API keys and other service configurations
|
43 |
|
44 |
Returns:
|
45 |
-
argparse.Namespace: Parsed command line arguments containing
|
46 |
-
- server_port (int): Port number for the demo server (default: 8232)
|
47 |
-
- server_name (str): Hostname/IP for the server (default: "0.0.0.0")
|
48 |
-
- max_char (int): Maximum character limit for messages (default: 8000)
|
49 |
-
- max_retry_num (int): Maximum retry attempts for API calls (default: 3)
|
50 |
-
- eb45t_model_url (str): Endpoint URL for the multimodal model
|
51 |
-
- x1_model_url (str): Endpoint URL for the text inference model
|
52 |
"""
|
53 |
parser = ArgumentParser(description="ERNIE models web chat demo.")
|
54 |
|
@@ -65,19 +61,38 @@ def get_args() -> argparse.Namespace:
|
|
65 |
"--max_retry_num", type=int, default=3, help="Maximum retry number for request."
|
66 |
)
|
67 |
parser.add_argument(
|
68 |
-
"--
|
69 |
type=str,
|
70 |
-
default="
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
help="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
)
|
79 |
|
80 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
return args
|
82 |
|
83 |
|
@@ -85,7 +100,6 @@ class GradioEvents(object):
|
|
85 |
"""
|
86 |
Central handler for all Gradio interface events in the chatbot demo. Provides static methods
|
87 |
for processing user interactions including:
|
88 |
-
- Streaming chat predictions with reasoning steps
|
89 |
- Response regeneration
|
90 |
- Conversation state management
|
91 |
- Image handling and URL conversion
|
@@ -127,12 +141,12 @@ class GradioEvents(object):
|
|
127 |
temperature: float,
|
128 |
top_p: float,
|
129 |
bot_client: BotClient
|
130 |
-
) ->
|
131 |
"""
|
132 |
Handles streaming chat interactions by processing user queries and
|
133 |
generating real-time responses from the bot client. Constructs conversation
|
134 |
history including system messages, text inputs and image attachments, then
|
135 |
-
streams back model responses
|
136 |
|
137 |
Args:
|
138 |
query (str): User input.
|
@@ -147,7 +161,7 @@ class GradioEvents(object):
|
|
147 |
bot_client (BotClient): Bot client.
|
148 |
|
149 |
Yields:
|
150 |
-
|
151 |
"""
|
152 |
conversation = []
|
153 |
if system_msg:
|
@@ -174,7 +188,6 @@ class GradioEvents(object):
|
|
174 |
else:
|
175 |
conversation.append({"role": "user", "content": query})
|
176 |
|
177 |
-
|
178 |
try:
|
179 |
req_data = {"messages": conversation}
|
180 |
for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
|
@@ -183,12 +196,9 @@ class GradioEvents(object):
|
|
183 |
|
184 |
message = chunk.get("choices", [{}])[0].get("delta", {})
|
185 |
content = message.get("content", "")
|
186 |
-
reasoning_content = message.get("reasoning_content", "")
|
187 |
|
188 |
-
if reasoning_content:
|
189 |
-
yield {"type": "thinking", "content": reasoning_content}
|
190 |
if content:
|
191 |
-
yield
|
192 |
|
193 |
except Exception as e:
|
194 |
raise gr.Error("Exception: " + repr(e))
|
@@ -209,7 +219,7 @@ class GradioEvents(object):
|
|
209 |
) -> list:
|
210 |
"""
|
211 |
Processes user queries in a streaming manner by coordinating with the chat stream handler,
|
212 |
-
progressively updates the chatbot state with
|
213 |
and maintains conversation history. Handles both text and multimodal inputs while preserving
|
214 |
the interactive chat experience with real-time updates.
|
215 |
|
@@ -249,35 +259,16 @@ class GradioEvents(object):
|
|
249 |
bot_client
|
250 |
)
|
251 |
|
252 |
-
reasoning_content = ""
|
253 |
response = ""
|
254 |
-
|
255 |
-
|
256 |
-
if not isinstance(new_text, dict):
|
257 |
-
continue
|
258 |
-
|
259 |
-
if new_text.get("type") == "thinking":
|
260 |
-
has_thinking = True
|
261 |
-
reasoning_content += new_text["content"]
|
262 |
-
|
263 |
-
elif new_text.get("type") == "answer":
|
264 |
-
response += new_text["content"]
|
265 |
|
266 |
-
# Remove previous
|
267 |
if chatbot[-1].get("role") == "assistant":
|
268 |
chatbot.pop(-1)
|
269 |
-
|
270 |
-
content = ""
|
271 |
-
if has_thinking:
|
272 |
-
content = "**思考过程:**<br>{}<br>".format(reasoning_content)
|
273 |
if response:
|
274 |
-
|
275 |
-
content += "<br><br>**最终回答:**<br>{}".format(response)
|
276 |
-
else:
|
277 |
-
content = response
|
278 |
-
|
279 |
-
if content:
|
280 |
-
chatbot.append({"role": "assistant", "content": content})
|
281 |
yield chatbot
|
282 |
|
283 |
logging.info("History: {}".format(task_history))
|
@@ -387,7 +378,7 @@ class GradioEvents(object):
|
|
387 |
gc.collect()
|
388 |
|
389 |
@staticmethod
|
390 |
-
def toggle_components_visibility(model_name: str) ->
|
391 |
"""
|
392 |
Toggle visibility of components depending on the selected model name.
|
393 |
|
@@ -395,13 +386,9 @@ class GradioEvents(object):
|
|
395 |
model_name (str): Name of the selected model.
|
396 |
|
397 |
Returns:
|
398 |
-
|
399 |
"""
|
400 |
-
|
401 |
-
return (
|
402 |
-
gr.update(visible=is_eb45t), # file_btn
|
403 |
-
gr.update(visible=is_eb45t) # system_message
|
404 |
-
)
|
405 |
|
406 |
|
407 |
def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
@@ -426,14 +413,11 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
|
426 |
/* Insert English prompt text below the SVG icon */
|
427 |
#file-upload .wrap::after {
|
428 |
content: "Drag and drop files here or click to upload";
|
429 |
-
font-size:
|
430 |
color: #555;
|
431 |
-
margin-top: 8px;
|
432 |
white-space: nowrap;
|
433 |
}
|
434 |
"""
|
435 |
-
model_names = ["eb-45t", "eb-x1"]
|
436 |
-
|
437 |
with gr.Blocks(css=css) as demo:
|
438 |
logo_url = GradioEvents.get_image_url("assets/logo.png")
|
439 |
gr.Markdown("""\
|
@@ -444,18 +428,20 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
|
444 |
<center><font size=3>This demo is based on ERNIE models. \
|
445 |
(本演示基于文心大模型实现。)</center>"""
|
446 |
)
|
447 |
-
gr.Markdown("""\
|
448 |
-
<center><font size=4>
|
449 |
-
<a href="https://yiyan.baidu.com/">eb-45t</a> |
|
450 |
-
 <a href="https://yiyan.baidu.com/">eb-x1</a></center>""")
|
451 |
|
452 |
chatbot = gr.Chatbot(
|
453 |
label="ERNIE",
|
454 |
elem_classes="control-height",
|
455 |
type="messages"
|
456 |
)
|
|
|
457 |
with gr.Row():
|
458 |
-
model_name = gr.Dropdown(
|
|
|
|
|
|
|
|
|
|
|
459 |
file_btn = gr.File(
|
460 |
label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
|
461 |
height="80px",
|
@@ -485,7 +471,7 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
|
485 |
model_name.change(
|
486 |
GradioEvents.toggle_components_visibility,
|
487 |
inputs=model_name,
|
488 |
-
outputs=
|
489 |
)
|
490 |
model_name.change(
|
491 |
GradioEvents.reset_state,
|
@@ -526,6 +512,12 @@ def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
|
526 |
show_progress=True
|
527 |
)
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
demo.queue().launch(
|
530 |
server_port=args.server_port,
|
531 |
server_name=args.server_name
|
|
|
17 |
import argparse
|
18 |
from collections import namedtuple
|
19 |
from functools import partial
|
20 |
+
import json
|
21 |
import logging
|
22 |
import os
|
23 |
import base64
|
|
|
31 |
|
32 |
logging.root.setLevel(logging.INFO)
|
33 |
|
34 |
+
MULTI_MODEL_PREFIX = "ernie-4.5-turbo-vl"
|
35 |
+
|
36 |
|
37 |
def get_args() -> argparse.Namespace:
|
38 |
"""
|
|
|
41 |
The arguments include:
|
42 |
- Server port and name for the Gradio interface
|
43 |
- Character limits and retry settings for conversation handling
|
44 |
+
- Model name to endpoint mappings for the chatbot
|
|
|
45 |
|
46 |
Returns:
|
47 |
+
argparse.Namespace: Parsed command line arguments containing all the above settings
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
"""
|
49 |
parser = ArgumentParser(description="ERNIE models web chat demo.")
|
50 |
|
|
|
61 |
"--max_retry_num", type=int, default=3, help="Maximum retry number for request."
|
62 |
)
|
63 |
parser.add_argument(
|
64 |
+
"--model_map",
|
65 |
type=str,
|
66 |
+
default="""{
|
67 |
+
"ernie-4.5-turbo-128k": "https://qianfan.baidubce.com/v2",
|
68 |
+
"ernie-4.5-turbo-32k": "https://qianfan.baidubce.com/v2",
|
69 |
+
"ernie-4.5-8k-preview": "https://qianfan.baidubce.com/v2",
|
70 |
+
"ernie-4.5-turbo-vl-32k": "https://qianfan.baidubce.com/v2",
|
71 |
+
"ernie-4.5-turbo-vl-32k-preview": "https://qianfan.baidubce.com/v2"
|
72 |
+
}""",
|
73 |
+
help="""JSON string defining model name to endpoint mappings.
|
74 |
+
Required Format:
|
75 |
+
{"model_name": "http://localhost:port/v1", ...}
|
76 |
+
|
77 |
+
Note:
|
78 |
+
- All endpoints must be valid HTTP URLs
|
79 |
+
- At least one model must be specified
|
80 |
+
- Prefix determines model capabilities:
|
81 |
+
* ERNIE-4.5[-*]: Text-only model
|
82 |
+
* ERNIE-4.5-VL[-*]: Multimodal models (image+text)
|
83 |
+
"""
|
84 |
)
|
85 |
|
86 |
args = parser.parse_args()
|
87 |
+
try:
|
88 |
+
args.model_map = json.loads(args.model_map)
|
89 |
+
|
90 |
+
# Validation: Check at least one model exists
|
91 |
+
if len(args.model_map) < 1:
|
92 |
+
raise ValueError("model_map must contain at least one model configuration")
|
93 |
+
except json.JSONDecodeError as e:
|
94 |
+
raise ValueError("Invalid JSON format for --model-map") from e
|
95 |
+
|
96 |
return args
|
97 |
|
98 |
|
|
|
100 |
"""
|
101 |
Central handler for all Gradio interface events in the chatbot demo. Provides static methods
|
102 |
for processing user interactions including:
|
|
|
103 |
- Response regeneration
|
104 |
- Conversation state management
|
105 |
- Image handling and URL conversion
|
|
|
141 |
temperature: float,
|
142 |
top_p: float,
|
143 |
bot_client: BotClient
|
144 |
+
) -> str:
|
145 |
"""
|
146 |
Handles streaming chat interactions by processing user queries and
|
147 |
generating real-time responses from the bot client. Constructs conversation
|
148 |
history including system messages, text inputs and image attachments, then
|
149 |
+
streams back model responses.
|
150 |
|
151 |
Args:
|
152 |
query (str): User input.
|
|
|
161 |
bot_client (BotClient): Bot client.
|
162 |
|
163 |
Yields:
|
164 |
+
str: Model response.
|
165 |
"""
|
166 |
conversation = []
|
167 |
if system_msg:
|
|
|
188 |
else:
|
189 |
conversation.append({"role": "user", "content": query})
|
190 |
|
|
|
191 |
try:
|
192 |
req_data = {"messages": conversation}
|
193 |
for chunk in bot_client.process_stream(model_name, req_data, max_tokens, temperature, top_p):
|
|
|
196 |
|
197 |
message = chunk.get("choices", [{}])[0].get("delta", {})
|
198 |
content = message.get("content", "")
|
|
|
199 |
|
|
|
|
|
200 |
if content:
|
201 |
+
yield content
|
202 |
|
203 |
except Exception as e:
|
204 |
raise gr.Error("Exception: " + repr(e))
|
|
|
219 |
) -> list:
|
220 |
"""
|
221 |
Processes user queries in a streaming manner by coordinating with the chat stream handler,
|
222 |
+
progressively updates the chatbot state with responses,
|
223 |
and maintains conversation history. Handles both text and multimodal inputs while preserving
|
224 |
the interactive chat experience with real-time updates.
|
225 |
|
|
|
259 |
bot_client
|
260 |
)
|
261 |
|
|
|
262 |
response = ""
|
263 |
+
for new_text in new_texts:
|
264 |
+
response += new_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
+
# Remove previous message if exists
|
267 |
if chatbot[-1].get("role") == "assistant":
|
268 |
chatbot.pop(-1)
|
269 |
+
|
|
|
|
|
|
|
270 |
if response:
|
271 |
+
chatbot.append({"role": "assistant", "content": response})
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
yield chatbot
|
273 |
|
274 |
logging.info("History: {}".format(task_history))
|
|
|
378 |
gc.collect()
|
379 |
|
380 |
@staticmethod
|
381 |
+
def toggle_components_visibility(model_name: str) -> gr.update:
|
382 |
"""
|
383 |
Toggle visibility of components depending on the selected model name.
|
384 |
|
|
|
386 |
model_name (str): Name of the selected model.
|
387 |
|
388 |
Returns:
|
389 |
+
gr.update: An update object representing the visibility of the file button.
|
390 |
"""
|
391 |
+
return gr.update(visible=model_name.startswith(MULTI_MODEL_PREFIX)) # file_btn
|
|
|
|
|
|
|
|
|
392 |
|
393 |
|
394 |
def launch_demo(args: argparse.Namespace, bot_client: BotClient):
|
|
|
413 |
/* Insert English prompt text below the SVG icon */
|
414 |
#file-upload .wrap::after {
|
415 |
content: "Drag and drop files here or click to upload";
|
416 |
+
font-size: 15px;
|
417 |
color: #555;
|
|
|
418 |
white-space: nowrap;
|
419 |
}
|
420 |
"""
|
|
|
|
|
421 |
with gr.Blocks(css=css) as demo:
|
422 |
logo_url = GradioEvents.get_image_url("assets/logo.png")
|
423 |
gr.Markdown("""\
|
|
|
428 |
<center><font size=3>This demo is based on ERNIE models. \
|
429 |
(本演示基于文心大模型实现。)</center>"""
|
430 |
)
|
|
|
|
|
|
|
|
|
431 |
|
432 |
chatbot = gr.Chatbot(
|
433 |
label="ERNIE",
|
434 |
elem_classes="control-height",
|
435 |
type="messages"
|
436 |
)
|
437 |
+
model_names = list(args.model_map.keys())
|
438 |
with gr.Row():
|
439 |
+
model_name = gr.Dropdown(
|
440 |
+
label="Select Model",
|
441 |
+
choices=model_names,
|
442 |
+
value=model_names[0],
|
443 |
+
allow_custom_value=True
|
444 |
+
)
|
445 |
file_btn = gr.File(
|
446 |
label="Image upload (Active only for multimodal models. Accepted formats: PNG, JPEG, JPG)",
|
447 |
height="80px",
|
|
|
471 |
model_name.change(
|
472 |
GradioEvents.toggle_components_visibility,
|
473 |
inputs=model_name,
|
474 |
+
outputs=file_btn
|
475 |
)
|
476 |
model_name.change(
|
477 |
GradioEvents.reset_state,
|
|
|
512 |
show_progress=True
|
513 |
)
|
514 |
|
515 |
+
demo.load(
|
516 |
+
GradioEvents.toggle_components_visibility,
|
517 |
+
inputs=gr.State(model_names[0]),
|
518 |
+
outputs=file_btn
|
519 |
+
)
|
520 |
+
|
521 |
demo.queue().launch(
|
522 |
server_port=args.server_port,
|
523 |
server_name=args.server_name
|
bot_requests.py
CHANGED
@@ -22,7 +22,7 @@ import json
|
|
22 |
import jieba
|
23 |
from openai import OpenAI
|
24 |
|
25 |
-
|
26 |
|
27 |
class BotClient(object):
|
28 |
"""Client for interacting with various AI models."""
|
@@ -41,15 +41,16 @@ class BotClient(object):
|
|
41 |
self.max_retry_num = getattr(args, 'max_retry_num', 3)
|
42 |
self.max_char = getattr(args, 'max_char', 8000)
|
43 |
|
44 |
-
self.
|
45 |
-
self.x1_model_url = getattr(args, 'x1_model_url', 'x1_model_url')
|
46 |
self.api_key = os.environ.get("API_KEY")
|
47 |
|
48 |
-
self.
|
49 |
-
self.qianfan_api_key = getattr(args, 'qianfan_api_key', 'qianfan_api_key')
|
50 |
self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
|
51 |
|
52 |
-
self.
|
|
|
|
|
|
|
53 |
|
54 |
def call_back(self, host_url: str, req_data: dict) -> dict:
|
55 |
"""
|
@@ -130,14 +131,9 @@ class BotClient(object):
|
|
130 |
Returns:
|
131 |
dict: Dictionary containing the model's processing results.
|
132 |
"""
|
133 |
-
|
134 |
-
"eb-45t": self.eb45t_model_url,
|
135 |
-
"eb-x1": self.x1_model_url
|
136 |
-
}
|
137 |
-
|
138 |
-
model_url = model_map[model_name]
|
139 |
|
140 |
-
req_data["model"] =
|
141 |
req_data["max_tokens"] = max_tokens
|
142 |
req_data["temperature"] = temperature
|
143 |
req_data["top_p"] = top_p
|
@@ -157,7 +153,6 @@ class BotClient(object):
|
|
157 |
res = {}
|
158 |
if len(res) != 0 and "error" not in res:
|
159 |
break
|
160 |
-
self.logger.info(json.dumps(res, ensure_ascii=False))
|
161 |
|
162 |
return res
|
163 |
|
@@ -183,13 +178,8 @@ class BotClient(object):
|
|
183 |
Yields:
|
184 |
dict: Dictionary containing the model's processing results.
|
185 |
"""
|
186 |
-
|
187 |
-
|
188 |
-
"eb-x1": self.x1_model_url
|
189 |
-
}
|
190 |
-
|
191 |
-
model_url = model_map[model_name]
|
192 |
-
req_data["model"] = "ernie-4.5-turbo-vl-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
|
193 |
req_data["max_tokens"] = max_tokens
|
194 |
req_data["temperature"] = temperature
|
195 |
req_data["top_p"] = top_p
|
@@ -282,7 +272,7 @@ class BotClient(object):
|
|
282 |
to_remove = total_units - self.max_char
|
283 |
|
284 |
# 1. Truncate historical messages
|
285 |
-
for i in range(
|
286 |
if to_remove <= 0:
|
287 |
break
|
288 |
|
@@ -362,27 +352,39 @@ class BotClient(object):
|
|
362 |
Returns:
|
363 |
list: A list of floats representing the embedding.
|
364 |
"""
|
365 |
-
client = OpenAI(base_url=self.
|
366 |
response = client.embeddings.create(input=[text], model=self.embedding_model)
|
367 |
return response.data[0].embedding
|
368 |
|
369 |
-
|
370 |
"""
|
371 |
-
|
372 |
-
|
373 |
Args:
|
374 |
-
query_list (list): List of queries to
|
375 |
|
376 |
Returns:
|
377 |
-
list: List of
|
378 |
"""
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
import jieba
|
23 |
from openai import OpenAI
|
24 |
|
25 |
+
import requests
|
26 |
|
27 |
class BotClient(object):
|
28 |
"""Client for interacting with various AI models."""
|
|
|
41 |
self.max_retry_num = getattr(args, 'max_retry_num', 3)
|
42 |
self.max_char = getattr(args, 'max_char', 8000)
|
43 |
|
44 |
+
self.model_map = getattr(args, 'model_map', {})
|
|
|
45 |
self.api_key = os.environ.get("API_KEY")
|
46 |
|
47 |
+
self.embedding_service_url = getattr(args, 'embedding_service_url', 'embedding_service_url')
|
|
|
48 |
self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
|
49 |
|
50 |
+
self.web_search_service_url = getattr(args, 'web_search_service_url', 'web_search_service_url')
|
51 |
+
self.max_search_results_num = getattr(args, 'max_search_results_num', 15)
|
52 |
+
|
53 |
+
self.qianfan_api_key = os.environ.get("API_KEY")
|
54 |
|
55 |
def call_back(self, host_url: str, req_data: dict) -> dict:
|
56 |
"""
|
|
|
131 |
Returns:
|
132 |
dict: Dictionary containing the model's processing results.
|
133 |
"""
|
134 |
+
model_url = self.model_map[model_name]
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
req_data["model"] = model_name
|
137 |
req_data["max_tokens"] = max_tokens
|
138 |
req_data["temperature"] = temperature
|
139 |
req_data["top_p"] = top_p
|
|
|
153 |
res = {}
|
154 |
if len(res) != 0 and "error" not in res:
|
155 |
break
|
|
|
156 |
|
157 |
return res
|
158 |
|
|
|
178 |
Yields:
|
179 |
dict: Dictionary containing the model's processing results.
|
180 |
"""
|
181 |
+
model_url = self.model_map[model_name]
|
182 |
+
req_data["model"] = model_name
|
|
|
|
|
|
|
|
|
|
|
183 |
req_data["max_tokens"] = max_tokens
|
184 |
req_data["temperature"] = temperature
|
185 |
req_data["top_p"] = top_p
|
|
|
272 |
to_remove = total_units - self.max_char
|
273 |
|
274 |
# 1. Truncate historical messages
|
275 |
+
for i in range(len(processed) - 1, 1):
|
276 |
if to_remove <= 0:
|
277 |
break
|
278 |
|
|
|
352 |
Returns:
|
353 |
list: A list of floats representing the embedding.
|
354 |
"""
|
355 |
+
client = OpenAI(base_url=self.embedding_service_url, api_key=self.qianfan_api_key)
|
356 |
response = client.embeddings.create(input=[text], model=self.embedding_model)
|
357 |
return response.data[0].embedding
|
358 |
|
359 |
+
def get_web_search_res(self, query_list: list) -> list:
|
360 |
"""
|
361 |
+
Send a request to the AI Search service using the provided API key and service URL.
|
362 |
+
|
363 |
Args:
|
364 |
+
query_list (list): List of queries to send to the AI Search service.
|
365 |
|
366 |
Returns:
|
367 |
+
list: List of responses from the AI Search service.
|
368 |
"""
|
369 |
+
headers = {
|
370 |
+
"Authorization": "Bearer " + self.qianfan_api_key,
|
371 |
+
"Content-Type": "application/json"
|
372 |
+
}
|
373 |
+
|
374 |
+
results = []
|
375 |
+
top_k = self.max_search_results_num // len(query_list)
|
376 |
+
for query in query_list:
|
377 |
+
payload = {
|
378 |
+
"messages": [{"role": "user", "content": query}],
|
379 |
+
"resource_type_filter": [{"type": "web", "top_k": top_k}]
|
380 |
+
}
|
381 |
+
response = requests.post(self.web_search_service_url, headers=headers, json=payload)
|
382 |
+
|
383 |
+
if response.status_code == 200:
|
384 |
+
response = response.json()
|
385 |
+
self.logger.info(response)
|
386 |
+
results.append(response["references"])
|
387 |
+
else:
|
388 |
+
self.logger.info(f"请求失败,状态码: {response.status_code}")
|
389 |
+
self.logger.info(response.text)
|
390 |
+
return results
|