Spaces:
Sleeping
Sleeping
add QA
Browse files- .gitattributes +1 -0
- app.py +160 -0
- config.json +5 -5
- docker-compose.yml +12 -0
- logs/bio_rag_2025-08-25.log +0 -0
- python-services/Retrieve/.gitignore +9 -0
- python-services/Retrieve/Dockerfile +41 -0
- python-services/Retrieve/MCP_USAGE.md +207 -0
- python-services/Retrieve/README_ENV.md +99 -0
- python-services/Retrieve/bio_agent/rewrite_agent.py +255 -0
- python-services/Retrieve/bio_requests/chat_request.py +17 -0
- python-services/Retrieve/bio_requests/rag_request.py +44 -0
- python-services/Retrieve/config/2023JCR(完整).xlsx +3 -0
- python-services/Retrieve/config/app_config_dev.yaml +60 -0
- python-services/Retrieve/config/global_storage.py +121 -0
- python-services/Retrieve/dto/bio_document.py +111 -0
- python-services/Retrieve/main.py +93 -0
- python-services/Retrieve/readme.md +284 -0
- python-services/Retrieve/requirements.txt +19 -0
- python-services/Retrieve/routers/mcp_sensor.py +81 -0
- python-services/Retrieve/routers/sensor.py +83 -0
- python-services/Retrieve/search_service/base_search.py +28 -0
- python-services/Retrieve/search_service/pubmed_search.py +197 -0
- python-services/Retrieve/search_service/web_search.py +163 -0
- python-services/Retrieve/service/__init__.py +0 -0
- python-services/Retrieve/service/chat.py +468 -0
- python-services/Retrieve/service/pubmed_api.py +164 -0
- python-services/Retrieve/service/pubmed_async_api.py +195 -0
- python-services/Retrieve/service/pubmed_xml_parse.py +232 -0
- python-services/Retrieve/service/query_rewrite.py +354 -0
- python-services/Retrieve/service/rag.py +54 -0
- python-services/Retrieve/service/rerank.py +60 -0
- python-services/Retrieve/service/web_search.py +406 -0
- python-services/Retrieve/utils/bio_logger.py +253 -0
- python-services/Retrieve/utils/http_util.py +275 -0
- python-services/Retrieve/utils/i18n_context.py +125 -0
- python-services/Retrieve/utils/i18n_messages.py +262 -0
- python-services/Retrieve/utils/i18n_types.py +12 -0
- python-services/Retrieve/utils/i18n_util.py +302 -0
- python-services/Retrieve/utils/snowflake_id.py +252 -0
- python-services/Retrieve/utils/token_util.py +63 -0
- requirements.txt +18 -1
- requirements_back.txt +15 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.xlsx filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -5,6 +5,10 @@ import json
|
|
5 |
import os
|
6 |
import platform
|
7 |
import time
|
|
|
|
|
|
|
|
|
8 |
|
9 |
if platform.system() == "Windows":
|
10 |
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
@@ -88,6 +92,112 @@ def save_config_to_json(config):
|
|
88 |
st.error(f"Error saving settings file: {str(e)}")
|
89 |
return False
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
# Initialize login session variables
|
92 |
if "authenticated" not in st.session_state:
|
93 |
st.session_state.authenticated = False
|
@@ -843,6 +953,14 @@ async def initialize_session(mcp_config=None):
|
|
843 |
# Load settings from config.json file
|
844 |
mcp_config = load_config_from_json()
|
845 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
# Validate MCP configuration before connecting
|
847 |
st.info("🔍 Validating MCP server configurations...")
|
848 |
config_errors = []
|
@@ -1369,6 +1487,32 @@ with st.sidebar:
|
|
1369 |
# Action buttons section
|
1370 |
st.subheader("🔄 Actions")
|
1371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1372 |
# Reset conversation button
|
1373 |
if st.button("Reset Conversation", use_container_width=True, type="primary"):
|
1374 |
# Reset thread_id
|
@@ -1435,3 +1579,19 @@ if user_query:
|
|
1435 |
st.warning(
|
1436 |
"⚠️ MCP server and agent are not initialized. Please click the 'Apply Settings' button in the left sidebar to initialize."
|
1437 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
6 |
import platform
|
7 |
import time
|
8 |
+
import subprocess
|
9 |
+
import threading
|
10 |
+
import signal
|
11 |
+
import sys
|
12 |
|
13 |
if platform.system() == "Windows":
|
14 |
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
|
|
92 |
st.error(f"Error saving settings file: {str(e)}")
|
93 |
return False
|
94 |
|
95 |
+
def start_retrieve_service():
|
96 |
+
"""
|
97 |
+
启动 Retrieve 服务作为后台进程
|
98 |
+
"""
|
99 |
+
try:
|
100 |
+
# 检查服务是否已经在运行
|
101 |
+
if "retrieve_process" in st.session_state and st.session_state.retrieve_process:
|
102 |
+
try:
|
103 |
+
# 检查进程是否还在运行
|
104 |
+
if st.session_state.retrieve_process.poll() is None:
|
105 |
+
st.info("✅ Retrieve 服务已经在运行")
|
106 |
+
return True
|
107 |
+
except:
|
108 |
+
pass
|
109 |
+
|
110 |
+
# 启动服务
|
111 |
+
st.info("🚀 正在启动 Retrieve 服务...")
|
112 |
+
|
113 |
+
# 构建命令 - 使用 cwd 参数设置工作目录
|
114 |
+
cmd = ["python", "main.py"]
|
115 |
+
|
116 |
+
# 启动进程
|
117 |
+
process = subprocess.Popen(
|
118 |
+
cmd,
|
119 |
+
stdout=subprocess.PIPE,
|
120 |
+
stderr=subprocess.PIPE,
|
121 |
+
text=True,
|
122 |
+
bufsize=1,
|
123 |
+
universal_newlines=True,
|
124 |
+
cwd="python-services/Retrieve" # 设置工作目录
|
125 |
+
)
|
126 |
+
|
127 |
+
# 存储进程引用
|
128 |
+
st.session_state.retrieve_process = process
|
129 |
+
st.session_state.retrieve_started = True
|
130 |
+
|
131 |
+
# 启动后台线程来监控进程输出
|
132 |
+
def monitor_process():
|
133 |
+
try:
|
134 |
+
while process.poll() is None:
|
135 |
+
# 读取输出
|
136 |
+
output = process.stdout.readline()
|
137 |
+
if output:
|
138 |
+
st.info(f"Retrieve 服务: {output.strip()}")
|
139 |
+
|
140 |
+
# 检查错误输出
|
141 |
+
error = process.stderr.readline()
|
142 |
+
if error:
|
143 |
+
st.warning(f"Retrieve 服务错误: {error.strip()}")
|
144 |
+
|
145 |
+
time.sleep(0.1)
|
146 |
+
|
147 |
+
# 进程结束
|
148 |
+
st.warning(f"Retrieve 服务已停止,退出码: {process.returncode}")
|
149 |
+
st.session_state.retrieve_started = False
|
150 |
+
|
151 |
+
except Exception as e:
|
152 |
+
st.error(f"监控 Retrieve 服务时出错: {str(e)}")
|
153 |
+
|
154 |
+
# 启动监控线程
|
155 |
+
monitor_thread = threading.Thread(target=monitor_process, daemon=True)
|
156 |
+
monitor_thread.start()
|
157 |
+
|
158 |
+
# 等待一下确保服务启动
|
159 |
+
time.sleep(2)
|
160 |
+
|
161 |
+
# 检查服务是否成功启动
|
162 |
+
if process.poll() is None:
|
163 |
+
st.success("✅ Retrieve 服务启动成功")
|
164 |
+
return True
|
165 |
+
else:
|
166 |
+
st.error("❌ Retrieve 服务启动失败")
|
167 |
+
return False
|
168 |
+
|
169 |
+
except Exception as e:
|
170 |
+
st.error(f"启动 Retrieve 服务时出错: {str(e)}")
|
171 |
+
return False
|
172 |
+
|
173 |
+
def stop_retrieve_service():
|
174 |
+
"""
|
175 |
+
停止 Retrieve 服务
|
176 |
+
"""
|
177 |
+
try:
|
178 |
+
if "retrieve_process" in st.session_state and st.session_state.retrieve_process:
|
179 |
+
process = st.session_state.retrieve_process
|
180 |
+
if process.poll() is None:
|
181 |
+
# 发送终止信号
|
182 |
+
process.terminate()
|
183 |
+
|
184 |
+
# 等待进程结束
|
185 |
+
try:
|
186 |
+
process.wait(timeout=5)
|
187 |
+
except subprocess.TimeoutExpired:
|
188 |
+
# 强制杀死进程
|
189 |
+
process.kill()
|
190 |
+
|
191 |
+
st.success("✅ Retrieve 服务已停止")
|
192 |
+
else:
|
193 |
+
st.info("Retrieve 服务已经停止")
|
194 |
+
|
195 |
+
st.session_state.retrieve_started = False
|
196 |
+
st.session_state.retrieve_process = None
|
197 |
+
|
198 |
+
except Exception as e:
|
199 |
+
st.error(f"停止 Retrieve 服务时出错: {str(e)}")
|
200 |
+
|
201 |
# Initialize login session variables
|
202 |
if "authenticated" not in st.session_state:
|
203 |
st.session_state.authenticated = False
|
|
|
953 |
# Load settings from config.json file
|
954 |
mcp_config = load_config_from_json()
|
955 |
|
956 |
+
# 自动启动 Retrieve 服务(如果配置中存在)
|
957 |
+
if "bio-qa-chat" in mcp_config:
|
958 |
+
st.info("🚀 检测到 bio-qa-chat 服务,正在启动...")
|
959 |
+
if start_retrieve_service():
|
960 |
+
st.success("✅ Retrieve 服务启动成功")
|
961 |
+
else:
|
962 |
+
st.warning("⚠️ Retrieve 服务启动失败,但继续初始化其他服务")
|
963 |
+
|
964 |
# Validate MCP configuration before connecting
|
965 |
st.info("🔍 Validating MCP server configurations...")
|
966 |
config_errors = []
|
|
|
1487 |
# Action buttons section
|
1488 |
st.subheader("🔄 Actions")
|
1489 |
|
1490 |
+
# Retrieve 服务控制按钮
|
1491 |
+
st.subheader("🔧 Retrieve 服务控制")
|
1492 |
+
|
1493 |
+
col1, col2 = st.columns(2)
|
1494 |
+
|
1495 |
+
with col1:
|
1496 |
+
if st.button("🚀 启动服务", use_container_width=True, type="primary"):
|
1497 |
+
if start_retrieve_service():
|
1498 |
+
st.success("✅ 服务启动成功")
|
1499 |
+
else:
|
1500 |
+
st.error("❌ 服务启动失败")
|
1501 |
+
st.rerun()
|
1502 |
+
|
1503 |
+
with col2:
|
1504 |
+
if st.button("🛑 停止服务", use_container_width=True, type="secondary"):
|
1505 |
+
stop_retrieve_service()
|
1506 |
+
st.rerun()
|
1507 |
+
|
1508 |
+
# 显示服务状态
|
1509 |
+
if st.session_state.get("retrieve_started", False):
|
1510 |
+
st.success("🟢 Retrieve 服务运行中")
|
1511 |
+
else:
|
1512 |
+
st.warning("🔴 Retrieve 服务未运行")
|
1513 |
+
|
1514 |
+
st.divider()
|
1515 |
+
|
1516 |
# Reset conversation button
|
1517 |
if st.button("Reset Conversation", use_container_width=True, type="primary"):
|
1518 |
# Reset thread_id
|
|
|
1579 |
st.warning(
|
1580 |
"⚠️ MCP server and agent are not initialized. Please click the 'Apply Settings' button in the left sidebar to initialize."
|
1581 |
)
|
1582 |
+
|
1583 |
+
# 应用退出时的清理逻辑
|
1584 |
+
def cleanup_on_exit():
|
1585 |
+
"""应用退出时清理资源"""
|
1586 |
+
try:
|
1587 |
+
if "retrieve_process" in st.session_state and st.session_state.retrieve_process:
|
1588 |
+
stop_retrieve_service()
|
1589 |
+
except:
|
1590 |
+
pass
|
1591 |
+
|
1592 |
+
# 注册清理函数
|
1593 |
+
import atexit
|
1594 |
+
atexit.register(cleanup_on_exit)
|
1595 |
+
|
1596 |
+
# 注意:在 Streamlit 中不能使用信号处理器,因为它在子线程中运行
|
1597 |
+
# 清理逻辑通过 atexit 和页面刷新时的状态检查来处理
|
config.json
CHANGED
@@ -27,16 +27,16 @@
|
|
27 |
],
|
28 |
"transport": "stdio"
|
29 |
},
|
30 |
-
"qa": {
|
31 |
-
"transport": "sse",
|
32 |
-
"url": "http://10.15.56.148:9230/sse"
|
33 |
-
},
|
34 |
"review_generate": {
|
35 |
"transport": "sse",
|
36 |
"url": "http://10.15.56.148:8000/review"
|
37 |
},
|
38 |
-
"
|
39 |
"transport": "streamable_http",
|
40 |
"url": "http://127.0.0.1:7860/gradio_api/mcp/"
|
|
|
|
|
|
|
|
|
41 |
}
|
42 |
}
|
|
|
27 |
],
|
28 |
"transport": "stdio"
|
29 |
},
|
|
|
|
|
|
|
|
|
30 |
"review_generate": {
|
31 |
"transport": "sse",
|
32 |
"url": "http://10.15.56.148:8000/review"
|
33 |
},
|
34 |
+
"get_200_words": {
|
35 |
"transport": "streamable_http",
|
36 |
"url": "http://127.0.0.1:7860/gradio_api/mcp/"
|
37 |
+
},
|
38 |
+
"bio-qa-chat": {
|
39 |
+
"transport": "sse",
|
40 |
+
"url": "http://10.15.56.148:9487/sse"
|
41 |
}
|
42 |
}
|
docker-compose.yml
CHANGED
@@ -15,6 +15,18 @@ services:
|
|
15 |
networks:
|
16 |
- mcp-network
|
17 |
restart: unless-stopped
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
networks:
|
20 |
mcp-network:
|
|
|
15 |
networks:
|
16 |
- mcp-network
|
17 |
restart: unless-stopped
|
18 |
+
depends_on:
|
19 |
+
- retrieve-service
|
20 |
+
|
21 |
+
# Retrieve服务
|
22 |
+
retrieve-service:
|
23 |
+
build: ./python-services/Retrieve
|
24 |
+
container_name: retrieve-service
|
25 |
+
ports:
|
26 |
+
- "9487:9487"
|
27 |
+
networks:
|
28 |
+
- mcp-network
|
29 |
+
restart: unless-stopped
|
30 |
|
31 |
networks:
|
32 |
mcp-network:
|
logs/bio_rag_2025-08-25.log
ADDED
File without changes
|
python-services/Retrieve/.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
logs/*
|
2 |
+
*.pyc
|
3 |
+
py_milvus_test.py
|
4 |
+
test_vector_search.py
|
5 |
+
.vscode/settings.json
|
6 |
+
service/Qwen3-Reranker-0.6B
|
7 |
+
test_model_api.py
|
8 |
+
test/logs
|
9 |
+
.conda/*
|
python-services/Retrieve/Dockerfile
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from python:3.11-slim as builder
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# 首先只复制依赖文件
|
6 |
+
COPY requirements.txt .
|
7 |
+
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
8 |
+
RUN pip install -U crawl4ai
|
9 |
+
# 运行安装后设置
|
10 |
+
RUN crawl4ai-setup
|
11 |
+
|
12 |
+
# Verify your installation
|
13 |
+
RUN crawl4ai-doctor
|
14 |
+
|
15 |
+
# RUN python -m playwright install --with-deps chromium
|
16 |
+
# 第二阶段
|
17 |
+
#from python:3.11-slim
|
18 |
+
|
19 |
+
#WORKDIR /app
|
20 |
+
|
21 |
+
# 从构建阶段复制已安装的依赖
|
22 |
+
#COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
23 |
+
#COPY --from=builder /ms-playwright /ms-playwright
|
24 |
+
|
25 |
+
|
26 |
+
# 复制应用代码
|
27 |
+
COPY . .
|
28 |
+
|
29 |
+
|
30 |
+
# 声明端口
|
31 |
+
EXPOSE 9487
|
32 |
+
|
33 |
+
USER root
|
34 |
+
|
35 |
+
|
36 |
+
# 3. 设置缓存路径并赋予权限
|
37 |
+
|
38 |
+
# 4. 切换非root用户(避免权限问题)
|
39 |
+
# RUN useradd -m appuser && chown -R appuser:appuser /app
|
40 |
+
# USER appuser
|
41 |
+
CMD ["python", "main.py"]
|
python-services/Retrieve/MCP_USAGE.md
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MCP 包装服务使用说明
|
2 |
+
|
3 |
+
## 概述
|
4 |
+
|
5 |
+
这个服务使用 `FastApiMCP` 将生物医学RAG服务包装成MCP(Model Context Protocol)工具,可以通过MCP客户端调用。
|
6 |
+
|
7 |
+
## 服务配置
|
8 |
+
|
9 |
+
在 `main.py` 中,服务被包装为:
|
10 |
+
|
11 |
+
```python
|
12 |
+
mcp = FastApiMCP(app, name="bio qa mcp", include_operations=["bio_qa_stream_chat"])
|
13 |
+
mcp.mount_sse()
|
14 |
+
```
|
15 |
+
|
16 |
+
## 可用的MCP操作
|
17 |
+
|
18 |
+
### 1. bio_qa_stream_chat
|
19 |
+
|
20 |
+
这是主要的生物医学问答操作,提供流式RAG问答服务。
|
21 |
+
|
22 |
+
## 调用方式
|
23 |
+
|
24 |
+
### 方式1: 通过MCP客户端调用
|
25 |
+
|
26 |
+
#### 1.1 配置MCP客户端
|
27 |
+
|
28 |
+
在你的MCP客户端配置中添加:
|
29 |
+
|
30 |
+
```json
|
31 |
+
{
|
32 |
+
"bio_qa_mcp": {
|
33 |
+
"url": "http://localhost:9487/sse",
|
34 |
+
"transport": "sse"
|
35 |
+
}
|
36 |
+
}
|
37 |
+
```
|
38 |
+
|
39 |
+
#### 1.2 调用示例
|
40 |
+
|
41 |
+
```python
|
42 |
+
# 使用MCP客户端调用
|
43 |
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
44 |
+
|
45 |
+
# 配置MCP服务器
|
46 |
+
mcp_config = {
|
47 |
+
"bio_qa_mcp": {
|
48 |
+
"url": "http://localhost:9487/sse",
|
49 |
+
"transport": "sse"
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
# 创建客户端
|
54 |
+
client = MultiServerMCPClient(mcp_config)
|
55 |
+
|
56 |
+
# 获取工具
|
57 |
+
tools = await client.get_tools()
|
58 |
+
|
59 |
+
# 使用工具
|
60 |
+
# 工具名称: bio_qa_stream_chat
|
61 |
+
# 参数: query (问题), lang (语言,可选,默认"en")
|
62 |
+
```
|
63 |
+
|
64 |
+
### 方式2: 直接HTTP调用
|
65 |
+
|
66 |
+
#### 2.1 直接调用API端点
|
67 |
+
|
68 |
+
```bash
|
69 |
+
# 调用生物医学问答接口
|
70 |
+
curl -X POST "http://localhost:9487/mcp/bio_qa" \
|
71 |
+
-H "Content-Type: application/x-www-form-urlencoded" \
|
72 |
+
-d "query=什么是糖尿病?&lang=zh"
|
73 |
+
```
|
74 |
+
|
75 |
+
#### 2.2 Python requests调用
|
76 |
+
|
77 |
+
```python
|
78 |
+
import requests
|
79 |
+
|
80 |
+
# 调用接口
|
81 |
+
response = requests.post(
|
82 |
+
"http://localhost:9487/mcp/bio_qa",
|
83 |
+
data={
|
84 |
+
"query": "什么是糖尿病?",
|
85 |
+
"lang": "zh"
|
86 |
+
}
|
87 |
+
)
|
88 |
+
|
89 |
+
# 处理流式响应
|
90 |
+
for line in response.iter_lines():
|
91 |
+
if line:
|
92 |
+
print(line.decode('utf-8'))
|
93 |
+
```
|
94 |
+
|
95 |
+
## 参数说明
|
96 |
+
|
97 |
+
### bio_qa_stream_chat 操作
|
98 |
+
|
99 |
+
- **query** (必需): 问题内容
|
100 |
+
- **lang** (可选): 语言设置
|
101 |
+
- `"zh"`: 中文
|
102 |
+
- `"en"`: 英文(默认)
|
103 |
+
|
104 |
+
## 响应格式
|
105 |
+
|
106 |
+
服务返回流式响应(Server-Sent Events),格式为:
|
107 |
+
|
108 |
+
```
|
109 |
+
data: {"type": "result", "content": "回答内容..."}
|
110 |
+
data: {"type": "result", "content": "更多内容..."}
|
111 |
+
data: {"type": "done", "content": "完成"}
|
112 |
+
```
|
113 |
+
|
114 |
+
## 使用场景
|
115 |
+
|
116 |
+
### 1. 在LangChain中使用
|
117 |
+
|
118 |
+
```python
|
119 |
+
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
120 |
+
from langchain_openai import ChatOpenAI
|
121 |
+
|
122 |
+
# 创建代理
|
123 |
+
llm = ChatOpenAI(model="gpt-4")
|
124 |
+
agent = create_openai_functions_agent(llm, tools, prompt)
|
125 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools)
|
126 |
+
|
127 |
+
# 执行问答
|
128 |
+
result = await agent_executor.ainvoke({
|
129 |
+
"input": "请帮我查询关于糖尿病的相关信息"
|
130 |
+
})
|
131 |
+
```
|
132 |
+
|
133 |
+
### 2. 在Streamlit应用中使用
|
134 |
+
|
135 |
+
```python
|
136 |
+
import streamlit as st
|
137 |
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
138 |
+
|
139 |
+
# 初始化MCP客户端
|
140 |
+
@st.cache_resource
|
141 |
+
def get_mcp_client():
|
142 |
+
config = {
|
143 |
+
"bio_qa_mcp": {
|
144 |
+
"url": "http://localhost:9487/sse",
|
145 |
+
"transport": "sse"
|
146 |
+
}
|
147 |
+
}
|
148 |
+
return MultiServerMCPClient(config)
|
149 |
+
|
150 |
+
# 使用
|
151 |
+
client = get_mcp_client()
|
152 |
+
tools = await client.get_tools()
|
153 |
+
```
|
154 |
+
|
155 |
+
## 部署说明
|
156 |
+
|
157 |
+
### 1. 启动服务
|
158 |
+
|
159 |
+
```bash
|
160 |
+
cd python-services/Retrieve
|
161 |
+
python main.py
|
162 |
+
```
|
163 |
+
|
164 |
+
服务将在 `http://localhost:9487` 启动。
|
165 |
+
|
166 |
+
### 2. 环境变量配置
|
167 |
+
|
168 |
+
确保设置了必要的环境变量:
|
169 |
+
|
170 |
+
```bash
|
171 |
+
export ENVIRONMENT=prod
|
172 |
+
export QA_LLM_MAIN_API_KEY=your-api-key
|
173 |
+
export QA_LLM_MAIN_BASE_URL=your-api-url
|
174 |
+
# ... 其他配置
|
175 |
+
```
|
176 |
+
|
177 |
+
### 3. 网络访问
|
178 |
+
|
179 |
+
- 本地访问: `http://localhost:9487`
|
180 |
+
- 远程访问: `http://your-server-ip:9487`
|
181 |
+
|
182 |
+
## 故障排除
|
183 |
+
|
184 |
+
### 常见问题
|
185 |
+
|
186 |
+
1. **连接失败**: 检查服务是否启动,端口是否正确
|
187 |
+
2. **认证错误**: 检查API密钥配置
|
188 |
+
3. **流式响应中断**: 检查网络连接稳定性
|
189 |
+
|
190 |
+
### 日志查看
|
191 |
+
|
192 |
+
服务会记录详细的日志信息,包括:
|
193 |
+
- 请求处理时间
|
194 |
+
- 错误信息
|
195 |
+
- 操作状态
|
196 |
+
|
197 |
+
## 扩展功能
|
198 |
+
|
199 |
+
### 添加新的MCP操作
|
200 |
+
|
201 |
+
1. 在 `routers/mcp_sensor.py` 中添加新的路由
|
202 |
+
2. 在 `main.py` 的 `include_operations` 中添加操作名称
|
203 |
+
3. 重新启动服务
|
204 |
+
|
205 |
+
### 自定义响应格式
|
206 |
+
|
207 |
+
可以修改 `ChatService` 来定制响应格式,满足不同的MCP客户端需求。
|
python-services/Retrieve/README_ENV.md
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 混合配置说明
|
2 |
+
|
3 |
+
## 概述
|
4 |
+
|
5 |
+
本项目采用混合配置方式:
|
6 |
+
- **大部分配置**:从YAML配置文件加载(`app_config_dev.yaml` 或 `app_config_prod.yaml`)
|
7 |
+
- **敏感配置**:API密钥和base_url从环境变量加载,覆盖YAML文件中的值
|
8 |
+
|
9 |
+
## 环境变量列表
|
10 |
+
|
11 |
+
### 基础配置
|
12 |
+
- `ENVIRONMENT`: 环境类型,可选值:`dev`(开发环境)或 `prod`(生产环境),默认为 `dev`
|
13 |
+
|
14 |
+
### API密钥和Base URL配置(从环境变量加载)
|
15 |
+
|
16 |
+
#### QA LLM 主模型
|
17 |
+
- `QA_LLM_MAIN_API_KEY`: API密钥
|
18 |
+
- `QA_LLM_MAIN_BASE_URL`: API基础URL
|
19 |
+
|
20 |
+
#### QA LLM 备用模型
|
21 |
+
- `QA_LLM_BACKUP_API_KEY`: API密钥
|
22 |
+
- `QA_LLM_BACKUP_BASE_URL`: API基础URL
|
23 |
+
|
24 |
+
#### Rewrite LLM 备用模型 (GPT-4o)
|
25 |
+
- `REWRITE_LLM_BACKUP_API_KEY`: API密钥
|
26 |
+
- `REWRITE_LLM_BACKUP_BASE_URL`: API基础URL
|
27 |
+
|
28 |
+
#### Rewrite LLM 主模型
|
29 |
+
- `REWRITE_LLM_MAIN_API_KEY`: API密钥
|
30 |
+
- `REWRITE_LLM_MAIN_BASE_URL`: API基础URL
|
31 |
+
|
32 |
+
#### Web搜索服务
|
33 |
+
- `SERPER_API_KEY`: Serper API密钥(用于网络搜索)
|
34 |
+
|
35 |
+
## 其他配置(从YAML文件加载)
|
36 |
+
|
37 |
+
以下配置仍然从YAML文件加载,包括:
|
38 |
+
- 模型名称
|
39 |
+
- max_tokens
|
40 |
+
- temperature
|
41 |
+
- recall配置
|
42 |
+
- qa-topk配置
|
43 |
+
- qa-prompt-max-token配置
|
44 |
+
- chat配置
|
45 |
+
|
46 |
+
## 使用方法
|
47 |
+
|
48 |
+
### 1. 设置环境变量
|
49 |
+
|
50 |
+
```bash
|
51 |
+
# 设置环境
|
52 |
+
export ENVIRONMENT=prod
|
53 |
+
|
54 |
+
# 设置API密钥和base_url
|
55 |
+
export QA_LLM_MAIN_API_KEY=your-actual-api-key
|
56 |
+
export QA_LLM_MAIN_BASE_URL=https://your-api-endpoint.com
|
57 |
+
|
58 |
+
export REWRITE_LLM_BACKUP_API_KEY=your-gpt4o-api-key
|
59 |
+
export REWRITE_LLM_BACKUP_BASE_URL=https://api.openai.com/v1
|
60 |
+
|
61 |
+
# 设置Web搜索API密钥
|
62 |
+
export SERPER_API_KEY=your-serper-api-key
|
63 |
+
|
64 |
+
# ... 其他API配置
|
65 |
+
```
|
66 |
+
|
67 |
+
### 2. 在代码中使用
|
68 |
+
|
69 |
+
```python
|
70 |
+
from config.global_storage import get_model_config
|
71 |
+
|
72 |
+
# 获取配置
|
73 |
+
config = get_model_config()
|
74 |
+
|
75 |
+
# 使用配置(API密钥和base_url来自环境变量,其他来自YAML)
|
76 |
+
model_name = config['qa-llm']['main']['model'] # 来自YAML
|
77 |
+
api_key = config['qa-llm']['main']['api_key'] # 来自环境变量
|
78 |
+
base_url = config['qa-llm']['main']['base_url'] # 来自环境变量
|
79 |
+
```
|
80 |
+
|
81 |
+
## 配置优先级
|
82 |
+
|
83 |
+
1. **环境变量**:API密钥和base_url(最高优先级)
|
84 |
+
2. **YAML文件**:其他所有配置(基础配置)
|
85 |
+
|
86 |
+
## 优势
|
87 |
+
|
88 |
+
1. **安全性**: 敏感信息(API密钥)从环境变量加载,不会出现在代码或配置文件中
|
89 |
+
2. **灵活性**: 可以轻松切换不同环境的API端点
|
90 |
+
3. **维护性**: 大部分配置仍在YAML文件中,便于管理和版本控制
|
91 |
+
4. **部署友好**: 生产环境只需要设置环境变量即可
|
92 |
+
|
93 |
+
## 注意事项
|
94 |
+
|
95 |
+
1. 如果环境变量未设置,将使用YAML文件中的默认值
|
96 |
+
2. 确保 `.env` 文件已添加到 `.gitignore` 中
|
97 |
+
3. 生产环境建议使用环境变量而不是 `.env` 文件
|
98 |
+
4. YAML文件中的API密钥和base_url值会被环境变量覆盖
|
99 |
+
5. 对于Web搜索服务,如果未设置 `SERPER_API_KEY`,将使用代码中的默认密钥(不推荐用于生产环境)
|
python-services/Retrieve/bio_agent/rewrite_agent.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Any, List
|
3 |
+
from agents import Agent, OpenAIChatCompletionsModel, Runner
|
4 |
+
from agents.agent_output import AgentOutputSchemaBase
|
5 |
+
from openai import AsyncOpenAI
|
6 |
+
from config.global_storage import get_model_config
|
7 |
+
from utils.bio_logger import bio_logger as logger
|
8 |
+
from typing import List, Dict
|
9 |
+
from pydantic import BaseModel, Field,ConfigDict
|
10 |
+
|
11 |
+
|
12 |
+
class DateRange(BaseModel):
|
13 |
+
# model_config = ConfigDict(strict=True)
|
14 |
+
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["start", "end"]})
|
15 |
+
start: str = Field('', description="Start date in YYYY-MM-DD format")
|
16 |
+
end: str = Field('', description="End date in YYYY-MM-DD format")
|
17 |
+
|
18 |
+
class Journal(BaseModel):
|
19 |
+
# model_config = ConfigDict(strict=True)
|
20 |
+
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "EISSN"]})
|
21 |
+
name: str = Field(..., description="Journal name")
|
22 |
+
EISSN: str = Field(..., description="Journal EISSN")
|
23 |
+
|
24 |
+
class AuthorFilter(BaseModel):
|
25 |
+
# model_config = ConfigDict(strict=True)
|
26 |
+
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "first_author", "last_author"]})
|
27 |
+
name: str = Field("", description="Author name to filter")
|
28 |
+
first_author: bool = Field(False, description="Is first author?")
|
29 |
+
last_author: bool = Field(False, description="Is last author?")
|
30 |
+
|
31 |
+
|
32 |
+
class Filters(BaseModel):
|
33 |
+
# model_config = ConfigDict(strict=True)
|
34 |
+
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["date_range", "article_types", "languages", "subjects", "journals", "author"]})
|
35 |
+
date_range: DateRange = Field(...,default_factory=DateRange)
|
36 |
+
article_types: List[str] = Field(...,default_factory=list)
|
37 |
+
languages: List[str] = Field(["English"],)
|
38 |
+
subjects: List[str] = Field(...,default_factory=list)
|
39 |
+
journals: List[str] = Field([""])
|
40 |
+
author: AuthorFilter = Field(...,default_factory=AuthorFilter)
|
41 |
+
|
42 |
+
class RewriteJsonOutput(BaseModel):
|
43 |
+
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["category", "key_words", "key_journals", "queries", "filters"]})
|
44 |
+
category: str = Field(..., description="Query category")
|
45 |
+
key_words: List[str] = Field(...,default_factory=list)
|
46 |
+
key_journals: List[Journal] = Field(...,default_factory=list)
|
47 |
+
queries: List[str] = Field(...,default_factory=list)
|
48 |
+
filters: Filters = Field(...,default_factory=Filters)
|
49 |
+
|
50 |
+
|
51 |
+
class SimpleJsonOutput(BaseModel):
|
52 |
+
key_words: List[str] = Field(...,default_factory=list)
|
53 |
+
|
54 |
+
|
55 |
+
class RewriteJsonOutputSchema(AgentOutputSchemaBase):
|
56 |
+
def is_plain_text(self):
|
57 |
+
return False
|
58 |
+
def name(self):
|
59 |
+
return "RewriteJsonOutput"
|
60 |
+
def json_schema(self):
|
61 |
+
return RewriteJsonOutput.model_json_schema()
|
62 |
+
def is_strict_json_schema(self):
|
63 |
+
return True
|
64 |
+
def validate_json(self, json_data: Dict[str, Any]) -> bool:
|
65 |
+
try:
|
66 |
+
if isinstance(json_data, str):
|
67 |
+
json_data = json.loads(json_data)
|
68 |
+
return RewriteJsonOutput.model_validate(json_data)
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Validation error: {e}")
|
72 |
+
# return False
|
73 |
+
def parse(self, json_data: Dict[str, Any]) -> Any:
|
74 |
+
if isinstance(json_data, str):
|
75 |
+
json_data = json.loads(json_data)
|
76 |
+
return json_data
|
77 |
+
|
78 |
+
class RewriteAgent:
|
79 |
+
def __init__(self):
|
80 |
+
self.model_config = get_model_config()
|
81 |
+
self.agent_name = "rewrite agent"
|
82 |
+
self.selected_model = OpenAIChatCompletionsModel(
|
83 |
+
model=self.model_config["rewrite-llm"]["main"]["model"],
|
84 |
+
openai_client=AsyncOpenAI(
|
85 |
+
api_key=self.model_config["rewrite-llm"]["main"]["api_key"],
|
86 |
+
base_url=self.model_config["rewrite-llm"]["main"]["base_url"],
|
87 |
+
timeout=120.0,
|
88 |
+
max_retries=2,
|
89 |
+
),
|
90 |
+
)
|
91 |
+
|
92 |
+
# self.openai_client = AsyncOpenAI(
|
93 |
+
# api_key=self.model_config["llm"]["api_key"],
|
94 |
+
# base_url=self.model_config["llm"]["base_url"],
|
95 |
+
# )
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
async def rewrite_query(self, query: str,INSTRUCTIONS: str,simple_version=False) -> List[str]:
|
100 |
+
try:
|
101 |
+
logger.info(f"Rewriting query with main configuration.")
|
102 |
+
if not simple_version:
|
103 |
+
rewrite_agent = Agent(
|
104 |
+
name=self.agent_name,
|
105 |
+
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
|
106 |
+
model=self.selected_model,
|
107 |
+
output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
rewrite_agent = Agent(
|
111 |
+
name=self.agent_name,
|
112 |
+
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
|
113 |
+
model=self.selected_model,
|
114 |
+
output_type=SimpleJsonOutput, # Use the Pydantic model for structured output
|
115 |
+
)
|
116 |
+
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query)
|
117 |
+
# completion = await self.openai_client.chat.completions.create(
|
118 |
+
# model=self.model_config["llm"]["model"],
|
119 |
+
# messages=[
|
120 |
+
# # {
|
121 |
+
# # "role": "system",
|
122 |
+
# # "content": "You are a helpful assistant.",
|
123 |
+
# # },
|
124 |
+
# {
|
125 |
+
# "role": "user",
|
126 |
+
# "content": INSTRUCTIONS +' Here is the question: ' + query,
|
127 |
+
# },
|
128 |
+
# ],
|
129 |
+
# temperature=self.model_config["llm"]["temperature"],
|
130 |
+
# # max_tokens=self.model_config["llm"]["max_tokens"],
|
131 |
+
# )
|
132 |
+
try:
|
133 |
+
# query_result = self.parse_json_output(completion.choices[0].message.content)
|
134 |
+
query_result = self.parse_json_output(result.final_output.model_dump_json())
|
135 |
+
# query_result = self.parse_json_output(completion.model_dump_json())
|
136 |
+
except Exception as e:
|
137 |
+
# print(completion.choices[0].message.content)
|
138 |
+
logger.error(f"Failed to parse JSON output: {e}")
|
139 |
+
return query_result
|
140 |
+
except Exception as main_error:
|
141 |
+
self.selected_model_backup = OpenAIChatCompletionsModel(
|
142 |
+
model=self.model_config["rewrite-llm"]["backup"]["model"],
|
143 |
+
openai_client=AsyncOpenAI(
|
144 |
+
api_key=self.model_config["rewrite-llm"]["backup"]["api_key"],
|
145 |
+
base_url=self.model_config["rewrite-llm"]["backup"]["base_url"],
|
146 |
+
timeout=120.0,
|
147 |
+
max_retries=2,
|
148 |
+
),
|
149 |
+
)
|
150 |
+
logger.error(f"Error with main model: {main_error}", exc_info=main_error)
|
151 |
+
logger.info("Trying backup model for rewriting query.")
|
152 |
+
if not simple_version:
|
153 |
+
rewrite_agent = Agent(
|
154 |
+
name=self.agent_name,
|
155 |
+
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
|
156 |
+
model=self.selected_model_backup,
|
157 |
+
output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
rewrite_agent = Agent(
|
161 |
+
name=self.agent_name,
|
162 |
+
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
|
163 |
+
model=self.selected_model_backup,
|
164 |
+
output_type=SimpleJsonOutput, # Use the Pydantic model for structured output
|
165 |
+
)
|
166 |
+
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query)
|
167 |
+
# completion = await self.openai_client.chat.completions.create(
|
168 |
+
# model=self.model_config["llm"]["model"],
|
169 |
+
# messages=[
|
170 |
+
# # {
|
171 |
+
# # "role": "system",
|
172 |
+
# # "content": "You are a helpful assistant.",
|
173 |
+
# # },
|
174 |
+
# {
|
175 |
+
# "role": "user",
|
176 |
+
# "content": INSTRUCTIONS +' Here is the question: ' + query,
|
177 |
+
# },
|
178 |
+
# ],
|
179 |
+
# temperature=self.model_config["llm"]["temperature"],
|
180 |
+
# # max_tokens=self.model_config["llm"]["max_tokens"],
|
181 |
+
# )
|
182 |
+
try:
|
183 |
+
# query_result = self.parse_json_output(completion.choices[0].message.content)
|
184 |
+
query_result = self.parse_json_output(result.final_output.model_dump_json())
|
185 |
+
# query_result = self.parse_json_output(completion.model_dump_json())
|
186 |
+
except Exception as e:
|
187 |
+
# print(completion.choices[0].message.content)
|
188 |
+
logger.error(f"Failed to parse JSON output: {e}")
|
189 |
+
return query_result
|
190 |
+
|
191 |
+
def parse_json_output(self, output: str) -> Any:
|
192 |
+
"""Take a string output and parse it as JSON"""
|
193 |
+
# First try to load the string as JSON
|
194 |
+
try:
|
195 |
+
return json.loads(output)
|
196 |
+
except json.JSONDecodeError as e:
|
197 |
+
logger.info(f"Output is not valid JSON: {output}")
|
198 |
+
logger.error(f"Failed to parse output as direct JSON: {e}")
|
199 |
+
|
200 |
+
# If that fails, assume that the output is in a code block - remove the code block markers and try again
|
201 |
+
parsed_output = output
|
202 |
+
if "```" in parsed_output:
|
203 |
+
try:
|
204 |
+
parts = parsed_output.split("```")
|
205 |
+
if len(parts) >= 3:
|
206 |
+
parsed_output = parts[1]
|
207 |
+
if parsed_output.startswith("json") or parsed_output.startswith(
|
208 |
+
"JSON"
|
209 |
+
):
|
210 |
+
parsed_output = parsed_output[4:].strip()
|
211 |
+
return json.loads(parsed_output)
|
212 |
+
except (IndexError, json.JSONDecodeError) as e:
|
213 |
+
logger.error(f"Failed to parse output from code block: {e}")
|
214 |
+
|
215 |
+
# As a last attempt, try to manually find the JSON object in the output and parse it
|
216 |
+
parsed_output = self.find_json_in_string(output)
|
217 |
+
if parsed_output:
|
218 |
+
try:
|
219 |
+
return json.loads(parsed_output)
|
220 |
+
except json.JSONDecodeError as e:
|
221 |
+
logger.error(f"Failed to parse extracted JSON: {e}")
|
222 |
+
logger.error(f"Extracted JSON: {parsed_output}")
|
223 |
+
return {"queries": []}
|
224 |
+
else:
|
225 |
+
logger.error("No valid JSON found in the output:{output}")
|
226 |
+
# If all fails, raise an error
|
227 |
+
return {"queries": []}
|
228 |
+
|
229 |
+
def find_json_in_string(self, string: str) -> str:
|
230 |
+
"""
|
231 |
+
Method to extract all text in the left-most brace that appears in a string.
|
232 |
+
Used to extract JSON from a string (note that this function does not validate the JSON).
|
233 |
+
|
234 |
+
Example:
|
235 |
+
string = "bla bla bla {this is {some} text{{}and it's sneaky}} because {it's} confusing"
|
236 |
+
output = "{this is {some} text{{}and it's sneaky}}"
|
237 |
+
"""
|
238 |
+
stack = 0
|
239 |
+
start_index = None
|
240 |
+
|
241 |
+
for i, c in enumerate(string):
|
242 |
+
if c == "{":
|
243 |
+
if stack == 0:
|
244 |
+
start_index = i # Start index of the first '{'
|
245 |
+
stack += 1 # Push to stack
|
246 |
+
elif c == "}":
|
247 |
+
stack -= 1 # Pop stack
|
248 |
+
if stack == 0:
|
249 |
+
# Return the substring from the start of the first '{' to the current '}'
|
250 |
+
return (
|
251 |
+
string[start_index : i + 1] if start_index is not None else ""
|
252 |
+
)
|
253 |
+
|
254 |
+
# If no complete set of braces is found, return an empty string
|
255 |
+
return ""
|
python-services/Retrieve/bio_requests/chat_request.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
|
3 |
+
|
4 |
+
class ChatRequest(BaseModel):
|
5 |
+
query: str = Field(default="", description="Search query")
|
6 |
+
|
7 |
+
is_web: bool = Field(
|
8 |
+
default=False, description="Whether to use web search, default is False"
|
9 |
+
)
|
10 |
+
|
11 |
+
is_pubmed: bool = Field(
|
12 |
+
default=True, description="Whether to use pubmed search, default is True"
|
13 |
+
)
|
14 |
+
|
15 |
+
language: str = Field(
|
16 |
+
default="en", description="Response language (zh/en), default is English"
|
17 |
+
)
|
python-services/Retrieve/bio_requests/rag_request.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
RAG request class, used to encapsulate the parameters of RAG requests
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import List, Optional
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
|
9 |
+
class RagRequest(BaseModel):
|
10 |
+
"""
|
11 |
+
RAG request class, used to encapsulate the parameters of RAG requests
|
12 |
+
"""
|
13 |
+
|
14 |
+
query: str = Field(default="", description="Search query")
|
15 |
+
|
16 |
+
top_k: int = Field(default=5, ge=1, description="Number of results to return")
|
17 |
+
|
18 |
+
search_type: Optional[str] = Field(
|
19 |
+
default="keyword",
|
20 |
+
description="Type of search to perform (keyword or advanced), please note that if data_source is not ['pubmed'], this field will be ignored",
|
21 |
+
)
|
22 |
+
|
23 |
+
is_rewrite: Optional[bool] = Field(
|
24 |
+
default=True, description="Whether the query is a subquery of a larger query"
|
25 |
+
)
|
26 |
+
|
27 |
+
data_source: List[str] = Field(
|
28 |
+
default=["pubmed"],
|
29 |
+
description="Data source to search in (e.g., pubmed, web)",
|
30 |
+
)
|
31 |
+
|
32 |
+
pubmed_topk: int = Field(
|
33 |
+
default=30,
|
34 |
+
description="Number of results to return from one specific pubmed search, only used when is_rewrite is True",
|
35 |
+
)
|
36 |
+
|
37 |
+
is_rerank: Optional[bool] = Field(
|
38 |
+
default=True,
|
39 |
+
description="Whether to use reranker to rerank the results, only used when data_source is ['pubmed']",
|
40 |
+
)
|
41 |
+
|
42 |
+
language: Optional[str] = Field(
|
43 |
+
default="en", description="Response language (zh/en), default is English"
|
44 |
+
)
|
python-services/Retrieve/config/2023JCR(完整).xlsx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:346311258d5c7843558c36d874a95a1603ff9f38c5ec32c9b58e93f41f71b023
|
3 |
+
size 1922687
|
python-services/Retrieve/config/app_config_dev.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
qa-llm:
|
2 |
+
main:
|
3 |
+
model: deepseek-r1
|
4 |
+
api_key: sk-sk-*************
|
5 |
+
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
6 |
+
max_tokens: 1024
|
7 |
+
temperature: 0.7
|
8 |
+
backup:
|
9 |
+
model: qwen-plus-latest
|
10 |
+
api_key: sk-sk-*************
|
11 |
+
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
12 |
+
max_tokens: 1024
|
13 |
+
temperature: 0.7
|
14 |
+
|
15 |
+
rewrite-llm:
|
16 |
+
backup:
|
17 |
+
model: gpt-4o
|
18 |
+
api_key: sk-**********
|
19 |
+
base_url: https://openai.sohoyo.io/v1
|
20 |
+
max_tokens: 1024
|
21 |
+
temperature: 0.7
|
22 |
+
main:
|
23 |
+
model: qwen-plus-latest
|
24 |
+
api_key: sk-sk-*************
|
25 |
+
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
26 |
+
max_tokens: 1024
|
27 |
+
temperature: 0.7
|
28 |
+
|
29 |
+
recall:
|
30 |
+
pubmed_topk: 30
|
31 |
+
es_topk: 30
|
32 |
+
|
33 |
+
qa-topk:
|
34 |
+
personal_vector: 40
|
35 |
+
pubmed: 10
|
36 |
+
web: 5
|
37 |
+
|
38 |
+
qa-prompt-max-token:
|
39 |
+
max_tokens: 120000
|
40 |
+
|
41 |
+
|
42 |
+
chat:
|
43 |
+
rag_prompt: |
|
44 |
+
# The following contents are the search results related to the user's message:
|
45 |
+
{search_results}
|
46 |
+
In the search results I provide to you, each result is formatted as [document X begin]...[document X end], where X represents the numerical index of each article.
|
47 |
+
When responding, please keep the following points in mind:
|
48 |
+
- Today is {cur_date}.
|
49 |
+
- Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
|
50 |
+
- If all the search results are irrelevant, please answer the question by yourself professionally and concisely.
|
51 |
+
- The search results may focus only on a few points, use the information it provided, but do not favor those points in your answer, reason and answer by yourself all-sidedly with full consideration.
|
52 |
+
- For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
|
53 |
+
- If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
|
54 |
+
- For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
|
55 |
+
- Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
|
56 |
+
- Your answer should synthesize information from multiple relevant documents.
|
57 |
+
- Unless the user requests otherwise, your response should be in the same language as the user's question.
|
58 |
+
# The user's message is:
|
59 |
+
{question}
|
60 |
+
- The content should be concise and direct, and you MUST include proper citations using ONLY "[bdd-rag-citation:X]" format reference marks to indicate the sources of your information. Do NOT use any other citation formats such as [document X], [Author, Year], or parenthetical bibliographical references.
|
python-services/Retrieve/config/global_storage.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""全局配置存储模块,提供配置文件的加载和缓存功能,API密钥和base_url从环境变量加载。"""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import Any, Dict, Optional
|
5 |
+
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
|
9 |
+
class ConfigManager:
|
10 |
+
"""配置管理器,使用单例模式缓存配置,API密钥和base_url从环境变量加载。"""
|
11 |
+
|
12 |
+
_instance = None
|
13 |
+
_config: Optional[Dict[str, Any]] = None
|
14 |
+
|
15 |
+
def __new__(cls):
|
16 |
+
if cls._instance is None:
|
17 |
+
cls._instance = super().__new__(cls)
|
18 |
+
return cls._instance
|
19 |
+
|
20 |
+
def get_config(self) -> Dict[str, Any]:
|
21 |
+
"""获取配置,如果未加载则自动加载。
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
包含所有配置信息的字典
|
25 |
+
"""
|
26 |
+
if self._config is None:
|
27 |
+
self._config = self._load_config()
|
28 |
+
return self._config
|
29 |
+
|
30 |
+
def _get_environment(self) -> str:
|
31 |
+
"""获取当前环境类型。
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
环境类型:'prod' 或 'dev'
|
35 |
+
"""
|
36 |
+
return os.getenv("ENVIRONMENT", "dev").lower()
|
37 |
+
|
38 |
+
def _get_config_path(self) -> str:
|
39 |
+
"""根据环境获取配置文件路径。
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
配置文件路径
|
43 |
+
"""
|
44 |
+
env = self._get_environment()
|
45 |
+
if env == "prod":
|
46 |
+
return "config/app_config_prod.yaml"
|
47 |
+
|
48 |
+
return "config/app_config_dev.yaml"
|
49 |
+
|
50 |
+
def _load_config(self) -> Dict[str, Any]:
|
51 |
+
"""加载配置文件,并覆盖API密钥和base_url为环境变量值。
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
从YAML文件加载的配置字典,API密钥和base_url从环境变量覆盖
|
55 |
+
"""
|
56 |
+
config_path = self._get_config_path()
|
57 |
+
try:
|
58 |
+
with open(config_path, "r", encoding="utf-8") as file:
|
59 |
+
config = yaml.safe_load(file)
|
60 |
+
# 添加环境信息到配置中
|
61 |
+
config["environment"] = self._get_environment()
|
62 |
+
|
63 |
+
# 从环境变量覆盖API密钥和base_url
|
64 |
+
self._override_api_configs(config)
|
65 |
+
|
66 |
+
return config
|
67 |
+
except FileNotFoundError as exc:
|
68 |
+
raise FileNotFoundError(f"配置文件未找到: {config_path}") from exc
|
69 |
+
except yaml.YAMLError as exc:
|
70 |
+
raise ValueError(f"配置文件格式错误: {exc}") from exc
|
71 |
+
|
72 |
+
def _override_api_configs(self, config: Dict[str, Any]) -> None:
|
73 |
+
"""从环境变量覆盖API密钥和base_url配置。
|
74 |
+
|
75 |
+
Args:
|
76 |
+
config: 配置字典
|
77 |
+
"""
|
78 |
+
# QA LLM 主模型
|
79 |
+
if "qa-llm" in config and "main" in config["qa-llm"]:
|
80 |
+
main_config = config["qa-llm"]["main"]
|
81 |
+
if os.getenv("QA_LLM_MAIN_API_KEY"):
|
82 |
+
main_config["api_key"] = os.getenv("QA_LLM_MAIN_API_KEY")
|
83 |
+
if os.getenv("QA_LLM_MAIN_BASE_URL"):
|
84 |
+
main_config["base_url"] = os.getenv("QA_LLM_MAIN_BASE_URL")
|
85 |
+
|
86 |
+
# QA LLM 备用模型
|
87 |
+
if "qa-llm" in config and "backup" in config["qa-llm"]:
|
88 |
+
backup_config = config["qa-llm"]["backup"]
|
89 |
+
if os.getenv("QA_LLM_BACKUP_API_KEY"):
|
90 |
+
backup_config["api_key"] = os.getenv("QA_LLM_BACKUP_API_KEY")
|
91 |
+
if os.getenv("QA_LLM_BACKUP_BASE_URL"):
|
92 |
+
backup_config["base_url"] = os.getenv("QA_LLM_BACKUP_BASE_URL")
|
93 |
+
|
94 |
+
# Rewrite LLM 备用模型 (GPT-4o)
|
95 |
+
if "rewrite-llm" in config and "backup" in config["rewrite-llm"]:
|
96 |
+
backup_config = config["rewrite-llm"]["backup"]
|
97 |
+
if os.getenv("REWRITE_LLM_BACKUP_API_KEY"):
|
98 |
+
backup_config["api_key"] = os.getenv("REWRITE_LLM_BACKUP_API_KEY")
|
99 |
+
if os.getenv("REWRITE_LLM_BACKUP_BASE_URL"):
|
100 |
+
backup_config["base_url"] = os.getenv("REWRITE_LLM_BACKUP_BASE_URL")
|
101 |
+
|
102 |
+
# Rewrite LLM 主模型
|
103 |
+
if "rewrite-llm" in config and "main" in config["rewrite-llm"]:
|
104 |
+
main_config = config["rewrite-llm"]["main"]
|
105 |
+
if os.getenv("REWRITE_LLM_MAIN_API_KEY"):
|
106 |
+
main_config["api_key"] = os.getenv("REWRITE_LLM_MAIN_API_KEY")
|
107 |
+
if os.getenv("REWRITE_LLM_MAIN_BASE_URL"):
|
108 |
+
main_config["base_url"] = os.getenv("REWRITE_LLM_MAIN_BASE_URL")
|
109 |
+
|
110 |
+
|
111 |
+
# 全局配置管理器实例
|
112 |
+
_config_manager = ConfigManager()
|
113 |
+
|
114 |
+
|
115 |
+
def get_model_config() -> Dict[str, Any]:
|
116 |
+
"""获取模型配置。
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
包含所有配置信息的字典
|
120 |
+
"""
|
121 |
+
return _config_manager.get_config()
|
python-services/Retrieve/dto/bio_document.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Optional
|
3 |
+
from utils.snowflake_id import snowflake_id_str
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class BaseBioDocument:
|
8 |
+
"""
|
9 |
+
生物医学文档基础类
|
10 |
+
包含所有搜索类型共有的字段
|
11 |
+
"""
|
12 |
+
|
13 |
+
bio_id: Optional[str] = field(default_factory=snowflake_id_str)
|
14 |
+
title: Optional[str] = None
|
15 |
+
text: Optional[str] = None
|
16 |
+
source: Optional[str] = None
|
17 |
+
source_id: Optional[str] = None
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class PubMedDocument(BaseBioDocument):
|
22 |
+
"""
|
23 |
+
PubMed学术文献文档
|
24 |
+
包含学术文献特有的字段
|
25 |
+
"""
|
26 |
+
|
27 |
+
abstract: Optional[str] = None
|
28 |
+
authors: Optional[str] = None
|
29 |
+
doi: Optional[str] = None
|
30 |
+
journal: Optional[str] = None
|
31 |
+
pub_date: Optional[str] = None
|
32 |
+
if_score: Optional[float] = None
|
33 |
+
url: Optional[str] = None
|
34 |
+
|
35 |
+
def __post_init__(self):
|
36 |
+
if self.source is None:
|
37 |
+
self.source = "pubmed"
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class PersonalDocument(BaseBioDocument):
|
42 |
+
"""
|
43 |
+
个人向量搜索文档
|
44 |
+
包含个人文档特有的字段
|
45 |
+
"""
|
46 |
+
|
47 |
+
if_score: Optional[float] = None
|
48 |
+
doc_id: Optional[str] = None
|
49 |
+
index: Optional[int] = 0
|
50 |
+
user_id: Optional[str] = None
|
51 |
+
file_name: Optional[str] = None
|
52 |
+
|
53 |
+
def __post_init__(self):
|
54 |
+
if self.source is None:
|
55 |
+
self.source = "personal_vector"
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class WebDocument(BaseBioDocument):
|
60 |
+
"""
|
61 |
+
Web搜索文档
|
62 |
+
包含网页内容特有的字段
|
63 |
+
"""
|
64 |
+
|
65 |
+
url: Optional[str] = None
|
66 |
+
description: Optional[str] = None
|
67 |
+
|
68 |
+
def __post_init__(self):
|
69 |
+
if self.source is None:
|
70 |
+
self.source = "web"
|
71 |
+
|
72 |
+
|
73 |
+
# 为了保持向后兼容,保留原有的BioDocument类
|
74 |
+
@dataclass
|
75 |
+
class BioDocument(BaseBioDocument):
|
76 |
+
"""
|
77 |
+
生物医学文档(向后兼容)
|
78 |
+
包含所有可能的字段,但建议使用专门的文档类型
|
79 |
+
"""
|
80 |
+
|
81 |
+
abstract: Optional[str] = None
|
82 |
+
authors: Optional[str] = None
|
83 |
+
doi: Optional[str] = None
|
84 |
+
journal: Optional[str] = None
|
85 |
+
pub_date: Optional[str] = None
|
86 |
+
if_score: Optional[float] = None
|
87 |
+
url: Optional[str] = None
|
88 |
+
doc_id: Optional[str] = None
|
89 |
+
|
90 |
+
|
91 |
+
# 工厂函数,根据source类型创建相应的文档对象
|
92 |
+
def create_bio_document(source: str, **kwargs) -> BaseBioDocument:
|
93 |
+
"""
|
94 |
+
根据source类型创建相应的文档对象
|
95 |
+
|
96 |
+
Args:
|
97 |
+
source: 文档来源类型 ("pubmed", "personal_vector", "web")
|
98 |
+
**kwargs: 文档字段
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
相应的文档对象
|
102 |
+
"""
|
103 |
+
if source == "pubmed":
|
104 |
+
return PubMedDocument(**kwargs)
|
105 |
+
elif source == "personal_vector":
|
106 |
+
return PersonalDocument(**kwargs)
|
107 |
+
elif source == "web":
|
108 |
+
return WebDocument(**kwargs)
|
109 |
+
else:
|
110 |
+
# 默认使用通用BioDocument
|
111 |
+
return BioDocument(**kwargs)
|
python-services/Retrieve/main.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""生物医学RAG服务主程序入口。"""
|
2 |
+
|
3 |
+
import importlib
|
4 |
+
import pkgutil
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
# 加载环境变量
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
import uvicorn
|
13 |
+
from asgi_correlation_id import CorrelationIdMiddleware, correlation_id
|
14 |
+
from fastapi import FastAPI, Request
|
15 |
+
from fastapi_mcp import FastApiMCP
|
16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
17 |
+
|
18 |
+
from routers import sensor, mcp_sensor
|
19 |
+
from utils.bio_logger import bio_logger as logger
|
20 |
+
|
21 |
+
# 调试:验证环境变量是否加载
|
22 |
+
logger.info(f"SERPER_API_KEY loaded: {'Yes' if os.getenv('SERPER_API_KEY') else 'No'}")
|
23 |
+
|
24 |
+
|
25 |
+
app = FastAPI(
|
26 |
+
docs_url=None, # 关闭 Swagger UI 文档
|
27 |
+
redoc_url=None, # 关闭 ReDoc 文档
|
28 |
+
openapi_url=None, # 关闭 OpenAPI 规范文件
|
29 |
+
debug=False, # 关闭调试模式
|
30 |
+
)
|
31 |
+
|
32 |
+
# 第一个添加的中间件
|
33 |
+
app.add_middleware(CorrelationIdMiddleware)
|
34 |
+
# 配置CORS
|
35 |
+
app.add_middleware(
|
36 |
+
CORSMiddleware,
|
37 |
+
allow_origins=["*"],
|
38 |
+
allow_credentials=True,
|
39 |
+
allow_methods=["*"],
|
40 |
+
allow_headers=["*"],
|
41 |
+
)
|
42 |
+
|
43 |
+
# 路由
|
44 |
+
app.include_router(sensor.router)
|
45 |
+
app.include_router(mcp_sensor.router) # 包含 MCP 路由
|
46 |
+
|
47 |
+
|
48 |
+
@app.middleware("http")
|
49 |
+
async def add_process_time_header(request: Request, call_next):
|
50 |
+
"""HTTP中间件,记录请求处理时间和状态。"""
|
51 |
+
start_time = time.time()
|
52 |
+
|
53 |
+
logger.info(f"Request started | URL: {request.url}")
|
54 |
+
|
55 |
+
response = await call_next(request)
|
56 |
+
process_time = time.time() - start_time
|
57 |
+
|
58 |
+
logger.info(
|
59 |
+
f"Request completed | "
|
60 |
+
f"Status: {response.status_code} | "
|
61 |
+
f"Time: {process_time:.2f}s"
|
62 |
+
)
|
63 |
+
|
64 |
+
return response
|
65 |
+
|
66 |
+
|
67 |
+
def dynamic_import_subclasses(parent_dir: str) -> None:
|
68 |
+
"""动态导入指定目录下的所有Python模块。
|
69 |
+
|
70 |
+
Args:
|
71 |
+
parent_dir: 要导入的目录路径
|
72 |
+
"""
|
73 |
+
for _, module_name, _ in pkgutil.iter_modules([parent_dir]):
|
74 |
+
module = importlib.import_module(f"{parent_dir}.{module_name}")
|
75 |
+
logger.info(f"Imported: {module.__name__}")
|
76 |
+
|
77 |
+
|
78 |
+
# Add MCP server to the FastAPI app
|
79 |
+
mcp = FastApiMCP(
|
80 |
+
app,
|
81 |
+
name="bio qa mcp",
|
82 |
+
include_operations=["bio_qa_stream_chat"]
|
83 |
+
)
|
84 |
+
|
85 |
+
# Mount the MCP server to the FastAPI app
|
86 |
+
# 挂载SSE端点到 /mcp/sse
|
87 |
+
mcp.mount_sse()
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
logger.info("Starting Bio RAG Server...")
|
92 |
+
dynamic_import_subclasses("search_service")
|
93 |
+
uvicorn.run(app, host="0.0.0.0", port=9487)
|
python-services/Retrieve/readme.md
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bio RAG Server
|
2 |
+
|
3 |
+
一个基于FastAPI的生物医学检索增强生成(RAG)服务,支持PubMed文献检索、Web搜索和向量数据库查询,提供智能问答和文档检索功能。
|
4 |
+
|
5 |
+
## 🚀 功能特性
|
6 |
+
|
7 |
+
- **多源数据检索**: 支持PubMed、Web搜索、个人向量数据库等多种数据源
|
8 |
+
- **智能问答**: 基于大语言模型的RAG问答,支持流式响应
|
9 |
+
- **查询重写**: 智能查询拆分和重写,提高检索精度
|
10 |
+
- **主备切换**: 支持LLM服务的主备配置,自动故障转移
|
11 |
+
- **流式响应**: 实时流式聊天响应,提升用户体验
|
12 |
+
- **国际化支持**: 支持中英文切换,包含87个国际化消息,涵盖8种消息类型
|
13 |
+
- **日志追踪**: 完整的请求追踪和日志记录
|
14 |
+
- **CORS支持**: 跨域请求支持,便于前端集成
|
15 |
+
|
16 |
+
## 🏗️ 系统架构
|
17 |
+
|
18 |
+
```
|
19 |
+
bio_rag_server/
|
20 |
+
├── bio_agent/ # AI代理相关
|
21 |
+
├── bio_requests/ # 请求模型定义
|
22 |
+
├── config/ # 配置文件
|
23 |
+
├── dto/ # 数据传输对象
|
24 |
+
├── routers/ # API路由
|
25 |
+
├── search_service/ # 搜索服务
|
26 |
+
├── service/ # 核心业务服务
|
27 |
+
├── utils/ # 工具类
|
28 |
+
└── test/ # 测试文件
|
29 |
+
```
|
30 |
+
|
31 |
+
## 📋 环境要求
|
32 |
+
|
33 |
+
- Python 3.8+
|
34 |
+
- OpenAI API 或兼容的LLM服务
|
35 |
+
|
36 |
+
## 🛠️ 安装部署
|
37 |
+
|
38 |
+
### 1. 克隆项目
|
39 |
+
|
40 |
+
```bash
|
41 |
+
git clone <repository-url>
|
42 |
+
cd bio_rag_server-1
|
43 |
+
```
|
44 |
+
|
45 |
+
### 2. 安装依赖
|
46 |
+
|
47 |
+
```bash
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
### 3. 配置环境
|
52 |
+
|
53 |
+
复制并修改配置文件 `config/app_config.yaml`:
|
54 |
+
|
55 |
+
```yaml
|
56 |
+
|
57 |
+
llm:
|
58 |
+
model: gpt-4o
|
59 |
+
api_key: your-openai-api-key
|
60 |
+
base_url: https://api.openai.com/v1
|
61 |
+
max_tokens: 1024
|
62 |
+
temperature: 0.7
|
63 |
+
|
64 |
+
qa-llm:
|
65 |
+
main:
|
66 |
+
model: deepseek-r1
|
67 |
+
api_key: your-main-api-key
|
68 |
+
base_url: https://your-main-endpoint/v1
|
69 |
+
max_tokens: 1024
|
70 |
+
temperature: 0.7
|
71 |
+
backup:
|
72 |
+
model: qwen-plus-latest
|
73 |
+
api_key: your-backup-api-key
|
74 |
+
base_url: https://your-backup-endpoint/v1
|
75 |
+
max_tokens: 1024
|
76 |
+
temperature: 0.7
|
77 |
+
```
|
78 |
+
|
79 |
+
### 4. 启动服务
|
80 |
+
|
81 |
+
```bash
|
82 |
+
python main.py
|
83 |
+
```
|
84 |
+
|
85 |
+
或使用Docker:
|
86 |
+
|
87 |
+
```bash
|
88 |
+
docker build -t bio-rag-server .
|
89 |
+
docker run -p 9487:9487 bio-rag-server
|
90 |
+
```
|
91 |
+
|
92 |
+
服务将在 `http://localhost:9487` 启动。
|
93 |
+
|
94 |
+
## 📚 API 文档
|
95 |
+
|
96 |
+
### 1. 文档检索 API
|
97 |
+
|
98 |
+
**端点**: `POST /retrieve`
|
99 |
+
|
100 |
+
**请求体**:
|
101 |
+
```json
|
102 |
+
{
|
103 |
+
"query": "cancer treatment",
|
104 |
+
"top_k": 5,
|
105 |
+
"search_type": "keyword",
|
106 |
+
"is_rewrite": true,
|
107 |
+
"data_source": ["pubmed"],
|
108 |
+
"user_id": "user123",
|
109 |
+
"pubmed_topk": 30
|
110 |
+
}
|
111 |
+
```
|
112 |
+
|
113 |
+
**响应**:
|
114 |
+
```json
|
115 |
+
[
|
116 |
+
{
|
117 |
+
"title": "Cancer Treatment Advances",
|
118 |
+
"abstract": "Recent advances in cancer treatment...",
|
119 |
+
"url": "https://pubmed.ncbi.nlm.nih.gov/...",
|
120 |
+
"score": 0.95
|
121 |
+
}
|
122 |
+
]
|
123 |
+
```
|
124 |
+
|
125 |
+
### 2. 流式聊天 API
|
126 |
+
|
127 |
+
**端点**: `POST /stream-chat`
|
128 |
+
|
129 |
+
**请求体**:
|
130 |
+
```json
|
131 |
+
{
|
132 |
+
"query": "What are the latest treatments for breast cancer?",
|
133 |
+
"is_web": true,
|
134 |
+
"is_pubmed": true,
|
135 |
+
"language": "en" // 可选:响应语言 (zh/en)
|
136 |
+
}
|
137 |
+
```
|
138 |
+
|
139 |
+
**响应**: Server-Sent Events (SSE) 流式响应
|
140 |
+
|
141 |
+
### 3. 国际化支持
|
142 |
+
|
143 |
+
所有API接口都支持国际化,通过 `language` 参数指定响应语言:
|
144 |
+
|
145 |
+
- `zh` (默认): 中文响应
|
146 |
+
- `en`: 英文响应
|
147 |
+
|
148 |
+
**响应格式示例**:
|
149 |
+
```json
|
150 |
+
{
|
151 |
+
"success": true,
|
152 |
+
"data": [...],
|
153 |
+
"message": "搜索成功", // 或 "Search successful"
|
154 |
+
"language": "zh"
|
155 |
+
}
|
156 |
+
```
|
157 |
+
|
158 |
+
**错误响应格式**:
|
159 |
+
```json
|
160 |
+
{
|
161 |
+
"success": false,
|
162 |
+
"error": {
|
163 |
+
"code": 500,
|
164 |
+
"message": "搜索失败", // 或 "Search failed"
|
165 |
+
"language": "zh",
|
166 |
+
"details": "具体错误信息"
|
167 |
+
}
|
168 |
+
}
|
169 |
+
```
|
170 |
+
|
171 |
+
## 🔧 配置说明
|
172 |
+
|
173 |
+
### 数据源配置
|
174 |
+
|
175 |
+
- **pubmed**: PubMed文献数据库
|
176 |
+
- **web**: Web搜索
|
177 |
+
|
178 |
+
|
179 |
+
### LLM配置
|
180 |
+
|
181 |
+
支持主备配置,当主配置失败时自动切换到备用配置:
|
182 |
+
|
183 |
+
```yaml
|
184 |
+
qa-llm:
|
185 |
+
main:
|
186 |
+
model: deepseek-r1
|
187 |
+
api_key: main-api-key
|
188 |
+
base_url: main-endpoint
|
189 |
+
backup:
|
190 |
+
model: qwen-plus-latest
|
191 |
+
api_key: backup-api-key
|
192 |
+
base_url: backup-endpoint
|
193 |
+
```
|
194 |
+
|
195 |
+
## 🧪 测试
|
196 |
+
|
197 |
+
### 基本功能测试
|
198 |
+
|
199 |
+
运行测试用例:
|
200 |
+
|
201 |
+
```bash
|
202 |
+
cd test
|
203 |
+
python client.py
|
204 |
+
```
|
205 |
+
|
206 |
+
### 国际化功能测试
|
207 |
+
|
208 |
+
```bash
|
209 |
+
# 基本国际化功能测试
|
210 |
+
python test/test_i18n.py
|
211 |
+
|
212 |
+
# Label国际化功能测试
|
213 |
+
python test/test_label_i18n.py
|
214 |
+
|
215 |
+
# 新的消息文件结构测试
|
216 |
+
python test/test_i18n_messages.py
|
217 |
+
|
218 |
+
# 运行客户端测试示例
|
219 |
+
python test/client_test.py
|
220 |
+
```
|
221 |
+
|
222 |
+
### 使用示例
|
223 |
+
|
224 |
+
```python
|
225 |
+
import requests
|
226 |
+
|
227 |
+
# 中文检索
|
228 |
+
response_zh = requests.post("http://localhost:9487/retrieve", json={
|
229 |
+
"query": "人工智能",
|
230 |
+
"language": "zh"
|
231 |
+
})
|
232 |
+
|
233 |
+
# 英文检索
|
234 |
+
response_en = requests.post("http://localhost:9487/retrieve", json={
|
235 |
+
"query": "artificial intelligence",
|
236 |
+
"language": "en"
|
237 |
+
})
|
238 |
+
```
|
239 |
+
|
240 |
+
## 📊 监控和日志
|
241 |
+
|
242 |
+
- 日志文件位置: `logs/bio_rag_YYYY-MM-DD.log`
|
243 |
+
- 请求追踪: 每个请求都有唯一的correlation_id
|
244 |
+
- 性能监控: 自动记录请求处理时间
|
245 |
+
|
246 |
+
## 🔒 安全特性
|
247 |
+
|
248 |
+
- API密钥配置化管理
|
249 |
+
- 请求日志记录
|
250 |
+
- CORS配置
|
251 |
+
- 错误处理和安全异常
|
252 |
+
|
253 |
+
## 🤝 贡献指南
|
254 |
+
|
255 |
+
1. Fork 项目
|
256 |
+
2. 创建功能分支 (`git checkout -b feature/AmazingFeature`)
|
257 |
+
3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
|
258 |
+
4. 推送到分支 (`git push origin feature/AmazingFeature`)
|
259 |
+
5. 打开 Pull Request
|
260 |
+
|
261 |
+
## 📄 许可证
|
262 |
+
|
263 |
+
本项目采用 MIT 许可证 - 查看 [LICENSE](LICENSE) 文件了解详情。
|
264 |
+
|
265 |
+
## 🆘 支持
|
266 |
+
|
267 |
+
如有问题或建议,请:
|
268 |
+
|
269 |
+
1. 查看 [Issues](../../issues) 页面
|
270 |
+
2. 创建新的 Issue
|
271 |
+
3. 联系项目维护者
|
272 |
+
|
273 |
+
## 🗺️ 路线图
|
274 |
+
|
275 |
+
- [ ] 支持更多数据源
|
276 |
+
- [ ] 增加用户认证和权限管理
|
277 |
+
- [ ] 优化向量搜索性能
|
278 |
+
- [ ] 添加更多LLM模型支持
|
279 |
+
- [ ] 实现缓存机制
|
280 |
+
- [ ] 增加API限流功能
|
281 |
+
|
282 |
+
---
|
283 |
+
|
284 |
+
**注意**: 请确保在使用前正确配置所有必要的API密钥和服务端点。
|
python-services/Retrieve/requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
asgi_correlation_id==4.3.4
|
2 |
+
fastapi==0.115.12
|
3 |
+
uvicorn==0.34.0
|
4 |
+
loguru==0.7.3
|
5 |
+
pyyaml==6.0.2
|
6 |
+
httpx==0.28.1
|
7 |
+
requests==2.32.3
|
8 |
+
biopython==1.85
|
9 |
+
openpyxl==3.1.5
|
10 |
+
openai==1.86.0
|
11 |
+
openai-agents==0.0.17
|
12 |
+
pandas==2.2.3
|
13 |
+
pymilvus==2.5.8
|
14 |
+
crawl4ai==0.7.0
|
15 |
+
aiohttp==3.11.18
|
16 |
+
beautifulsoup4==4.12.3
|
17 |
+
tiktoken==0.9.0
|
18 |
+
fastapi-mcp==0.4.0
|
19 |
+
python-dotenv
|
python-services/Retrieve/routers/mcp_sensor.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from asgi_correlation_id import correlation_id
|
2 |
+
from fastapi import APIRouter
|
3 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
4 |
+
|
5 |
+
from utils.bio_logger import bio_logger as logger
|
6 |
+
from utils.i18n_util import (
|
7 |
+
get_language,
|
8 |
+
create_error_response,
|
9 |
+
)
|
10 |
+
from utils.i18n_context import with_language
|
11 |
+
|
12 |
+
from bio_requests.chat_request import ChatRequest
|
13 |
+
|
14 |
+
from service.chat import ChatService
|
15 |
+
|
16 |
+
router = APIRouter(prefix="/mcp", tags=["MCP"])
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/bio_qa", response_model=None, operation_id="bio_qa_stream_chat")
|
20 |
+
async def bio_qa(query: str, lang: str = "en"):
|
21 |
+
"""
|
22 |
+
生物医学问答接口,提供RAG问答服务。
|
23 |
+
query: 问答内容
|
24 |
+
lang: 语言设置,zh代表中文,en代表英文
|
25 |
+
"""
|
26 |
+
|
27 |
+
logger.info(f"{correlation_id.get()} Bio QA for {query}")
|
28 |
+
chat_request = ChatRequest(query=query, language=lang)
|
29 |
+
# 解析语言设置
|
30 |
+
language = get_language(chat_request.language)
|
31 |
+
|
32 |
+
# 使用上下文管理器设置语言
|
33 |
+
with with_language(language):
|
34 |
+
try:
|
35 |
+
chat_service = ChatService()
|
36 |
+
return StreamingResponse(
|
37 |
+
chat_service.generate_stream(chat_request),
|
38 |
+
media_type="text/event-stream",
|
39 |
+
headers={
|
40 |
+
"Connection": "keep-alive",
|
41 |
+
"Cache-Control": "no-cache",
|
42 |
+
},
|
43 |
+
)
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(f"{correlation_id.get()} Stream chat error: {e}")
|
46 |
+
error_response = create_error_response(
|
47 |
+
error_key="service_unavailable",
|
48 |
+
details=str(e),
|
49 |
+
error_code=500,
|
50 |
+
)
|
51 |
+
return JSONResponse(content=error_response, status_code=500)
|
52 |
+
|
53 |
+
|
54 |
+
# 添加MCP协议所需的端点
|
55 |
+
@router.get("/tools")
|
56 |
+
async def list_tools():
|
57 |
+
"""列出可用的MCP工具"""
|
58 |
+
return {
|
59 |
+
"tools": [
|
60 |
+
{
|
61 |
+
"name": "bio_qa_stream_chat",
|
62 |
+
"description": "生物医学问答服务,提供RAG问答功能",
|
63 |
+
"inputSchema": {
|
64 |
+
"type": "object",
|
65 |
+
"properties": {
|
66 |
+
"query": {
|
67 |
+
"type": "string",
|
68 |
+
"description": "问题内容"
|
69 |
+
},
|
70 |
+
"lang": {
|
71 |
+
"type": "string",
|
72 |
+
"description": "语言设置,zh代表中文,en代表英文",
|
73 |
+
"enum": ["zh", "en"],
|
74 |
+
"default": "en"
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"required": ["query"]
|
78 |
+
}
|
79 |
+
}
|
80 |
+
]
|
81 |
+
}
|
python-services/Retrieve/routers/sensor.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""API路由模块"""
|
2 |
+
|
3 |
+
from asgi_correlation_id import correlation_id
|
4 |
+
from fastapi import APIRouter
|
5 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
6 |
+
|
7 |
+
from utils.bio_logger import bio_logger as logger
|
8 |
+
from utils.i18n_util import (
|
9 |
+
get_language,
|
10 |
+
create_success_response,
|
11 |
+
create_error_response,
|
12 |
+
)
|
13 |
+
from utils.i18n_context import with_language
|
14 |
+
from bio_requests.rag_request import RagRequest
|
15 |
+
from bio_requests.chat_request import ChatRequest
|
16 |
+
from service.rag import RagService
|
17 |
+
from service.chat import ChatService
|
18 |
+
|
19 |
+
router = APIRouter()
|
20 |
+
|
21 |
+
|
22 |
+
@router.post("/retrieve")
|
23 |
+
async def search(rag_request: RagRequest) -> JSONResponse:
|
24 |
+
"""文档检索接口,支持多源数据检索。"""
|
25 |
+
|
26 |
+
logger.info(f"{correlation_id.get()} Searching for {rag_request}")
|
27 |
+
|
28 |
+
# 解析语言设置
|
29 |
+
language = get_language(rag_request.language)
|
30 |
+
|
31 |
+
# 使用上下文管理器设置语言
|
32 |
+
with with_language(language):
|
33 |
+
try:
|
34 |
+
rag_assistant = RagService()
|
35 |
+
documents = await rag_assistant.multi_query(rag_request)
|
36 |
+
|
37 |
+
logger.info(f"{correlation_id.get()} Found {len(documents)} documents")
|
38 |
+
results = [document.__dict__ for document in documents]
|
39 |
+
|
40 |
+
# 返回国际化响应
|
41 |
+
response_data = create_success_response(
|
42 |
+
data=results, message_key="search_success"
|
43 |
+
)
|
44 |
+
|
45 |
+
return JSONResponse(content=response_data)
|
46 |
+
|
47 |
+
except Exception as e:
|
48 |
+
logger.error(f"{correlation_id.get()} Search error: {e}")
|
49 |
+
error_response = create_error_response(
|
50 |
+
error_key="search_failed", details=str(e), error_code=500
|
51 |
+
)
|
52 |
+
return JSONResponse(content=error_response, status_code=500)
|
53 |
+
|
54 |
+
|
55 |
+
@router.post("/stream-chat")
|
56 |
+
async def stream_chat(chat_request: ChatRequest):
|
57 |
+
"""流式聊天接口,提供RAG问答服务。"""
|
58 |
+
|
59 |
+
logger.info(f"{correlation_id.get()} Streaming chat for {chat_request}")
|
60 |
+
|
61 |
+
# 解析语言设置
|
62 |
+
language = get_language(chat_request.language)
|
63 |
+
|
64 |
+
# 使用上下文管理器设置语言
|
65 |
+
with with_language(language):
|
66 |
+
try:
|
67 |
+
chat_service = ChatService()
|
68 |
+
return StreamingResponse(
|
69 |
+
chat_service.generate_stream(chat_request),
|
70 |
+
media_type="text/event-stream",
|
71 |
+
headers={
|
72 |
+
"Connection": "keep-alive",
|
73 |
+
"Cache-Control": "no-cache",
|
74 |
+
},
|
75 |
+
)
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"{correlation_id.get()} Stream chat error: {e}")
|
78 |
+
error_response = create_error_response(
|
79 |
+
error_key="service_unavailable",
|
80 |
+
details=str(e),
|
81 |
+
error_code=500,
|
82 |
+
)
|
83 |
+
return JSONResponse(content=error_response, status_code=500)
|
python-services/Retrieve/search_service/base_search.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from bio_requests.rag_request import RagRequest
|
4 |
+
from dto.bio_document import BaseBioDocument
|
5 |
+
|
6 |
+
|
7 |
+
class BaseSearchService:
|
8 |
+
_registry = []
|
9 |
+
|
10 |
+
def __init_subclass__(cls, **kwargs):
|
11 |
+
super().__init_subclass__(**kwargs)
|
12 |
+
BaseSearchService._registry.append(cls)
|
13 |
+
|
14 |
+
@classmethod
|
15 |
+
def get_subclasses(cls):
|
16 |
+
return cls._registry
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
self.data_source = "Base"
|
20 |
+
pass
|
21 |
+
|
22 |
+
async def filter_search(self, rag_request: RagRequest) -> List[BaseBioDocument]:
|
23 |
+
if self.data_source in rag_request.data_source:
|
24 |
+
return await self.search(rag_request)
|
25 |
+
return []
|
26 |
+
|
27 |
+
async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]:
|
28 |
+
return []
|
python-services/Retrieve/search_service/pubmed_search.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
from typing import Dict, List
|
5 |
+
|
6 |
+
from dto.bio_document import BaseBioDocument, create_bio_document
|
7 |
+
from search_service.base_search import BaseSearchService
|
8 |
+
from bio_requests.rag_request import RagRequest
|
9 |
+
from utils.bio_logger import bio_logger as logger
|
10 |
+
|
11 |
+
|
12 |
+
from service.query_rewrite import QueryRewriteService
|
13 |
+
from service.pubmed_api import PubMedApi
|
14 |
+
from service.pubmed_async_api import PubMedAsyncApi
|
15 |
+
from config.global_storage import get_model_config
|
16 |
+
|
17 |
+
|
18 |
+
class PubMedSearchService(BaseSearchService):
|
19 |
+
def __init__(self):
|
20 |
+
self.query_rewrite_service = QueryRewriteService()
|
21 |
+
self.model_config = get_model_config()
|
22 |
+
|
23 |
+
self.pubmed_topk = self.model_config["recall"]["pubmed_topk"]
|
24 |
+
self.es_topk = self.model_config["recall"]["es_topk"]
|
25 |
+
self.data_source = "pubmed"
|
26 |
+
|
27 |
+
async def get_query_list(self, rag_request: RagRequest) -> List[Dict]:
|
28 |
+
"""根据RagRequest获取查询列表"""
|
29 |
+
if rag_request.is_rewrite:
|
30 |
+
query_list = await self.query_rewrite_service.query_split(rag_request.query)
|
31 |
+
logger.info(f"length of query_list after query_split: {len(query_list)}")
|
32 |
+
if len(query_list) == 0:
|
33 |
+
logger.info("query_list is empty, use query_split_for_simple")
|
34 |
+
query_list = await self.query_rewrite_service.query_split_for_simple(
|
35 |
+
rag_request.query
|
36 |
+
)
|
37 |
+
logger.info(
|
38 |
+
f"length of query_list after query_split_for_simple: {len(query_list)}"
|
39 |
+
)
|
40 |
+
self.pubmed_topk = rag_request.pubmed_topk
|
41 |
+
self.es_topk = rag_request.pubmed_topk
|
42 |
+
else:
|
43 |
+
self.pubmed_topk = rag_request.top_k
|
44 |
+
self.es_topk = rag_request.top_k
|
45 |
+
query_list = [
|
46 |
+
{
|
47 |
+
"query_item": rag_request.query,
|
48 |
+
"search_type": rag_request.search_type,
|
49 |
+
}
|
50 |
+
]
|
51 |
+
return query_list
|
52 |
+
|
53 |
+
async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]:
|
54 |
+
"""异步搜索PubMed数据库"""
|
55 |
+
if not rag_request.query:
|
56 |
+
return []
|
57 |
+
|
58 |
+
start_time = time.time()
|
59 |
+
query_list = await self.get_query_list(rag_request)
|
60 |
+
|
61 |
+
# 使用异步并发替代线程池
|
62 |
+
articles_id_list = []
|
63 |
+
es_articles = []
|
64 |
+
|
65 |
+
try:
|
66 |
+
# 创建异步任务列表,使用PubMedApi的search_database方法
|
67 |
+
async_tasks = []
|
68 |
+
for query in query_list:
|
69 |
+
task = self._search_pubmed_with_sync_api(
|
70 |
+
query["query_item"], self.pubmed_topk, query["search_type"]
|
71 |
+
)
|
72 |
+
async_tasks.append((query, task))
|
73 |
+
|
74 |
+
# 并发执行所有搜索任务
|
75 |
+
results = await asyncio.gather(
|
76 |
+
*[task for _, task in async_tasks], return_exceptions=True
|
77 |
+
)
|
78 |
+
|
79 |
+
# 处理结果
|
80 |
+
for i, (query, _) in enumerate(async_tasks):
|
81 |
+
result = results[i]
|
82 |
+
|
83 |
+
if isinstance(result, Exception):
|
84 |
+
logger.error(f"Error in search pubmed: {result}")
|
85 |
+
else:
|
86 |
+
articles_id_list.extend(result)
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Error in concurrent PubMed search: {e}")
|
90 |
+
|
91 |
+
# 获取文章详细信息
|
92 |
+
pubmed_docs = await self.fetch_article_details(articles_id_list)
|
93 |
+
|
94 |
+
# 合并结果
|
95 |
+
all_results = []
|
96 |
+
all_results.extend(pubmed_docs)
|
97 |
+
all_results.extend(es_articles)
|
98 |
+
|
99 |
+
logger.info(
|
100 |
+
f"""Finished searching PubMed, query:{rag_request.query},
|
101 |
+
total articles: {len(articles_id_list)}, total time: {time.time() - start_time:.2f}s"""
|
102 |
+
)
|
103 |
+
return all_results
|
104 |
+
|
105 |
+
async def _search_pubmed_with_sync_api(
|
106 |
+
self, query: str, top_k: int, search_type: str
|
107 |
+
) -> List[str]:
|
108 |
+
"""
|
109 |
+
使用PubMedApi的search_database方法,但通过异步包装来提升并发效率
|
110 |
+
|
111 |
+
Args:
|
112 |
+
query: 搜索查询
|
113 |
+
top_k: 返回结果数量
|
114 |
+
search_type: 搜索类型
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
文章ID列表
|
118 |
+
"""
|
119 |
+
try:
|
120 |
+
# 在线程池中运行同步的search_database方法
|
121 |
+
loop = asyncio.get_event_loop()
|
122 |
+
pubmed_api = PubMedApi()
|
123 |
+
|
124 |
+
# 使用run_in_executor来异步执行同步方法
|
125 |
+
id_list = await loop.run_in_executor(
|
126 |
+
None, # 使用默认线程池
|
127 |
+
pubmed_api.search_database,
|
128 |
+
query,
|
129 |
+
top_k,
|
130 |
+
search_type,
|
131 |
+
)
|
132 |
+
return id_list
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error in PubMed search for query '{query}': {e}")
|
135 |
+
raise e
|
136 |
+
|
137 |
+
async def fetch_article_details(
|
138 |
+
self, articles_id_list: List[str]
|
139 |
+
) -> List[BaseBioDocument]:
|
140 |
+
"""根据文章ID从pubmed获取文章详细信息"""
|
141 |
+
if not articles_id_list:
|
142 |
+
return []
|
143 |
+
|
144 |
+
# 将articles_id_list去重
|
145 |
+
articles_id_list = list(set(articles_id_list))
|
146 |
+
|
147 |
+
# 将articles_id_list以group_size个一组切分成不同的列表
|
148 |
+
group_size = 80
|
149 |
+
articles_id_groups = [
|
150 |
+
articles_id_list[i : i + group_size]
|
151 |
+
for i in range(0, len(articles_id_list), group_size)
|
152 |
+
]
|
153 |
+
|
154 |
+
try:
|
155 |
+
# 并发获取所有组的详细信息
|
156 |
+
batch_tasks = []
|
157 |
+
for ids in articles_id_groups:
|
158 |
+
pubmed_async_api = PubMedAsyncApi()
|
159 |
+
task = pubmed_async_api.fetch_details(id_list=ids)
|
160 |
+
batch_tasks.append(task)
|
161 |
+
|
162 |
+
task_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
163 |
+
|
164 |
+
fetch_results = []
|
165 |
+
for result in task_results:
|
166 |
+
if isinstance(result, Exception):
|
167 |
+
logger.error(f"Error in fetch_details: {result}")
|
168 |
+
continue
|
169 |
+
fetch_results.extend(result)
|
170 |
+
|
171 |
+
except Exception as e:
|
172 |
+
logger.error(f"Error in concurrent fetch_details: {e}")
|
173 |
+
return []
|
174 |
+
|
175 |
+
# 转换为BioDocument对象
|
176 |
+
all_results = [
|
177 |
+
create_bio_document(
|
178 |
+
title=result["title"],
|
179 |
+
abstract=result["abstract"],
|
180 |
+
authors=self.process_authors(result["authors"]),
|
181 |
+
doi=result["doi"],
|
182 |
+
source=self.data_source,
|
183 |
+
source_id=result["pmid"],
|
184 |
+
pub_date=result["pub_date"],
|
185 |
+
journal=result["journal"],
|
186 |
+
text=result["abstract"],
|
187 |
+
url=f'https://pubmed.ncbi.nlm.nih.gov/{result["pmid"]}',
|
188 |
+
)
|
189 |
+
for result in fetch_results
|
190 |
+
]
|
191 |
+
return all_results
|
192 |
+
|
193 |
+
def process_authors(self, author_list: List[Dict]) -> str:
|
194 |
+
"""处理作者列表,将其转换为字符串"""
|
195 |
+
return ", ".join(
|
196 |
+
[f"{author['forename']} {author['lastname']}" for author in author_list]
|
197 |
+
)
|
python-services/Retrieve/search_service/web_search.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Web search service for retrieving and processing web content.
|
3 |
+
|
4 |
+
This module provides functionality to search the web using Serper API
|
5 |
+
and extract content from web pages using crawl4ai.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import asyncio
|
9 |
+
import os
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
from bio_requests.rag_request import RagRequest
|
13 |
+
from dto.bio_document import BaseBioDocument, create_bio_document
|
14 |
+
from search_service.base_search import BaseSearchService
|
15 |
+
from service.web_search import SerperClient, scrape_urls, url_to_fit_contents
|
16 |
+
from utils.bio_logger import bio_logger as logger
|
17 |
+
|
18 |
+
|
19 |
+
class WebSearchService(BaseSearchService):
|
20 |
+
"""
|
21 |
+
Web search service that retrieves content from web pages.
|
22 |
+
|
23 |
+
This service uses Serper API for web search and crawl4ai for content extraction.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self):
|
27 |
+
"""Initialize the web search service."""
|
28 |
+
self.data_source = "web"
|
29 |
+
self._serper_client: Optional[SerperClient] = None
|
30 |
+
self._max_results = 5
|
31 |
+
self._content_length_limit = 40000 # ~10k tokens
|
32 |
+
|
33 |
+
@property
|
34 |
+
def serper_client(self) -> SerperClient:
|
35 |
+
"""Lazy initialization of SerperClient."""
|
36 |
+
if self._serper_client is None:
|
37 |
+
# 从环境变量获取API密钥
|
38 |
+
api_key = os.getenv("SERPER_API_KEY")
|
39 |
+
if not api_key:
|
40 |
+
logger.warning("SERPER_API_KEY environment variable not set, using default key")
|
41 |
+
|
42 |
+
self._serper_client = SerperClient(api_key=api_key)
|
43 |
+
return self._serper_client
|
44 |
+
|
45 |
+
async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]:
|
46 |
+
"""
|
47 |
+
Perform web search and extract content from search results.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
rag_request: The RAG request containing the search query
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
List of BaseBioDocument objects with extracted web content
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
query = rag_request.query
|
57 |
+
logger.info(f"Starting web search for query: {query}")
|
58 |
+
|
59 |
+
# Search for URLs using Serper
|
60 |
+
url_results = await self.search_serper(query, rag_request.top_k)
|
61 |
+
|
62 |
+
if not url_results:
|
63 |
+
logger.info(f"No search results found for query: {query}")
|
64 |
+
return []
|
65 |
+
|
66 |
+
# Extract content from URLs
|
67 |
+
search_results = await self.enrich_url_results_with_contents(url_results)
|
68 |
+
|
69 |
+
logger.info(f"Web search completed. Found {len(search_results)} documents")
|
70 |
+
return search_results
|
71 |
+
|
72 |
+
except Exception as e:
|
73 |
+
logger.error(f"Error during web search: {str(e)}", exc_info=e)
|
74 |
+
return []
|
75 |
+
|
76 |
+
async def enrich_url_results_with_contents(
|
77 |
+
self, results: List
|
78 |
+
) -> List[BaseBioDocument]:
|
79 |
+
"""
|
80 |
+
Extract content from URLs and create BaseBioDocument objects.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
results: List of search results with URLs
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
List of BaseBioDocument objects with extracted content
|
87 |
+
"""
|
88 |
+
try:
|
89 |
+
# Create tasks for concurrent content extraction
|
90 |
+
tasks = [self._extract_content_from_url(res) for res in results]
|
91 |
+
contents = await asyncio.gather(*tasks, return_exceptions=True)
|
92 |
+
|
93 |
+
enriched_results = []
|
94 |
+
for res, content in zip(results, contents):
|
95 |
+
# Handle exceptions from content extraction
|
96 |
+
if isinstance(content, Exception):
|
97 |
+
logger.error(f"Failed to extract content from {res.url}: {content}")
|
98 |
+
continue
|
99 |
+
|
100 |
+
bio_doc = create_bio_document(
|
101 |
+
title=res.title,
|
102 |
+
url=res.url,
|
103 |
+
text=str(content)[: self._content_length_limit],
|
104 |
+
source=self.data_source,
|
105 |
+
)
|
106 |
+
enriched_results.append(bio_doc)
|
107 |
+
|
108 |
+
return enriched_results
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"Error enriching URL results: {str(e)}", exc_info=e)
|
112 |
+
return []
|
113 |
+
|
114 |
+
async def _extract_content_from_url(self, res) -> str:
|
115 |
+
"""
|
116 |
+
Extract content from a single URL with error handling.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
res: Search result object containing URL information
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
Extracted content as string
|
123 |
+
"""
|
124 |
+
try:
|
125 |
+
return await url_to_fit_contents(res)
|
126 |
+
except Exception as e:
|
127 |
+
logger.error(f"Error extracting content from {res.url}: {str(e)}")
|
128 |
+
return f"Error extracting content: {str(e)}"
|
129 |
+
|
130 |
+
async def search_serper(
|
131 |
+
self, query: str, max_results: Optional[int] = None
|
132 |
+
) -> List:
|
133 |
+
"""
|
134 |
+
Perform web search using Serper API.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
query: Search query string
|
138 |
+
max_results: Maximum number of results to return
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
List of search results with URLs
|
142 |
+
"""
|
143 |
+
try:
|
144 |
+
max_results = max_results or self._max_results
|
145 |
+
logger.info(f"Searching Serper for: {query} (max_results: {max_results})")
|
146 |
+
|
147 |
+
search_results = await self.serper_client.search(
|
148 |
+
query, filter_for_relevance=True, max_results=max_results
|
149 |
+
)
|
150 |
+
|
151 |
+
if not search_results:
|
152 |
+
logger.info(f"No search results from Serper for query: {query}")
|
153 |
+
return []
|
154 |
+
|
155 |
+
# Scrape content from URLs
|
156 |
+
results = await scrape_urls(search_results)
|
157 |
+
|
158 |
+
logger.info(f"Serper search completed. Found {len(results)} results")
|
159 |
+
return results
|
160 |
+
|
161 |
+
except Exception as e:
|
162 |
+
logger.error(f"Error in Serper search: {str(e)}", exc_info=e)
|
163 |
+
return []
|
python-services/Retrieve/service/__init__.py
ADDED
File without changes
|
python-services/Retrieve/service/chat.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""生物医学聊天服务模块,提供RAG问答和流式响应功能。"""
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
from typing import Any, AsyncGenerator, List
|
7 |
+
|
8 |
+
from openai import AsyncOpenAI
|
9 |
+
from openai.types.chat import ChatCompletionMessageParam
|
10 |
+
|
11 |
+
from bio_requests.chat_request import ChatRequest
|
12 |
+
from bio_requests.rag_request import RagRequest
|
13 |
+
from config.global_storage import get_model_config
|
14 |
+
from search_service.pubmed_search import PubMedSearchService
|
15 |
+
|
16 |
+
from search_service.web_search import WebSearchService
|
17 |
+
from service.query_rewrite import QueryRewriteService
|
18 |
+
from service.rerank import RerankService
|
19 |
+
from utils.bio_logger import bio_logger as logger
|
20 |
+
from utils.i18n_util import get_error_message, get_label_message
|
21 |
+
from utils.token_util import num_tokens_from_messages, num_tokens_from_text
|
22 |
+
from utils.snowflake_id import snowflake_id_str
|
23 |
+
|
24 |
+
|
25 |
+
class ChatService:
|
26 |
+
"""生物医学聊天服务,提供RAG问答和流式响应功能。"""
|
27 |
+
|
28 |
+
def __init__(self):
|
29 |
+
self.pubmed_search_service = PubMedSearchService()
|
30 |
+
self.web_search_service = WebSearchService()
|
31 |
+
self.query_rewrite_service = QueryRewriteService()
|
32 |
+
|
33 |
+
self.rag_request = RagRequest()
|
34 |
+
self.rerank_service = RerankService()
|
35 |
+
self.model_config = get_model_config()
|
36 |
+
|
37 |
+
def _initialize_rag_request(self, chat_request: ChatRequest) -> None:
|
38 |
+
"""初始化RAG请求参数"""
|
39 |
+
self.rag_request.query = chat_request.query
|
40 |
+
|
41 |
+
async def generate_stream(self, chat_request: ChatRequest):
|
42 |
+
"""
|
43 |
+
Generate a stream of messages for the chat request.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
chat_request: 聊天请求
|
47 |
+
"""
|
48 |
+
|
49 |
+
start_time = time.time()
|
50 |
+
|
51 |
+
try:
|
52 |
+
# 初始化RAG请求
|
53 |
+
self._initialize_rag_request(chat_request)
|
54 |
+
|
55 |
+
# PubMed搜索
|
56 |
+
logger.info("QA-RAG: Start search pubmed...")
|
57 |
+
|
58 |
+
pubmed_results = await self._search_pubmed(chat_request)
|
59 |
+
|
60 |
+
pubmed_task_text = self._generate_pubmed_search_task_text(pubmed_results)
|
61 |
+
yield pubmed_task_text
|
62 |
+
logger.info(
|
63 |
+
f"QA-RAG: Finished search pubmed, length: {len(pubmed_results)}"
|
64 |
+
)
|
65 |
+
|
66 |
+
# Web搜索
|
67 |
+
web_results = []
|
68 |
+
|
69 |
+
logger.info("QA-RAG: Start search web...")
|
70 |
+
|
71 |
+
web_urls, task_text = await self._search_web()
|
72 |
+
logger.info("QA-RAG: Finished search web...")
|
73 |
+
|
74 |
+
web_results = (
|
75 |
+
await self.web_search_service.enrich_url_results_with_contents(web_urls)
|
76 |
+
)
|
77 |
+
|
78 |
+
yield task_text
|
79 |
+
|
80 |
+
# 创建消息
|
81 |
+
messages, citation_list = self._create_messages(
|
82 |
+
pubmed_results, web_results, chat_request
|
83 |
+
)
|
84 |
+
citation_text = self._generate_citation_text(citation_list)
|
85 |
+
yield citation_text
|
86 |
+
# 流式聊天完成
|
87 |
+
async for content in self._stream_chat_completion(messages):
|
88 |
+
yield content
|
89 |
+
|
90 |
+
logger.info(
|
91 |
+
f"Finished search and chat, query: [{chat_request.query}], total time: {time.time() - start_time:.2f}s"
|
92 |
+
)
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Error occurred: {e}")
|
96 |
+
# 使用上下文中的语言返回错误消息
|
97 |
+
error_msg = get_error_message("llm_service_error")
|
98 |
+
yield f"data: {error_msg}\n\n"
|
99 |
+
return
|
100 |
+
|
101 |
+
def _generate_citation_text(self, citation_list: List[Any]) -> str:
|
102 |
+
"""生成引用文本"""
|
103 |
+
|
104 |
+
return f"""
|
105 |
+
```bdd-resource-lookup
|
106 |
+
{json.dumps(citation_list)}
|
107 |
+
```
|
108 |
+
"""
|
109 |
+
|
110 |
+
async def _search_pubmed(self, chat_request: ChatRequest) -> List[Any]:
|
111 |
+
"""执行PubMed搜索"""
|
112 |
+
try:
|
113 |
+
logger.info(f"query: {chat_request.query}, Using pubmed search...")
|
114 |
+
self.rag_request.top_k = self.model_config["qa-topk"]["pubmed"]
|
115 |
+
self.rag_request.pubmed_topk = self.model_config["qa-topk"]["pubmed"]
|
116 |
+
|
117 |
+
start_search_time = time.time()
|
118 |
+
pubmed_results = await self.pubmed_search_service.search(self.rag_request)
|
119 |
+
end_search_time = time.time()
|
120 |
+
|
121 |
+
logger.info(
|
122 |
+
f"length of pubmed_results: {len(pubmed_results)},time used:{end_search_time - start_search_time:.2f}s"
|
123 |
+
)
|
124 |
+
pubmed_results = pubmed_results[0 : self.rag_request.top_k]
|
125 |
+
logger.info(f"length of pubmed_results after rerank: {len(pubmed_results)}")
|
126 |
+
|
127 |
+
end_rerank_time = time.time()
|
128 |
+
logger.info(
|
129 |
+
f"Reranked {len(pubmed_results)} results,time used:{end_rerank_time - end_search_time:.2f}s"
|
130 |
+
)
|
131 |
+
|
132 |
+
return pubmed_results
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"error in search pubmed: {e}")
|
135 |
+
return []
|
136 |
+
|
137 |
+
async def _search_web(self) -> tuple[List[Any], str]:
|
138 |
+
"""执行Web搜索"""
|
139 |
+
web_topk = self.model_config["qa-topk"]["web"]
|
140 |
+
try:
|
141 |
+
# 尝试获取重写后的查询
|
142 |
+
query_list = await self.query_rewrite_service.query_split_for_web(
|
143 |
+
self.rag_request.query
|
144 |
+
)
|
145 |
+
# 安全获取重写查询,如果query_list为空或获取失败则使用原始查询
|
146 |
+
serper_query = (
|
147 |
+
query_list[0].get("query_item", "").strip() if query_list else None
|
148 |
+
)
|
149 |
+
# 如果重写查询为空,则回退到原始查询
|
150 |
+
if not serper_query:
|
151 |
+
serper_query = self.rag_request.query
|
152 |
+
# 使用最终确定的查询执行搜索
|
153 |
+
url_results = await self.web_search_service.search_serper(
|
154 |
+
query=serper_query, max_results=web_topk
|
155 |
+
)
|
156 |
+
except Exception as e:
|
157 |
+
logger.error(f"error in query rewrite web or serper retrieval: {e}")
|
158 |
+
# 出错时使用原始查询进行搜索
|
159 |
+
url_results = await self.web_search_service.search_serper(
|
160 |
+
query=self.rag_request.query, max_results=web_topk
|
161 |
+
)
|
162 |
+
|
163 |
+
# 生成任务文本
|
164 |
+
task_text = self._generate_web_search_task_text(url_results)
|
165 |
+
return url_results, task_text
|
166 |
+
|
167 |
+
|
168 |
+
def _generate_pubmed_search_task_text(self, pubmed_results: List[Any]) -> str:
|
169 |
+
"""生成PubMed搜索任务文本"""
|
170 |
+
docs = [
|
171 |
+
{
|
172 |
+
"docId": result.bio_id,
|
173 |
+
"url": result.url,
|
174 |
+
"title": result.title,
|
175 |
+
"description": result.text,
|
176 |
+
"author": result.authors,
|
177 |
+
"JournalInfo": result.journal.get("title", "")
|
178 |
+
+ "."
|
179 |
+
+ result.journal.get("year", "")
|
180 |
+
+ "."
|
181 |
+
+ (
|
182 |
+
result.journal.get("start_page", "")
|
183 |
+
+ "-"
|
184 |
+
+ result.journal.get("end_page", "")
|
185 |
+
+ "."
|
186 |
+
if result.journal.get("start_page")
|
187 |
+
and result.journal.get("end_page")
|
188 |
+
else ""
|
189 |
+
)
|
190 |
+
+ "doi:"
|
191 |
+
+ result.doi,
|
192 |
+
"PMID": result.source_id,
|
193 |
+
}
|
194 |
+
for result in pubmed_results
|
195 |
+
]
|
196 |
+
label = get_label_message("pubmed_search")
|
197 |
+
return self._generate_task_text(label, "pubmed", docs)
|
198 |
+
|
199 |
+
def _generate_web_search_task_text(self, url_results: List[Any]) -> str:
|
200 |
+
"""生成Web搜索任务文本"""
|
201 |
+
web_docs = [
|
202 |
+
{
|
203 |
+
"docId": snowflake_id_str(),
|
204 |
+
"url": url_result.url,
|
205 |
+
"title": url_result.title,
|
206 |
+
"description": url_result.description,
|
207 |
+
}
|
208 |
+
for url_result in url_results
|
209 |
+
]
|
210 |
+
|
211 |
+
logger.info(f"URL Results: {web_docs}")
|
212 |
+
|
213 |
+
label = get_label_message("web_search")
|
214 |
+
|
215 |
+
return self._generate_task_text(label, "webSearch", web_docs)
|
216 |
+
|
217 |
+
def _generate_task_text(self, label, source, bio_docs: List[Any]):
|
218 |
+
"""生成任务文本"""
|
219 |
+
task = {
|
220 |
+
"type": "search",
|
221 |
+
"label": label,
|
222 |
+
"hoverable": True,
|
223 |
+
"handler": "QASearch",
|
224 |
+
"status": "running",
|
225 |
+
"handlerParam": {"source": source, "bioDocs": bio_docs},
|
226 |
+
}
|
227 |
+
return f"""
|
228 |
+
```bdd-chat-agent-task
|
229 |
+
{json.dumps(task)}
|
230 |
+
```
|
231 |
+
"""
|
232 |
+
|
233 |
+
def _build_document_texts(
|
234 |
+
self, pubmed_results: List[Any], web_results: List[Any]
|
235 |
+
) -> tuple[str, str, List[Any]]:
|
236 |
+
"""构建文档文本"""
|
237 |
+
# 个人向量搜索结果
|
238 |
+
citation_list = []
|
239 |
+
temp_doc_list = []
|
240 |
+
|
241 |
+
# pubmed结果
|
242 |
+
pubmed_offset = 0
|
243 |
+
for idx, doc in enumerate(pubmed_results):
|
244 |
+
_idx = idx + 1 + pubmed_offset
|
245 |
+
temp_doc_list.append(
|
246 |
+
"[document {idx} begin] title: {title}. content: {abstract} [document {idx} end]".format(
|
247 |
+
idx=_idx, title=doc.title, abstract=doc.abstract
|
248 |
+
)
|
249 |
+
)
|
250 |
+
citation_list.append(
|
251 |
+
{"source": "pubmed", "docId": doc.bio_id, "citation": _idx}
|
252 |
+
)
|
253 |
+
pubmed_texts = "\n".join(temp_doc_list)
|
254 |
+
|
255 |
+
temp_doc_list = []
|
256 |
+
# 联网搜索结果
|
257 |
+
web_offset = pubmed_offset + len(pubmed_results)
|
258 |
+
for idx, doc in enumerate(web_results):
|
259 |
+
_idx = idx + 1 + web_offset
|
260 |
+
temp_doc_list.append(
|
261 |
+
"[document {idx} begin] title: {title}. content: {content} [document {idx} end]".format(
|
262 |
+
idx=_idx, title=doc.title, content=doc.text
|
263 |
+
)
|
264 |
+
)
|
265 |
+
citation_list.append(
|
266 |
+
{"source": "webSearch", "docId": doc.bio_id, "citation": _idx}
|
267 |
+
)
|
268 |
+
web_texts = "\n".join(temp_doc_list)
|
269 |
+
|
270 |
+
return pubmed_texts, web_texts, citation_list
|
271 |
+
|
272 |
+
def _truncate_documents_to_token_limit(
|
273 |
+
self,
|
274 |
+
pubmed_texts: str,
|
275 |
+
web_texts: str,
|
276 |
+
chat_request: ChatRequest,
|
277 |
+
) -> tuple[List[ChatCompletionMessageParam], int]:
|
278 |
+
"""截断文档以符合token限制"""
|
279 |
+
pubmed_list = pubmed_texts.split("\n")
|
280 |
+
web_list = web_texts.split("\n")
|
281 |
+
|
282 |
+
today = datetime.date.today()
|
283 |
+
openai_client_rag_prompt = self.model_config["chat"]["rag_prompt"]
|
284 |
+
max_tokens = self.model_config["qa-prompt-max-token"]["max_tokens"]
|
285 |
+
pubmed_token_limit = max_tokens
|
286 |
+
web_token_limit = 60000
|
287 |
+
personal_vector_token_limit = 80000
|
288 |
+
if chat_request.is_pubmed and chat_request.is_web:
|
289 |
+
personal_vector_token_limit = 40000
|
290 |
+
pubmed_token_limit = 20000
|
291 |
+
web_token_limit = 60000
|
292 |
+
elif chat_request.is_pubmed and not chat_request.is_web:
|
293 |
+
personal_vector_token_limit = 80000
|
294 |
+
pubmed_token_limit = 40000
|
295 |
+
web_token_limit = 0
|
296 |
+
elif chat_request.is_pubmed and chat_request.is_web:
|
297 |
+
personal_vector_token_limit = 0
|
298 |
+
pubmed_token_limit = 60000
|
299 |
+
web_token_limit = 60000
|
300 |
+
elif chat_request.is_pubmed and not chat_request.is_web:
|
301 |
+
personal_vector_token_limit = 0
|
302 |
+
pubmed_token_limit = 120000
|
303 |
+
web_token_limit = 0
|
304 |
+
|
305 |
+
def calculate_num_tokens(
|
306 |
+
pubmed_list: List[str], web_list: List[str]
|
307 |
+
) -> tuple[int, List[ChatCompletionMessageParam]]:
|
308 |
+
# 合并结果
|
309 |
+
docs_text = "\n".join(pubmed_list + web_list)
|
310 |
+
|
311 |
+
pt = (
|
312 |
+
openai_client_rag_prompt.replace("{search_results}", docs_text)
|
313 |
+
.replace("{cur_date}", str(today))
|
314 |
+
.replace("{question}", chat_request.query)
|
315 |
+
)
|
316 |
+
messages: List[ChatCompletionMessageParam] = [
|
317 |
+
{"role": "user", "content": pt}
|
318 |
+
]
|
319 |
+
# 计算token数
|
320 |
+
num_tokens = num_tokens_from_messages(messages)
|
321 |
+
return num_tokens, messages
|
322 |
+
|
323 |
+
while True:
|
324 |
+
num_tokens, messages = calculate_num_tokens(pubmed_list, web_list)
|
325 |
+
if num_tokens <= max_tokens:
|
326 |
+
break
|
327 |
+
# 如果超过token限制,则按照比例进行截断
|
328 |
+
logger.info(
|
329 |
+
f"start truncate documents to token limit: max_tokens: {max_tokens}"
|
330 |
+
)
|
331 |
+
logger.info(
|
332 |
+
f"pubmed_token_limit: {pubmed_token_limit}, web_token_limit: {web_token_limit}, personal_vector_token_limit: {personal_vector_token_limit}"
|
333 |
+
)
|
334 |
+
|
335 |
+
while True:
|
336 |
+
if num_tokens_from_text("\n".join(pubmed_list)) > pubmed_token_limit:
|
337 |
+
pubmed_list.pop()
|
338 |
+
else:
|
339 |
+
break
|
340 |
+
|
341 |
+
# 截断pubmed之后,重新计算token数,如果token数小于max_tokens,则停止截断
|
342 |
+
num_tokens, messages = calculate_num_tokens(pubmed_list, web_list)
|
343 |
+
if num_tokens <= max_tokens:
|
344 |
+
break
|
345 |
+
|
346 |
+
while True:
|
347 |
+
if num_tokens_from_text("\n".join(web_list)) > web_token_limit:
|
348 |
+
web_list.pop()
|
349 |
+
else:
|
350 |
+
break
|
351 |
+
|
352 |
+
# 截断web之后,重新计算token数,如果token数小于max_tokens,则停止截断
|
353 |
+
num_tokens, messages = calculate_num_tokens(pubmed_list, web_list)
|
354 |
+
if num_tokens <= max_tokens:
|
355 |
+
break
|
356 |
+
|
357 |
+
logger.info(f"Final token count: {num_tokens}")
|
358 |
+
return messages, num_tokens
|
359 |
+
|
360 |
+
def _create_messages(
|
361 |
+
self,
|
362 |
+
pubmed_results: List[Any],
|
363 |
+
web_results: List[Any],
|
364 |
+
chat_request: ChatRequest,
|
365 |
+
) -> tuple[List[ChatCompletionMessageParam], List[Any]]:
|
366 |
+
"""创建聊天消息"""
|
367 |
+
if len(pubmed_results) == 0 and len(web_results) == 0:
|
368 |
+
logger.info(f"No results found for query: {chat_request.query}")
|
369 |
+
pt = chat_request.query
|
370 |
+
messages: List[ChatCompletionMessageParam] = [
|
371 |
+
{"role": "user", "content": pt}
|
372 |
+
]
|
373 |
+
num_tokens = num_tokens_from_messages(messages)
|
374 |
+
logger.info(f"Total tokens: {num_tokens}")
|
375 |
+
return messages, []
|
376 |
+
|
377 |
+
# 构建文档文本
|
378 |
+
pubmed_texts, web_texts, citation_list = self._build_document_texts(
|
379 |
+
pubmed_results, web_results
|
380 |
+
)
|
381 |
+
|
382 |
+
# 截断文档以符合token限制
|
383 |
+
messages, num_tokens = self._truncate_documents_to_token_limit(
|
384 |
+
pubmed_texts, web_texts, chat_request
|
385 |
+
)
|
386 |
+
|
387 |
+
return messages, citation_list
|
388 |
+
|
389 |
+
async def _stream_chat_completion(
|
390 |
+
self, messages: List[ChatCompletionMessageParam]
|
391 |
+
) -> AsyncGenerator[bytes, None]:
|
392 |
+
"""流式聊天完成,支持qa-llm的main/backup配置"""
|
393 |
+
|
394 |
+
async def create_stream_with_config(
|
395 |
+
qa_config: dict, config_name: str
|
396 |
+
) -> AsyncGenerator[bytes, None]:
|
397 |
+
"""使用指定配置创建流式响应"""
|
398 |
+
try:
|
399 |
+
logger.info(f"Using qa-llm {config_name} configuration")
|
400 |
+
|
401 |
+
client = AsyncOpenAI(
|
402 |
+
api_key=qa_config["api_key"],
|
403 |
+
base_url=qa_config["base_url"],
|
404 |
+
)
|
405 |
+
|
406 |
+
chat_start_time = time.time()
|
407 |
+
|
408 |
+
# 创建聊天完成流
|
409 |
+
stream = await client.chat.completions.create(
|
410 |
+
model=qa_config["model"],
|
411 |
+
messages=messages,
|
412 |
+
stream=True,
|
413 |
+
temperature=qa_config["temperature"],
|
414 |
+
max_tokens=qa_config["max_tokens"],
|
415 |
+
)
|
416 |
+
|
417 |
+
logger.info(
|
418 |
+
f"Finished chat completion with {config_name} config, total time: {time.time() - chat_start_time:.2f}s"
|
419 |
+
)
|
420 |
+
|
421 |
+
is_start_answer = False
|
422 |
+
# 处理流式响应
|
423 |
+
async for chunk in stream:
|
424 |
+
if chunk.choices and (content := chunk.choices[0].delta.content):
|
425 |
+
if not is_start_answer:
|
426 |
+
is_start_answer = True
|
427 |
+
|
428 |
+
yield content.encode("utf-8")
|
429 |
+
|
430 |
+
except Exception as e:
|
431 |
+
logger.info(f"qa-llm {config_name} configuration failed: {e}")
|
432 |
+
raise e
|
433 |
+
|
434 |
+
async def with_fallback(main_func, backup_func):
|
435 |
+
"""高阶函数:尝试主函数,失败时使用备选函数"""
|
436 |
+
try:
|
437 |
+
async for content in main_func():
|
438 |
+
yield content
|
439 |
+
except Exception as main_error:
|
440 |
+
logger.info("Main config failed, falling back to backup configuration")
|
441 |
+
try:
|
442 |
+
async for content in backup_func():
|
443 |
+
yield content
|
444 |
+
except Exception as backup_error:
|
445 |
+
logger.error(
|
446 |
+
f"Both main and backup qa-llm configurations failed. "
|
447 |
+
f"Main error: {main_error}, Backup error: {backup_error}"
|
448 |
+
)
|
449 |
+
raise backup_error
|
450 |
+
|
451 |
+
# 创建主用和备选配置的生成器函数
|
452 |
+
async def main_stream():
|
453 |
+
logger.info("Using main qa-llm configuration")
|
454 |
+
async for content in create_stream_with_config(
|
455 |
+
self.model_config["qa-llm"]["main"], "main"
|
456 |
+
):
|
457 |
+
yield content
|
458 |
+
|
459 |
+
async def backup_stream():
|
460 |
+
logger.info("Using backup qa-llm configuration")
|
461 |
+
async for content in create_stream_with_config(
|
462 |
+
self.model_config["qa-llm"]["backup"], "backup"
|
463 |
+
):
|
464 |
+
yield content
|
465 |
+
|
466 |
+
# 使用fallback逻辑
|
467 |
+
async for content in with_fallback(main_stream, backup_stream):
|
468 |
+
yield content
|
python-services/Retrieve/service/pubmed_api.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Dict, List
|
3 |
+
from Bio import Entrez
|
4 |
+
import requests
|
5 |
+
from config.global_storage import get_model_config
|
6 |
+
from dto.bio_document import PubMedDocument
|
7 |
+
from service.pubmed_xml_parse import PubmedXmlParse
|
8 |
+
from utils.bio_logger import bio_logger as logger
|
9 |
+
|
10 |
+
PUBMED_ACCOUNT = [
|
11 |
+
{"email": "[email protected]", "api_key": "60eb67add17f39aa588a43e30bb7fce98809"},
|
12 |
+
{"email": "[email protected]", "api_key": "fd9bb5b827c95086b9c2d579df20beca2708"},
|
13 |
+
{"email": "[email protected]", "api_key": "026586b79437a2b21d1e27d8c3f339230208"},
|
14 |
+
{"email": "[email protected]", "api_key": "bca0489d8fe314bfdbb1f7bfe63fb5d76e09"},
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
class PubMedApi:
|
19 |
+
def __init__(self):
|
20 |
+
self.pubmed_xml_parse = PubmedXmlParse()
|
21 |
+
self.model_config = get_model_config()
|
22 |
+
|
23 |
+
def pubmed_search_function(
|
24 |
+
self, query: str, top_k: int, search_type: str
|
25 |
+
) -> List[PubMedDocument]:
|
26 |
+
|
27 |
+
try:
|
28 |
+
start_time = time.time()
|
29 |
+
logger.info(
|
30 |
+
f'Trying to search PubMed for "{query}", top_k={top_k}, search_type={search_type}'
|
31 |
+
)
|
32 |
+
id_list = self.search_database(query, retmax=top_k, search_type=search_type)
|
33 |
+
records = self.fetch_details(id_list, db="pubmed", rettype="abstract")
|
34 |
+
|
35 |
+
end_search_pubmed_time = time.time()
|
36 |
+
logger.info(
|
37 |
+
f'Finished searching PubMed for "{query}", took {end_search_pubmed_time - start_time:.2f} seconds, found {len(records)} results'
|
38 |
+
)
|
39 |
+
|
40 |
+
return [
|
41 |
+
PubMedDocument(
|
42 |
+
title=result["title"],
|
43 |
+
abstract=result["abstract"],
|
44 |
+
authors=self.process_authors(result["authors"]),
|
45 |
+
doi=result["doi"],
|
46 |
+
source="pubmed",
|
47 |
+
source_id=result["pmid"],
|
48 |
+
pub_date=result["pub_date"],
|
49 |
+
journal=result["journal"],
|
50 |
+
text=result["abstract"],
|
51 |
+
)
|
52 |
+
for result in records
|
53 |
+
]
|
54 |
+
except Exception as e:
|
55 |
+
logger.error(f"Error searching PubMed query: {query} error: {e}")
|
56 |
+
raise e
|
57 |
+
|
58 |
+
def process_authors(self, author_list: List[Dict]) -> str:
|
59 |
+
|
60 |
+
return ", ".join(
|
61 |
+
[f"{author['forename']} {author['lastname']}" for author in author_list]
|
62 |
+
)
|
63 |
+
|
64 |
+
# 搜索数据库(ESearch)
|
65 |
+
def search_database(
|
66 |
+
self, query: str, retmax: int, search_type: str = "keyword"
|
67 |
+
) -> List[str]:
|
68 |
+
"""
|
69 |
+
获取pubmed数据库中的记录id列表
|
70 |
+
:param search_type: 搜索类型,keyword或advanced
|
71 |
+
:param query: 查询字符串
|
72 |
+
:param retmax: 返回的最大结果数
|
73 |
+
"""
|
74 |
+
start_time = time.time()
|
75 |
+
db = "pubmed"
|
76 |
+
# 随机从pubmed账号池中选择一个
|
77 |
+
random_index = int((time.time() * 1000) % len(PUBMED_ACCOUNT))
|
78 |
+
random_pubmed_account = PUBMED_ACCOUNT[random_index]
|
79 |
+
Entrez.email = random_pubmed_account["email"]
|
80 |
+
Entrez.api_key = random_pubmed_account["api_key"]
|
81 |
+
if search_type == "keyword":
|
82 |
+
art_type_list = [
|
83 |
+
"Address",
|
84 |
+
"Bibliography",
|
85 |
+
"Biography",
|
86 |
+
"Books and Documents",
|
87 |
+
"Clinical Conference",
|
88 |
+
"Clinical Study",
|
89 |
+
"Collected Works",
|
90 |
+
"Comment",
|
91 |
+
"Congress",
|
92 |
+
"Consensus Development Conference",
|
93 |
+
"Consensus Development Conference, NIH",
|
94 |
+
"Dictionary",
|
95 |
+
"Directory",
|
96 |
+
"Duplicate Publication",
|
97 |
+
"Editorial",
|
98 |
+
"Festschrift",
|
99 |
+
"Government Document",
|
100 |
+
"Guideline",
|
101 |
+
"Interactive Tutorial",
|
102 |
+
"Interview",
|
103 |
+
"Lecture",
|
104 |
+
"Legal Case",
|
105 |
+
"Legislation",
|
106 |
+
"Letter",
|
107 |
+
"News",
|
108 |
+
"Newspaper Article",
|
109 |
+
"Patient Education Handout",
|
110 |
+
"Periodical Index",
|
111 |
+
"Personal Narrative",
|
112 |
+
"Practice Guideline",
|
113 |
+
"Published Erratum",
|
114 |
+
"Technical Report",
|
115 |
+
"Video-Audio Media",
|
116 |
+
"Webcast",
|
117 |
+
]
|
118 |
+
art_type = "(" + " OR ".join(f'"{j}"[Filter]' for j in art_type_list) + ")"
|
119 |
+
query = "( " + query + ")"
|
120 |
+
query += " AND (fha[Filter]) NOT " + art_type
|
121 |
+
handle = Entrez.esearch(
|
122 |
+
db=db, term=query, usehistory="y", sort="relevance", retmax=retmax
|
123 |
+
)
|
124 |
+
elif search_type == "advanced":
|
125 |
+
handle = Entrez.esearch(
|
126 |
+
db=db, term=query, usehistory="y", sort="relevance", retmax=retmax
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
raise ValueError("search_type must be either 'keyword' or 'advanced'")
|
130 |
+
|
131 |
+
results = Entrez.read(handle)
|
132 |
+
handle.close()
|
133 |
+
id_list = results["IdList"]
|
134 |
+
logger.info(
|
135 |
+
f"Finished searching PubMed id, took {time.time() - start_time:.2f} seconds, found {len(id_list)} results,query: {query}"
|
136 |
+
)
|
137 |
+
logger.info(
|
138 |
+
f"Search type:{search_type} PubMed search query: {query}, id_list: {id_list}"
|
139 |
+
)
|
140 |
+
if len(id_list) == 0:
|
141 |
+
return []
|
142 |
+
return id_list
|
143 |
+
|
144 |
+
def fetch_details(self, id_list, db="pubmed", rettype="abstract"):
|
145 |
+
start_time = time.time()
|
146 |
+
try:
|
147 |
+
ids = ",".join(id_list)
|
148 |
+
server = "efetch"
|
149 |
+
|
150 |
+
random_index = int((time.time() * 1000) % len(PUBMED_ACCOUNT))
|
151 |
+
random_pubmed_account = PUBMED_ACCOUNT[random_index]
|
152 |
+
api_key = random_pubmed_account["api_key"]
|
153 |
+
url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/{server}.fcgi?db={db}&id={ids}&retmode=xml&api_key={api_key}&rettype={rettype}"
|
154 |
+
response = requests.get(url)
|
155 |
+
articles = self.pubmed_xml_parse.parse_pubmed_xml(response.text)
|
156 |
+
logger.info(
|
157 |
+
f"pubmed_async_http fetch detail, Time taken: {time.time() - start_time}"
|
158 |
+
)
|
159 |
+
return articles
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Error fetching details for id_list: {id_list}, error: {e}")
|
162 |
+
# pmid 精准匹配
|
163 |
+
|
164 |
+
return []
|
python-services/Retrieve/service/pubmed_async_api.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import time
|
3 |
+
from typing import Dict, List
|
4 |
+
import aiohttp
|
5 |
+
|
6 |
+
from config.global_storage import get_model_config
|
7 |
+
from dto.bio_document import PubMedDocument
|
8 |
+
from service.pubmed_xml_parse import PubmedXmlParse
|
9 |
+
from utils.bio_logger import bio_logger as logger
|
10 |
+
|
11 |
+
PUBMED_ACCOUNT = [
|
12 |
+
{"email": "[email protected]", "api_key": "60eb67add17f39aa588a43e30bb7fce98809"},
|
13 |
+
{"email": "[email protected]", "api_key": "fd9bb5b827c95086b9c2d579df20beca2708"},
|
14 |
+
{"email": "[email protected]", "api_key": "026586b79437a2b21d1e27d8c3f339230208"},
|
15 |
+
{"email": "[email protected]", "api_key": "bca0489d8fe314bfdbb1f7bfe63fb5d76e09"},
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
class PubMedAsyncApi:
|
20 |
+
def __init__(self):
|
21 |
+
self.pubmed_xml_parse = PubmedXmlParse()
|
22 |
+
self.model_config = get_model_config()
|
23 |
+
|
24 |
+
async def pubmed_search_function(
|
25 |
+
self, query: str, top_k: int, search_type: str
|
26 |
+
) -> List[PubMedDocument]:
|
27 |
+
|
28 |
+
try:
|
29 |
+
start_time = time.time()
|
30 |
+
logger.info(
|
31 |
+
f'Trying to search PubMed for "{query}", top_k={top_k}, search_type={search_type}'
|
32 |
+
)
|
33 |
+
id_list = await self.search_database(
|
34 |
+
query, db="pubmed", retmax=top_k, search_type=search_type
|
35 |
+
)
|
36 |
+
articles = await self.fetch_details(
|
37 |
+
id_list, db="pubmed", rettype="abstract"
|
38 |
+
)
|
39 |
+
|
40 |
+
end_search_pubmed_time = time.time()
|
41 |
+
logger.info(
|
42 |
+
f'Finished searching PubMed for "{query}", took {end_search_pubmed_time - start_time:.2f} seconds, found {len(articles)} results'
|
43 |
+
)
|
44 |
+
|
45 |
+
return [
|
46 |
+
PubMedDocument(
|
47 |
+
title=result["title"],
|
48 |
+
abstract=result["abstract"],
|
49 |
+
authors=self.process_authors(result["authors"]),
|
50 |
+
doi=result["doi"],
|
51 |
+
source="pubmed",
|
52 |
+
source_id=result["pmid"],
|
53 |
+
pub_date=result["pub_date"],
|
54 |
+
journal=result["journal"],
|
55 |
+
)
|
56 |
+
for result in articles
|
57 |
+
]
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Error searching PubMed query: {query} error: {e}")
|
60 |
+
raise e
|
61 |
+
|
62 |
+
def process_authors(self, author_list: List[Dict]) -> str:
|
63 |
+
|
64 |
+
return ", ".join(
|
65 |
+
[f"{author['forename']} {author['lastname']}" for author in author_list]
|
66 |
+
)
|
67 |
+
|
68 |
+
# 搜索数据库(ESearch)
|
69 |
+
async def search_database(
|
70 |
+
self, query: str, db: str, retmax: int, search_type: str = "keyword"
|
71 |
+
) -> List[Dict]:
|
72 |
+
if search_type not in ["keyword", "advanced"]:
|
73 |
+
raise ValueError("search_type must be one of 'keyword' or 'advanced'")
|
74 |
+
|
75 |
+
if search_type == "keyword":
|
76 |
+
art_type_list = [
|
77 |
+
"Address",
|
78 |
+
"Bibliography",
|
79 |
+
"Biography",
|
80 |
+
"Books and Documents",
|
81 |
+
"Clinical Conference",
|
82 |
+
"Clinical Study",
|
83 |
+
"Collected Works",
|
84 |
+
"Comment",
|
85 |
+
"Congress",
|
86 |
+
"Consensus Development Conference",
|
87 |
+
"Consensus Development Conference, NIH",
|
88 |
+
"Dictionary",
|
89 |
+
"Directory",
|
90 |
+
"Duplicate Publication",
|
91 |
+
"Editorial",
|
92 |
+
"Festschrift",
|
93 |
+
"Government Document",
|
94 |
+
"Guideline",
|
95 |
+
"Interactive Tutorial",
|
96 |
+
"Interview",
|
97 |
+
"Lecture",
|
98 |
+
"Legal Case",
|
99 |
+
"Legislation",
|
100 |
+
"Letter",
|
101 |
+
"News",
|
102 |
+
"Newspaper Article",
|
103 |
+
"Patient Education Handout",
|
104 |
+
"Periodical Index",
|
105 |
+
"Personal Narrative",
|
106 |
+
"Practice Guideline",
|
107 |
+
"Published Erratum",
|
108 |
+
"Technical Report",
|
109 |
+
"Video-Audio Media",
|
110 |
+
"Webcast",
|
111 |
+
]
|
112 |
+
art_type = "(" + " OR ".join(f'"{j}"[Filter]' for j in art_type_list) + ")"
|
113 |
+
query = "( " + query + ")"
|
114 |
+
query += " AND (fha[Filter]) NOT " + art_type
|
115 |
+
|
116 |
+
id_list = await self.esearch(query=query, retmax=retmax)
|
117 |
+
|
118 |
+
if len(id_list) == 0:
|
119 |
+
return []
|
120 |
+
|
121 |
+
return id_list
|
122 |
+
|
123 |
+
async def esearch(self, query=None, retmax=10):
|
124 |
+
start_time = time.time()
|
125 |
+
db = "pubmed"
|
126 |
+
server = "esearch"
|
127 |
+
random_index = int((time.time() * 1000) % len(PUBMED_ACCOUNT))
|
128 |
+
random_pubmed_account = PUBMED_ACCOUNT[random_index]
|
129 |
+
|
130 |
+
api_key = random_pubmed_account["api_key"]
|
131 |
+
url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/{server}.fcgi?db={db}&term={query}&retmode=json&api_key={api_key}&sort=relevance&retmax={retmax}"
|
132 |
+
response = await self.async_http_get(url=url)
|
133 |
+
|
134 |
+
id_list = response["esearchresult"]["idlist"]
|
135 |
+
logger.info(
|
136 |
+
f"pubmed_async_http get id_list, search Time taken: {time.time() - start_time}s"
|
137 |
+
)
|
138 |
+
|
139 |
+
return id_list
|
140 |
+
|
141 |
+
async def async_http_get(self, url: str):
|
142 |
+
async with aiohttp.ClientSession() as session:
|
143 |
+
try_time = 1
|
144 |
+
while try_time < 4:
|
145 |
+
async with session.get(url) as response:
|
146 |
+
if response.status == 200:
|
147 |
+
return await response.json()
|
148 |
+
else:
|
149 |
+
logger.error(
|
150 |
+
f"{url},try_time:{try_time},Error: {response.status}"
|
151 |
+
)
|
152 |
+
try_time += 1
|
153 |
+
# 睡眠0.5秒后重试
|
154 |
+
await asyncio.sleep(0.5)
|
155 |
+
raise Exception(f"Failed to fetch data from {url} after 3 attempts")
|
156 |
+
|
157 |
+
async def async_http_get_text(self, url: str, params=None):
|
158 |
+
async with aiohttp.ClientSession() as session:
|
159 |
+
try_time = 1
|
160 |
+
while try_time < 4:
|
161 |
+
async with session.get(url, params=params) as response:
|
162 |
+
if response.status == 200:
|
163 |
+
|
164 |
+
return await response.text()
|
165 |
+
else:
|
166 |
+
logger.error(
|
167 |
+
f"{url},try_time:{try_time},Error: {response.status}"
|
168 |
+
)
|
169 |
+
try_time += 1
|
170 |
+
# 睡眠0.5秒后重试
|
171 |
+
await asyncio.sleep(0.5)
|
172 |
+
raise Exception(f"Failed to fetch data from {url} after 3 attempts")
|
173 |
+
|
174 |
+
# 获取详细信息(EFetch)
|
175 |
+
async def fetch_details(self, id_list, db="pubmed", rettype="abstract"):
|
176 |
+
start_time = time.time()
|
177 |
+
try:
|
178 |
+
ids = ",".join(id_list)
|
179 |
+
server = "efetch"
|
180 |
+
|
181 |
+
random_index = int((time.time() * 1000) % len(PUBMED_ACCOUNT))
|
182 |
+
random_pubmed_account = PUBMED_ACCOUNT[random_index]
|
183 |
+
api_key = random_pubmed_account["api_key"]
|
184 |
+
url = f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/{server}.fcgi?db={db}&id={ids}&retmode=xml&api_key={api_key}&rettype={rettype}"
|
185 |
+
response = await self.async_http_get_text(url=url)
|
186 |
+
articles = self.pubmed_xml_parse.parse_pubmed_xml(response)
|
187 |
+
logger.info(
|
188 |
+
f"pubmed_async_http fetch detail, Time taken: {time.time() - start_time}"
|
189 |
+
)
|
190 |
+
return articles
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"Error fetching details for id_list: {id_list}, error: {e}")
|
193 |
+
# pmid 精准匹配
|
194 |
+
|
195 |
+
return []
|
python-services/Retrieve/service/pubmed_xml_parse.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import xml.etree.ElementTree as ET
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
class PubmedXmlParse:
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def remove_xml_tags(self, text):
|
10 |
+
"""移除XML标签,返回纯文本"""
|
11 |
+
clean = re.compile('<.*?>')
|
12 |
+
return re.sub(clean, '', text)
|
13 |
+
|
14 |
+
|
15 |
+
# 解析 XML 数据
|
16 |
+
def parse_pubmed_xml(self, xml_data):
|
17 |
+
tree = ET.ElementTree(ET.fromstring(xml_data))
|
18 |
+
root = tree.getroot()
|
19 |
+
|
20 |
+
articles = []
|
21 |
+
|
22 |
+
# 遍历每个 PubmedArticle 元素
|
23 |
+
for article in root.findall(".//PubmedArticle"):
|
24 |
+
# 提取文章信息
|
25 |
+
article_title_elem = article.find(".//ArticleTitle")
|
26 |
+
article_title = ""
|
27 |
+
if article_title_elem is not None:
|
28 |
+
# Convert element to string and decode to handle tags
|
29 |
+
title_text = ET.tostring(article_title_elem, encoding='unicode', method='xml')
|
30 |
+
# Remove the ArticleTitle tags but keep inner content and tags
|
31 |
+
title_text = title_text.replace('<ArticleTitle>', '').replace('</ArticleTitle>', '')
|
32 |
+
# Remove all XML tags to get plain text
|
33 |
+
article_title = self.remove_xml_tags(title_text).strip()
|
34 |
+
|
35 |
+
pmid = (
|
36 |
+
article.find(".//ArticleId[@IdType='pubmed']").text
|
37 |
+
if article.find(".//ArticleId[@IdType='pubmed']") is not None
|
38 |
+
else ""
|
39 |
+
)
|
40 |
+
abstract_texts = article.findall(".//AbstractText")
|
41 |
+
abstract_text = (
|
42 |
+
" ".join(
|
43 |
+
[
|
44 |
+
abstract.text if abstract.text is not None else ""
|
45 |
+
for abstract in abstract_texts
|
46 |
+
]
|
47 |
+
)
|
48 |
+
if abstract_texts
|
49 |
+
else ""
|
50 |
+
)
|
51 |
+
|
52 |
+
# 提取作者信息
|
53 |
+
authors = []
|
54 |
+
for author in article.findall(".//Author"):
|
55 |
+
|
56 |
+
authors.append(
|
57 |
+
{
|
58 |
+
"lastname": (
|
59 |
+
author.find(".//LastName").text
|
60 |
+
if author.find(".//LastName") is not None
|
61 |
+
else ""
|
62 |
+
),
|
63 |
+
"forename": (
|
64 |
+
author.find(".//ForeName").text
|
65 |
+
if author.find(".//ForeName") is not None
|
66 |
+
else ""
|
67 |
+
),
|
68 |
+
"initials": (
|
69 |
+
author.find(".//Initials").text
|
70 |
+
if author.find(".//Initials") is not None
|
71 |
+
else ""
|
72 |
+
),
|
73 |
+
"affiliation": (
|
74 |
+
author.find(".//AffiliationInfo/Affiliation").text
|
75 |
+
if author.find(".//AffiliationInfo/Affiliation") is not None
|
76 |
+
else ""
|
77 |
+
),
|
78 |
+
}
|
79 |
+
)
|
80 |
+
|
81 |
+
journal = {
|
82 |
+
"issn": (
|
83 |
+
article.find(".//Journal/ISSN").text
|
84 |
+
if article.find(".//Journal/ISSN") is not None
|
85 |
+
else ""
|
86 |
+
),
|
87 |
+
"title": (
|
88 |
+
article.find(".//Journal/Title").text
|
89 |
+
if article.find(".//Journal/Title") is not None
|
90 |
+
else ""
|
91 |
+
),
|
92 |
+
"abbreviation": (
|
93 |
+
article.find(".//Journal/ISOAbbreviation").text
|
94 |
+
if article.find(".//Journal/ISOAbbreviation") is not None
|
95 |
+
else ""
|
96 |
+
),
|
97 |
+
"startPage": (
|
98 |
+
article.find(".//Pagination/StartPage").text
|
99 |
+
if article.find(".//Pagination/StartPage") is not None
|
100 |
+
else ""
|
101 |
+
),
|
102 |
+
"endPage": (
|
103 |
+
article.find(".//Pagination/EndPage").text
|
104 |
+
if article.find(".//Pagination/EndPage") is not None
|
105 |
+
else ""
|
106 |
+
),
|
107 |
+
"volume": (
|
108 |
+
article.find(".//Journal/JournalIssue/Volume").text
|
109 |
+
if article.find(".//Journal/JournalIssue/Volume") is not None
|
110 |
+
else ""
|
111 |
+
),
|
112 |
+
"issue": (
|
113 |
+
article.find(".//Journal/JournalIssue/Issue").text
|
114 |
+
if article.find(".//Journal/JournalIssue/Issue") is not None
|
115 |
+
else ""
|
116 |
+
),
|
117 |
+
"year": (
|
118 |
+
article.find(".//Journal/JournalIssue/PubDate/Year").text
|
119 |
+
if article.find(".//Journal/JournalIssue/PubDate/Year") is not None
|
120 |
+
else ""
|
121 |
+
),
|
122 |
+
}
|
123 |
+
medline = article.find("MedlineCitation")
|
124 |
+
references = article.findall(".//PubmedData/ReferenceList/Reference")
|
125 |
+
# 将每篇文章的信息添加到列表中
|
126 |
+
articles.append(
|
127 |
+
{
|
128 |
+
"pmid": pmid,
|
129 |
+
"pmcid": (
|
130 |
+
article.find(
|
131 |
+
".//PubmedData/ArticleIdList/ArticleId[@IdType='pmc']"
|
132 |
+
).text
|
133 |
+
if article.find(
|
134 |
+
".//PubmedData/ArticleIdList/ArticleId[@IdType='pmc']"
|
135 |
+
)
|
136 |
+
is not None
|
137 |
+
else ""
|
138 |
+
),
|
139 |
+
"title": article_title,
|
140 |
+
"abstract": abstract_text,
|
141 |
+
"journal": journal,
|
142 |
+
"authors": authors,
|
143 |
+
"pub_date": {
|
144 |
+
"year": (
|
145 |
+
article.find(".//Journal/JournalIssue/PubDate/Year").text
|
146 |
+
if article.find(".//Journal/JournalIssue/PubDate/Year")
|
147 |
+
is not None
|
148 |
+
else ""
|
149 |
+
),
|
150 |
+
"month": (
|
151 |
+
article.find(".//Journal/JournalIssue/PubDate/Month").text
|
152 |
+
if article.find(".//Journal/JournalIssue/PubDate/Month")
|
153 |
+
is not None
|
154 |
+
else ""
|
155 |
+
),
|
156 |
+
"day": (
|
157 |
+
article.find(".//Journal/JournalIssue/PubDate/Day").text
|
158 |
+
if article.find(".//Journal/JournalIssue/PubDate/Day")
|
159 |
+
is not None
|
160 |
+
else ""
|
161 |
+
),
|
162 |
+
},
|
163 |
+
"keywords": (
|
164 |
+
[k.text for k in medline.findall(".//KeywordList/Keyword")]
|
165 |
+
if medline.findall(".//KeywordList/Keyword") is not None
|
166 |
+
else ""
|
167 |
+
),
|
168 |
+
"doi": self.parse_doi(medline.find("Article"), article),
|
169 |
+
"mesh_terms": [
|
170 |
+
self.parse_mesh(m)
|
171 |
+
for m in medline.findall("MeshHeadingList/MeshHeading")
|
172 |
+
],
|
173 |
+
"references": [self.parse_reference(r) for r in references],
|
174 |
+
}
|
175 |
+
)
|
176 |
+
|
177 |
+
return articles
|
178 |
+
|
179 |
+
def parse_doi(self, article, article_elem) -> str:
|
180 |
+
if article.find(".//ELocationID[@EIdType='doi']") is not None:
|
181 |
+
doi = article.find(".//ELocationID[@EIdType='doi']").text
|
182 |
+
if doi is not None and doi != "":
|
183 |
+
return doi
|
184 |
+
elif article_elem.find(".//ArticleIdList/ArticleId[@IdType='doi']") is not None:
|
185 |
+
doi = article_elem.find(".//ArticleIdList/ArticleId[@IdType='doi']").text
|
186 |
+
if doi is not None and doi != "":
|
187 |
+
return doi
|
188 |
+
else:
|
189 |
+
return ""
|
190 |
+
|
191 |
+
def parse_mesh(self, mesh_elem):
|
192 |
+
"""解析MeSH主题词"""
|
193 |
+
return {
|
194 |
+
"descriptor": (
|
195 |
+
mesh_elem.find(".//DescriptorName").text
|
196 |
+
if mesh_elem.find(".//DescriptorName") is not None
|
197 |
+
else ""
|
198 |
+
),
|
199 |
+
"qualifiers": [
|
200 |
+
(
|
201 |
+
q.find(".//QualifierName").text
|
202 |
+
if q.find(".//QualifierName") is not None
|
203 |
+
else ""
|
204 |
+
)
|
205 |
+
for q in mesh_elem.findall(".//QualifierName")
|
206 |
+
],
|
207 |
+
}
|
208 |
+
|
209 |
+
def parse_reference(self, reference_elem):
|
210 |
+
"""解析参考文献"""
|
211 |
+
return {
|
212 |
+
"citation": (
|
213 |
+
reference_elem.find("Citation").text
|
214 |
+
if reference_elem.find("Citation") is not None
|
215 |
+
else ""
|
216 |
+
),
|
217 |
+
"doi": (
|
218 |
+
reference_elem.find(".//ArticleId[@IdType='doi']").text
|
219 |
+
if reference_elem.find(".//ArticleId[@IdType='doi']") is not None
|
220 |
+
else ""
|
221 |
+
),
|
222 |
+
"pmid": (
|
223 |
+
reference_elem.find(".//ArticleId[@IdType='pubmed']").text
|
224 |
+
if reference_elem.find(".//ArticleId[@IdType='pubmed']") is not None
|
225 |
+
else ""
|
226 |
+
),
|
227 |
+
"pmcid": (
|
228 |
+
reference_elem.find(".//ArticleId[@IdType='pmcid']").text
|
229 |
+
if reference_elem.find(".//ArticleId[@IdType='pmcid']") is not None
|
230 |
+
else ""
|
231 |
+
),
|
232 |
+
}
|
python-services/Retrieve/service/query_rewrite.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from bio_agent.rewrite_agent import RewriteAgent
|
3 |
+
from utils.bio_logger import bio_logger as logger
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
# Instruct
|
7 |
+
INSTRUCTIONS_rewrite = f"""
|
8 |
+
You are a research expert with strong skills in question categorization and optimizing PubMed searches.
|
9 |
+
|
10 |
+
Frist, classify the research question into exactly one of the following categories:
|
11 |
+
- Review: Queries that summarize existing knowledge or literature on a topic.
|
12 |
+
- Question_Answer: Queries that seek specific answers to scientific questions.
|
13 |
+
|
14 |
+
|
15 |
+
Secondly, extract the 3-6 key words of the research question. The key words should be the most important terms or phrases that capture the essence of the research question. These key words should be relevant to the topic and can be used to generate search queries. These key words should be relavant to medicine, biology, health, disease.
|
16 |
+
|
17 |
+
Thirdly,using the given keywords, please identify at least 60 leading authoritative journals in this field, including their names and EISSNs. It would be ok to include journals that are not strictly in the field of medicine, biology, health, or disease, but are relevant to the topic and the journals should be well-known and respected in their respective fields. The EISSN is the electronic International Standard Serial Number for the journal.
|
18 |
+
|
19 |
+
Next, break down this research question into specific search queries for PubMed that comprehensively cover all important aspects of the topic. Generate as many search queries as necessary to ensure thorough coverage - don't limit yourself to a fixed number.
|
20 |
+
|
21 |
+
Each query should:
|
22 |
+
1. Be concise (3-6 words maximum)
|
23 |
+
2. Focus on a specific aspect of the research question
|
24 |
+
3. Use appropriate scientific terminology
|
25 |
+
4. Be suitable for a scientific database search
|
26 |
+
5. Collectively cover the full breadth of the research topic
|
27 |
+
|
28 |
+
If the query's type is review, generate additional queries (10-20) to ensure thorough coverage. If the query's type is question-answer, fewer queries (5-10) may be sufficient.
|
29 |
+
|
30 |
+
Avoid long phrases, questions, or full sentences, as these are not effective for database searches.
|
31 |
+
|
32 |
+
Examples of good queries:
|
33 |
+
- "CRISPR cancer therapy"
|
34 |
+
- "tau protein Alzheimer's"
|
35 |
+
- "microbiome obesity metabolism"
|
36 |
+
|
37 |
+
Then, construct the final PubMed search query based on the following filters:
|
38 |
+
- "date_range": {{"start": "YYYY/MM/DD", "end": "YYYY/MM/DD",}}, only populate this field if the query contains phrases like "the past x years" or "the last x years"; otherwise, leave blank as default.
|
39 |
+
- "article_types": [],array of publication types, only if user specify some publication types, otherwise leave blank as default.
|
40 |
+
- "languages": [],array of language filters,if user do not specify, use English as default.
|
41 |
+
- "subjects": [],if user do not specify, use human as default.
|
42 |
+
- "journals": [], if user do not specify, use [] as default.
|
43 |
+
- "author": [{{"name": string, "first_author": boolean, "last_author": boolean}}], if user do not specify, use {{}} as default.
|
44 |
+
|
45 |
+
|
46 |
+
IMPORTANT: Your output MUST be a valid JSON object with a "queries" field containing an array of strings. For example:
|
47 |
+
```
|
48 |
+
{{ "category": "Review",
|
49 |
+
"key_words":["CRISPR", "cancer", "therapy"],
|
50 |
+
"key_journals":[{{"name":"Nature","EISSN":"1476-4687"}}],
|
51 |
+
"queries": [
|
52 |
+
"CRISPR cancer therapy",
|
53 |
+
"tau protein Alzheimer's",
|
54 |
+
"microbiome obesity metabolism"
|
55 |
+
],
|
56 |
+
"filters": {{"date_range": {{"start": "2019/01/01", "end": "2024/01/01"}},
|
57 |
+
"article_types": [],
|
58 |
+
"languages": ["English"],
|
59 |
+
"subjects": ["human"],
|
60 |
+
"journals": [],
|
61 |
+
"author": {{"name": "", "first_author": false, "last_author": false}}
|
62 |
+
}}
|
63 |
+
}}
|
64 |
+
|
65 |
+
Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only.If you are not sure about the output, output an empty array.
|
66 |
+
|
67 |
+
"""
|
68 |
+
|
69 |
+
SIMPLE_INSTRUCTIONS_rewrite = f"""
|
70 |
+
You are a research expert with strong skills in question categorization and optimizing PubMed searches.
|
71 |
+
Extract the 3-6 key words of the research question. The key words should be the most important terms or phrases that capture the essence of the research question. These key words should be relevant to the topic and can be used to generate search queries. These key words should be relavant to medicine, biology, health, disease.
|
72 |
+
IMPORTANT: Your output MUST be a valid JSON object. For example:
|
73 |
+
```
|
74 |
+
{{
|
75 |
+
"key_words":["CRISPR", "cancer", "therapy"],
|
76 |
+
}}
|
77 |
+
|
78 |
+
Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only.If you are not sure about the output, output an empty array.
|
79 |
+
"""
|
80 |
+
|
81 |
+
|
82 |
+
def build_pubmed_filter_query(data):
|
83 |
+
|
84 |
+
# 基础查询部分(queries的组合)
|
85 |
+
base_query = ""
|
86 |
+
|
87 |
+
# 构建过滤器部分
|
88 |
+
filters = []
|
89 |
+
|
90 |
+
# 日期范围过滤
|
91 |
+
date_range = data["filters"].get("date_range", {})
|
92 |
+
if date_range.get("start") or date_range.get("end"):
|
93 |
+
start_date = date_range.get("start", "1000/01/01") # 很早的日期作为默认
|
94 |
+
end_date = date_range.get("end", datetime.now().strftime("%Y/%m/%d")) # 当前日期作为默认
|
95 |
+
date_filter = f'("{start_date}"[Date - Publication] : "{end_date}"[Date - Publication])'
|
96 |
+
filters.append(date_filter)
|
97 |
+
|
98 |
+
# 文章类型过滤
|
99 |
+
article_types = data["filters"].get("article_types", [])
|
100 |
+
if article_types:
|
101 |
+
type_filter = " OR ".join([f'"{at}"[Publication Type]' for at in article_types])
|
102 |
+
filters.append(f"({type_filter})")
|
103 |
+
|
104 |
+
# 语言过滤
|
105 |
+
languages = data["filters"].get("languages", [])
|
106 |
+
if languages:
|
107 |
+
lang_filter = " OR ".join([f'"{lang}"[Language]' for lang in languages])
|
108 |
+
filters.append(f"({lang_filter})")
|
109 |
+
|
110 |
+
# 主题过滤
|
111 |
+
# subjects = data["filters"].get("subjects", [])
|
112 |
+
# if subjects:
|
113 |
+
# subj_filter = " OR ".join([f'"{subj}"[MeSH Terms]' for subj in subjects])
|
114 |
+
# filters.append(f"({subj_filter})")
|
115 |
+
|
116 |
+
# 期刊过滤
|
117 |
+
journal_names = data["filters"].get("journals", [])
|
118 |
+
if journal_names:
|
119 |
+
journal_filter = " OR ".join([f'"{journal}"[Journal]' for journal in journal_names])
|
120 |
+
filters.append(f"({journal_filter})")
|
121 |
+
|
122 |
+
# 作者过滤
|
123 |
+
author = data["filters"].get("author", {})
|
124 |
+
if author and author.get("name"):
|
125 |
+
author_query = []
|
126 |
+
if author.get("first_author", False):
|
127 |
+
author_query.append(f'"{author["name"]}"[Author - First]')
|
128 |
+
if author.get("last_author", False):
|
129 |
+
author_query.append(f'"{author["name"]}"[Author - Last]')
|
130 |
+
if not author.get("first_author", False) and not author.get("last_author", False):
|
131 |
+
author_query.append(f'"{author["name"]}"[Author]')
|
132 |
+
if author_query:
|
133 |
+
filters.append(f"({' OR '.join(author_query)})")
|
134 |
+
|
135 |
+
# 组合所有过滤器
|
136 |
+
if filters:
|
137 |
+
full_query = " AND ".join(filters)
|
138 |
+
else:
|
139 |
+
full_query = base_query
|
140 |
+
|
141 |
+
return full_query
|
142 |
+
|
143 |
+
|
144 |
+
class QueryRewriteService:
|
145 |
+
def __init__(self):
|
146 |
+
self.rewrite_agent = RewriteAgent()
|
147 |
+
# self.aclient = OPENAI_CLIENT
|
148 |
+
# self.pd_data= pd.read_excel('config/2023JCR(完整).xlsx')
|
149 |
+
# self.pd_data = self.pd_data[["名字", "EISSN"]]
|
150 |
+
|
151 |
+
|
152 |
+
async def query_split(self, query: str):
|
153 |
+
start_time = time.time()
|
154 |
+
query_list = []
|
155 |
+
queries = []
|
156 |
+
key_journals = {"name": "", "EISSN": ""}
|
157 |
+
category = "Review"
|
158 |
+
try_count = 0
|
159 |
+
while try_count < 3:
|
160 |
+
try:
|
161 |
+
query_dict = await self.rewrite_agent.rewrite_query(
|
162 |
+
query, INSTRUCTIONS_rewrite + ' Please note: Today is ' + datetime.now().strftime("%Y/%m/%d") + '.'
|
163 |
+
)
|
164 |
+
logger.info(f"query_dict: {query_dict}")
|
165 |
+
# logger.info(f"query_dict filter: {query_dict['filters']}")
|
166 |
+
if (
|
167 |
+
"queries" not in query_dict
|
168 |
+
or "key_journals" not in query_dict
|
169 |
+
or "category" not in query_dict
|
170 |
+
):
|
171 |
+
logger.error(f"Invalid JSON structure, {query_dict}")
|
172 |
+
|
173 |
+
raise ValueError("Invalid JSON structure")
|
174 |
+
queries = query_dict.get("queries")
|
175 |
+
key_journals = query_dict.get("key_journals")
|
176 |
+
category = query_dict.get("category")
|
177 |
+
key_words = query_dict.get("key_words")
|
178 |
+
journal_list =[]
|
179 |
+
for journal in key_journals:
|
180 |
+
journal_list.append(journal.get("EISSN", ""))
|
181 |
+
journal_list = [
|
182 |
+
f"""("{journal_EISSN}"[Journal])"""
|
183 |
+
for journal_EISSN in journal_list
|
184 |
+
]
|
185 |
+
journal_list += [
|
186 |
+
"(Nature[Journal])",
|
187 |
+
"(Science[Journal])",
|
188 |
+
"(Nature Reviews Methods Primers[Journal])",
|
189 |
+
"(Innovation[Journal])",
|
190 |
+
"(National Science Review[Journal])",
|
191 |
+
"(Nature Communications[Journal])",
|
192 |
+
"(Science Bulletin[Journal])",
|
193 |
+
"(Science Advances[Journal])",
|
194 |
+
"(BMJ[Journal])",
|
195 |
+
]
|
196 |
+
if category == "Review":
|
197 |
+
for sub_query in queries:
|
198 |
+
query_list.append(
|
199 |
+
{
|
200 |
+
"query_item": "( "
|
201 |
+
# + sub_query.strip()
|
202 |
+
+ ' '.join(key_words)
|
203 |
+
# + " ) AND ("
|
204 |
+
# + " OR ".join(journal_list)
|
205 |
+
+ ") AND (fha[Filter]) AND "
|
206 |
+
+ build_pubmed_filter_query(query_dict),
|
207 |
+
"search_type": "advanced",
|
208 |
+
}
|
209 |
+
)
|
210 |
+
query_list.append(
|
211 |
+
{
|
212 |
+
"query_item": "( "
|
213 |
+
+ sub_query.strip()
|
214 |
+
+ " ) AND ("
|
215 |
+
+ " OR ".join(journal_list)
|
216 |
+
+ ") AND (fha[Filter]) AND "
|
217 |
+
+ build_pubmed_filter_query(query_dict),
|
218 |
+
"search_type": "advanced",
|
219 |
+
}
|
220 |
+
)
|
221 |
+
|
222 |
+
else:
|
223 |
+
# query_list.append(
|
224 |
+
# {
|
225 |
+
# "query_item": "( "
|
226 |
+
# + sub_query.strip()
|
227 |
+
# + " ) AND ("
|
228 |
+
# + " OR ".join(journal_list)
|
229 |
+
# + ") AND (fha[Filter]) AND "
|
230 |
+
# + build_pubmed_filter_query(query_dict),
|
231 |
+
# "search_type": "advanced",
|
232 |
+
# }
|
233 |
+
# )
|
234 |
+
query_list.append(
|
235 |
+
{
|
236 |
+
"query_item": "( "
|
237 |
+
# + sub_query.strip()
|
238 |
+
+ ' '.join(key_words)
|
239 |
+
# + " ) AND ("
|
240 |
+
# + " OR ".join(journal_list)
|
241 |
+
+ ") AND (fha[Filter]) AND "
|
242 |
+
+ build_pubmed_filter_query(query_dict),
|
243 |
+
"search_type": "advanced",
|
244 |
+
}
|
245 |
+
)
|
246 |
+
logger.info(
|
247 |
+
f"Original query: {query}, count: {len(query_list)}, wait time: {time.time() - start_time:.2f}s, rewrite result: {query_list}"
|
248 |
+
)
|
249 |
+
return query_list
|
250 |
+
except Exception as e:
|
251 |
+
logger.error(f"Error in query rewrite: {e},trying again...",exc_info=e)
|
252 |
+
try_count += 1
|
253 |
+
time.sleep(0.1)
|
254 |
+
new_try_count = 0
|
255 |
+
logger.info(f"Error in query rewrite,trying a simple version again...")
|
256 |
+
while new_try_count < 3:
|
257 |
+
try:
|
258 |
+
query_dict = await self.rewrite_agent.rewrite_query(
|
259 |
+
query, INSTRUCTIONS_rewrite + ' Please note: Today is ' + datetime.now().strftime("%Y/%m/%d") + '.'
|
260 |
+
)
|
261 |
+
logger.info(f"query_dict: {query_dict}")
|
262 |
+
if "key_words" not in query_dict:
|
263 |
+
logger.error(f"SIMPLE_version:Invalid JSON structure, {query_dict}")
|
264 |
+
raise ValueError("SIMPLE_version:Invalid JSON structure")
|
265 |
+
key_words = query_dict.get("key_words")
|
266 |
+
query_list.append(
|
267 |
+
{
|
268 |
+
"query_item": "( "
|
269 |
+
+ ' '.join(key_words)
|
270 |
+
+ " ) AND (fha[Filter]) AND "
|
271 |
+
+ build_pubmed_filter_query(query_dict),
|
272 |
+
"search_type": "advanced",
|
273 |
+
}
|
274 |
+
)
|
275 |
+
logger.info(
|
276 |
+
f"SIMPLE_version: Original query: {query}, count: {len(query_list)}, wait time: {time.time() - start_time:.2f}s, rewrite result: {query_list}"
|
277 |
+
)
|
278 |
+
return query_list
|
279 |
+
except Exception as e:
|
280 |
+
logger.error(f"SIMPLE_version: Error in query rewrite: {e}")
|
281 |
+
new_try_count += 1
|
282 |
+
time.sleep(0.1)
|
283 |
+
return []
|
284 |
+
async def query_split_for_web(self,query: str):
|
285 |
+
"""
|
286 |
+
For web use, only return the key words.
|
287 |
+
"""
|
288 |
+
start_time = time.time()
|
289 |
+
query_list = []
|
290 |
+
try_count = 0
|
291 |
+
while try_count < 3:
|
292 |
+
try:
|
293 |
+
query_dict = await self.rewrite_agent.rewrite_query(
|
294 |
+
query, INSTRUCTIONS_rewrite + ' Please note: Today is ' + datetime.now().strftime("%Y/%m/%d") + '.',True
|
295 |
+
)
|
296 |
+
logger.info(f"query_dict: {query_dict}")
|
297 |
+
if "key_words" not in query_dict:
|
298 |
+
logger.error(f"SIMPLE_version for web:Invalid JSON structure, {query_dict}")
|
299 |
+
raise ValueError("SIMPLE_version for web:Invalid JSON structure")
|
300 |
+
key_words = query_dict.get("key_words")
|
301 |
+
query_list.append(
|
302 |
+
{
|
303 |
+
"query_item":
|
304 |
+
' '.join(key_words)
|
305 |
+
# + " ) AND (fha[Filter]) AND "
|
306 |
+
# + build_pubmed_filter_query(query_dict),
|
307 |
+
# "search_type": "advanced",
|
308 |
+
}
|
309 |
+
)
|
310 |
+
logger.info(
|
311 |
+
f"SIMPLE_version for web: Original query: {query}, count: {len(query_list)}, wait time: {time.time() - start_time:.2f}s, rewrite result: {query_list}"
|
312 |
+
)
|
313 |
+
return query_list
|
314 |
+
except Exception as e:
|
315 |
+
logger.error(f"SIMPLE_version: Error in query rewrite: {e}")
|
316 |
+
try_count += 1
|
317 |
+
time.sleep(0.1)
|
318 |
+
return [{"query_item": ""}]
|
319 |
+
|
320 |
+
async def query_split_for_simple(self,query: str):
|
321 |
+
"""
|
322 |
+
For simple use, only return the key words.
|
323 |
+
"""
|
324 |
+
start_time = time.time()
|
325 |
+
query_list = []
|
326 |
+
try_count = 0
|
327 |
+
while try_count < 3:
|
328 |
+
try:
|
329 |
+
query_dict = await self.rewrite_agent.rewrite_query(
|
330 |
+
query, SIMPLE_INSTRUCTIONS_rewrite + ' Please note: Today is ' + datetime.now().strftime("%Y/%m/%d") + '.',True
|
331 |
+
)
|
332 |
+
logger.info(f"query_dict: {query_dict}")
|
333 |
+
if "key_words" not in query_dict:
|
334 |
+
logger.error(f"SIMPLE_version for simple:Invalid JSON structure, {query_dict}")
|
335 |
+
raise ValueError("SIMPLE_version for simple:Invalid JSON structure")
|
336 |
+
key_words = query_dict.get("key_words")
|
337 |
+
query_list.append(
|
338 |
+
{
|
339 |
+
"query_item":
|
340 |
+
' '.join(key_words),
|
341 |
+
# + " ) AND (fha[Filter]) AND "
|
342 |
+
# + build_pubmed_filter_query(query_dict),
|
343 |
+
"search_type": "keyword",
|
344 |
+
}
|
345 |
+
)
|
346 |
+
logger.info(
|
347 |
+
f"SIMPLE_version for simple: Original query: {query}, count: {len(query_list)}, wait time: {time.time() - start_time:.2f}s, rewrite result: {query_list}"
|
348 |
+
)
|
349 |
+
return query_list
|
350 |
+
except Exception as e:
|
351 |
+
logger.error(f"SIMPLE_version for simple: Error in query rewrite: {e}")
|
352 |
+
try_count += 1
|
353 |
+
time.sleep(0.1)
|
354 |
+
return [{"query_item": ""}]
|
python-services/Retrieve/service/rag.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import time
|
3 |
+
from typing import List
|
4 |
+
from service.rerank import RerankService
|
5 |
+
from search_service.base_search import BaseSearchService
|
6 |
+
from utils.bio_logger import bio_logger as logger
|
7 |
+
|
8 |
+
from dto.bio_document import BaseBioDocument
|
9 |
+
|
10 |
+
from bio_requests.rag_request import RagRequest
|
11 |
+
|
12 |
+
|
13 |
+
class RagService:
|
14 |
+
def __init__(self):
|
15 |
+
self.rerank_service = RerankService()
|
16 |
+
# 确保所有子类都被加载
|
17 |
+
self.search_services = [
|
18 |
+
subclass() for subclass in BaseSearchService.get_subclasses()
|
19 |
+
]
|
20 |
+
logger.info(
|
21 |
+
f"Loaded search services: {[service.__class__.__name__ for service in self.search_services]}"
|
22 |
+
)
|
23 |
+
|
24 |
+
async def multi_query(self, rag_request: RagRequest) -> List[BaseBioDocument]:
|
25 |
+
start_time = time.time()
|
26 |
+
batch_search = [
|
27 |
+
service.filter_search(rag_request=rag_request)
|
28 |
+
for service in self.search_services
|
29 |
+
]
|
30 |
+
task_result = await asyncio.gather(*batch_search, return_exceptions=True)
|
31 |
+
all_results = []
|
32 |
+
for result in task_result:
|
33 |
+
if isinstance(result, Exception):
|
34 |
+
logger.error(f"Error in search service: {result}")
|
35 |
+
continue
|
36 |
+
all_results.extend(result)
|
37 |
+
end_search_time = time.time()
|
38 |
+
logger.info(
|
39 |
+
f"Found {len(all_results)} results in total,time used:{end_search_time - start_time:.2f}s"
|
40 |
+
)
|
41 |
+
if rag_request.is_rerank:
|
42 |
+
logger.info("RerankService: is_rerank is True")
|
43 |
+
reranked_results = await self.rerank_service.rerank(
|
44 |
+
rag_request=rag_request, documents=all_results
|
45 |
+
)
|
46 |
+
end_rerank_time = time.time()
|
47 |
+
logger.info(
|
48 |
+
f"Reranked {len(reranked_results)} results,time used:{end_rerank_time - end_search_time:.2f}s"
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
logger.info("RerankService: is_rerank is False, skip rerank")
|
52 |
+
reranked_results = all_results
|
53 |
+
|
54 |
+
return reranked_results[0 : rag_request.top_k]
|
python-services/Retrieve/service/rerank.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from bio_requests.rag_request import RagRequest
|
3 |
+
from dto.bio_document import BaseBioDocument
|
4 |
+
from utils.bio_logger import bio_logger as logger
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
# Load the Excel file
|
9 |
+
df = pd.read_excel("config/2023JCR(完整).xlsx")
|
10 |
+
|
11 |
+
# Select only the 'ISSN' and '5年IF' columns
|
12 |
+
df = df[["ISSN", "5年IF", "EISSN"]]
|
13 |
+
|
14 |
+
# Convert '5年IF' to float, setting invalid values to 0.01
|
15 |
+
df["5年IF"] = pd.to_numeric(df["5年IF"], errors="coerce").fillna(0.01)
|
16 |
+
|
17 |
+
|
18 |
+
class RerankService:
|
19 |
+
def __init__(self):
|
20 |
+
|
21 |
+
# Select only the 'ISSN' and '5年IF' columns
|
22 |
+
self.df = df
|
23 |
+
|
24 |
+
async def rerank(
|
25 |
+
self, rag_request: RagRequest, documents: List[BaseBioDocument] = []
|
26 |
+
) -> List[BaseBioDocument]:
|
27 |
+
if not rag_request.data_source or "pubmed" not in rag_request.data_source:
|
28 |
+
logger.info("RerankService: data_source is not pubmed, skip rerank")
|
29 |
+
return documents
|
30 |
+
logger.info("RerankService: start rerank")
|
31 |
+
# Now sorted_documents contains the documents sorted by "5-year IF" from high to low
|
32 |
+
|
33 |
+
# Step 1: Extract ISSN and query the DataFrame for "5-year IF"
|
34 |
+
|
35 |
+
for document in documents:
|
36 |
+
issn = document.journal["issn"]
|
37 |
+
|
38 |
+
# Check if ISSN exists in the 'ISSN' column
|
39 |
+
if_5_year = self.df.loc[self.df["ISSN"] == issn, "5年IF"].values
|
40 |
+
if if_5_year.size > 0:
|
41 |
+
document.if_score = if_5_year[0]
|
42 |
+
else:
|
43 |
+
# If not found in 'ISSN', check the 'EISSN' column
|
44 |
+
if_5_year = self.df.loc[self.df["EISSN"] == issn, "5年IF"].values
|
45 |
+
if if_5_year.size > 0:
|
46 |
+
document.if_score = if_5_year[0]
|
47 |
+
else:
|
48 |
+
document.if_score = None
|
49 |
+
|
50 |
+
# Step 2: De-duplicate the ID of each document in the documents list
|
51 |
+
documents = list({doc.bio_id: doc for doc in documents}.values())
|
52 |
+
|
53 |
+
# Step 3: Sort documents by "5-year IF" in descending order
|
54 |
+
sorted_documents = sorted(
|
55 |
+
documents,
|
56 |
+
key=lambda x: x.if_score if x.if_score is not None else 0.01,
|
57 |
+
reverse=True,
|
58 |
+
)
|
59 |
+
|
60 |
+
return sorted_documents
|
python-services/Retrieve/service/web_search.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import ssl
|
4 |
+
import aiohttp
|
5 |
+
import asyncio
|
6 |
+
from agents import function_tool
|
7 |
+
|
8 |
+
# from ..workers.baseclass import ResearchAgent, ResearchRunner
|
9 |
+
# from ..workers.utils.parse_output import create_type_parser
|
10 |
+
from typing import List, Union, Optional
|
11 |
+
from bs4 import BeautifulSoup
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
from pydantic import BaseModel, Field
|
14 |
+
from crawl4ai import *
|
15 |
+
|
16 |
+
load_dotenv()
|
17 |
+
CONTENT_LENGTH_LIMIT = 10000 # Trim scraped content to this length to avoid large context / token limit issues
|
18 |
+
SEARCH_PROVIDER = os.getenv("SEARCH_PROVIDER", "serper").lower()
|
19 |
+
|
20 |
+
|
21 |
+
# ------- DEFINE TYPES -------
|
22 |
+
|
23 |
+
|
24 |
+
class ScrapeResult(BaseModel):
|
25 |
+
url: str = Field(description="The URL of the webpage")
|
26 |
+
text: str = Field(description="The full text content of the webpage")
|
27 |
+
title: str = Field(description="The title of the webpage")
|
28 |
+
description: str = Field(description="A short description of the webpage")
|
29 |
+
|
30 |
+
|
31 |
+
class WebpageSnippet(BaseModel):
|
32 |
+
url: str = Field(description="The URL of the webpage")
|
33 |
+
title: str = Field(description="The title of the webpage")
|
34 |
+
description: Optional[str] = Field(description="A short description of the webpage")
|
35 |
+
|
36 |
+
|
37 |
+
class SearchResults(BaseModel):
|
38 |
+
results_list: List[WebpageSnippet]
|
39 |
+
|
40 |
+
|
41 |
+
# ------- DEFINE TOOL -------
|
42 |
+
|
43 |
+
# Add a module-level variable to store the singleton instance
|
44 |
+
_serper_client = None
|
45 |
+
|
46 |
+
|
47 |
+
@function_tool
|
48 |
+
async def web_search(query: str) -> Union[List[ScrapeResult], str]:
|
49 |
+
"""Perform a web search for a given query and get back the URLs along with their titles, descriptions and text contents.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
query: The search query
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
List of ScrapeResult objects which have the following fields:
|
56 |
+
- url: The URL of the search result
|
57 |
+
- title: The title of the search result
|
58 |
+
- description: The description of the search result
|
59 |
+
- text: The full text content of the search result
|
60 |
+
"""
|
61 |
+
# Only use SerperClient if search provider is serper
|
62 |
+
if SEARCH_PROVIDER == "openai":
|
63 |
+
# For OpenAI search provider, this function should not be called directly
|
64 |
+
# The WebSearchTool from the agents module will be used instead
|
65 |
+
return f"The web_search function is not used when SEARCH_PROVIDER is set to 'openai'. Please check your configuration."
|
66 |
+
else:
|
67 |
+
try:
|
68 |
+
# Lazy initialization of SerperClient
|
69 |
+
global _serper_client
|
70 |
+
if _serper_client is None:
|
71 |
+
_serper_client = SerperClient()
|
72 |
+
|
73 |
+
search_results = await _serper_client.search(
|
74 |
+
query, filter_for_relevance=True, max_results=5
|
75 |
+
)
|
76 |
+
results = await scrape_urls(search_results)
|
77 |
+
return results
|
78 |
+
except Exception as e:
|
79 |
+
# Return a user-friendly error message
|
80 |
+
return f"Sorry, I encountered an error while searching: {str(e)}"
|
81 |
+
|
82 |
+
|
83 |
+
# ------- DEFINE AGENT FOR FILTERING SEARCH RESULTS BY RELEVANCE -------
|
84 |
+
|
85 |
+
FILTER_AGENT_INSTRUCTIONS = f"""
|
86 |
+
You are a search result filter. Your task is to analyze a list of SERP search results and determine which ones are relevant
|
87 |
+
to the original query based on the link, title and snippet. Return only the relevant results in the specified format.
|
88 |
+
|
89 |
+
- Remove any results that refer to entities that have similar names to the queried entity, but are not the same.
|
90 |
+
- E.g. if the query asks about a company "Amce Inc, acme.com", remove results with "acmesolutions.com" or "acme.net" in the link.
|
91 |
+
|
92 |
+
Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only:
|
93 |
+
{SearchResults.model_json_schema()}
|
94 |
+
"""
|
95 |
+
|
96 |
+
# selected_model = fast_model
|
97 |
+
#
|
98 |
+
# filter_agent = ResearchAgent(
|
99 |
+
# name="SearchFilterAgent",
|
100 |
+
# instructions=FILTER_AGENT_INSTRUCTIONS,
|
101 |
+
# model=selected_model,
|
102 |
+
# output_type=SearchResults if model_supports_structured_output(selected_model) else None,
|
103 |
+
# output_parser=create_type_parser(SearchResults) if not model_supports_structured_output(selected_model) else None
|
104 |
+
# )
|
105 |
+
|
106 |
+
# ------- DEFINE UNDERLYING TOOL LOGIC -------
|
107 |
+
|
108 |
+
# Create a shared connector
|
109 |
+
ssl_context = ssl.create_default_context()
|
110 |
+
ssl_context.check_hostname = False
|
111 |
+
ssl_context.verify_mode = ssl.CERT_NONE
|
112 |
+
ssl_context.set_ciphers(
|
113 |
+
"DEFAULT:@SECLEVEL=1"
|
114 |
+
) # Add this line to allow older cipher suites
|
115 |
+
|
116 |
+
|
117 |
+
class SerperClient:
|
118 |
+
"""A client for the Serper API to perform Google searches."""
|
119 |
+
|
120 |
+
def __init__(self, api_key: str = None):
|
121 |
+
self.api_key = api_key or os.getenv("SERPER_API_KEY")
|
122 |
+
if not self.api_key:
|
123 |
+
raise ValueError(
|
124 |
+
"No API key provided. Set SERPER_API_KEY environment variable."
|
125 |
+
)
|
126 |
+
|
127 |
+
self.url = "https://google.serper.dev/search"
|
128 |
+
self.headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
|
129 |
+
|
130 |
+
async def search(
|
131 |
+
self, query: str, filter_for_relevance: bool = True, max_results: int = 5
|
132 |
+
) -> List[WebpageSnippet]:
|
133 |
+
"""Perform a Google search using Serper API and fetch basic details for top results.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
query: The search query
|
137 |
+
num_results: Maximum number of results to return (max 10)
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
Dictionary with search results
|
141 |
+
"""
|
142 |
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
143 |
+
async with aiohttp.ClientSession(connector=connector) as session:
|
144 |
+
async with session.post(
|
145 |
+
self.url, headers=self.headers, json={"q": query, "autocorrect": False}
|
146 |
+
) as response:
|
147 |
+
response.raise_for_status()
|
148 |
+
results = await response.json()
|
149 |
+
results_list = [
|
150 |
+
WebpageSnippet(
|
151 |
+
url=result.get("link", ""),
|
152 |
+
title=result.get("title", ""),
|
153 |
+
description=result.get("snippet", ""),
|
154 |
+
)
|
155 |
+
for result in results.get("organic", [])
|
156 |
+
]
|
157 |
+
|
158 |
+
if not results_list:
|
159 |
+
return []
|
160 |
+
|
161 |
+
if not filter_for_relevance:
|
162 |
+
return results_list[:max_results]
|
163 |
+
|
164 |
+
# return results_list[:max_results]
|
165 |
+
|
166 |
+
return await self._filter_results(results_list, query, max_results=max_results)
|
167 |
+
|
168 |
+
async def _filter_results(
|
169 |
+
self, results: List[WebpageSnippet], query: str, max_results: int = 5
|
170 |
+
) -> List[WebpageSnippet]:
|
171 |
+
# get rid of pubmed source data
|
172 |
+
filtered_results = [
|
173 |
+
res
|
174 |
+
for res in results
|
175 |
+
if "pmc.ncbi.nlm.nih.gov" not in res.url
|
176 |
+
and "pubmed.ncbi.nlm.nih.gov" not in res.url
|
177 |
+
]
|
178 |
+
|
179 |
+
# # get rid of unrelated data
|
180 |
+
# serialized_results = [result.model_dump() if isinstance(result, WebpageSnippet) else result for result in
|
181 |
+
# filtered_results]
|
182 |
+
#
|
183 |
+
# user_prompt = f"""
|
184 |
+
# Original search query: {query}
|
185 |
+
#
|
186 |
+
# Search results to analyze:
|
187 |
+
# {json.dumps(serialized_results, indent=2)}
|
188 |
+
#
|
189 |
+
# Return {max_results} search results or less.
|
190 |
+
# """
|
191 |
+
#
|
192 |
+
# try:
|
193 |
+
# result = await ResearchRunner.run(filter_agent, user_prompt)
|
194 |
+
# output = result.final_output_as(SearchResults)
|
195 |
+
# return output.results_list
|
196 |
+
# except Exception as e:
|
197 |
+
# print("Error filtering urls:", str(e))
|
198 |
+
# return filtered_results[:max_results]
|
199 |
+
|
200 |
+
async def fetch_url(session, url):
|
201 |
+
try:
|
202 |
+
async with session.get(url, timeout=5) as response:
|
203 |
+
return response.status == 200
|
204 |
+
except Exception as e:
|
205 |
+
print(f"Error accessing {url}: {str(e)}")
|
206 |
+
return False # 返回 False 表示不可访问
|
207 |
+
|
208 |
+
async def filter_unreachable_urls(results):
|
209 |
+
async with aiohttp.ClientSession() as session:
|
210 |
+
tasks = [fetch_url(session, res.url) for res in results]
|
211 |
+
reachable = await asyncio.gather(*tasks)
|
212 |
+
return [
|
213 |
+
res for res, can_access in zip(results, reachable) if can_access
|
214 |
+
]
|
215 |
+
|
216 |
+
reachable_results = await filter_unreachable_urls(filtered_results)
|
217 |
+
|
218 |
+
# Return the first `max_results` or less if there are not enough reachable results
|
219 |
+
return reachable_results[:max_results]
|
220 |
+
|
221 |
+
|
222 |
+
async def scrape_urls(items: List[WebpageSnippet]) -> List[ScrapeResult]:
|
223 |
+
"""Fetch text content from provided URLs.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
items: List of SearchEngineResult items to extract content from
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
List of ScrapeResult objects which have the following fields:
|
230 |
+
- url: The URL of the search result
|
231 |
+
- title: The title of the search result
|
232 |
+
- description: The description of the search result
|
233 |
+
- text: The full text content of the search result
|
234 |
+
"""
|
235 |
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
236 |
+
async with aiohttp.ClientSession(connector=connector) as session:
|
237 |
+
# Create list of tasks for concurrent execution
|
238 |
+
tasks = []
|
239 |
+
for item in items:
|
240 |
+
if item.url: # Skip empty URLs
|
241 |
+
tasks.append(fetch_and_process_url(session, item))
|
242 |
+
|
243 |
+
# Execute all tasks concurrently and gather results
|
244 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
245 |
+
|
246 |
+
# Filter out errors and return successful results
|
247 |
+
return [r for r in results if isinstance(r, ScrapeResult)]
|
248 |
+
|
249 |
+
|
250 |
+
async def fetch_and_process_url(
|
251 |
+
session: aiohttp.ClientSession, item: WebpageSnippet
|
252 |
+
) -> ScrapeResult:
|
253 |
+
"""Helper function to fetch and process a single URL."""
|
254 |
+
|
255 |
+
if not is_valid_url(item.url):
|
256 |
+
return ScrapeResult(
|
257 |
+
url=item.url,
|
258 |
+
title=item.title,
|
259 |
+
description=item.description,
|
260 |
+
text=f"Error fetching content: URL contains restricted file extension",
|
261 |
+
)
|
262 |
+
|
263 |
+
try:
|
264 |
+
async with session.get(item.url, timeout=8) as response:
|
265 |
+
if response.status == 200:
|
266 |
+
content = await response.text()
|
267 |
+
# Run html_to_text in a thread pool to avoid blocking
|
268 |
+
text_content = await asyncio.get_event_loop().run_in_executor(
|
269 |
+
None, html_to_text, content
|
270 |
+
)
|
271 |
+
text_content = text_content[
|
272 |
+
:CONTENT_LENGTH_LIMIT
|
273 |
+
] # Trim content to avoid exceeding token limit
|
274 |
+
return ScrapeResult(
|
275 |
+
url=item.url,
|
276 |
+
title=item.title,
|
277 |
+
description=item.description,
|
278 |
+
text=text_content,
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
# Instead of raising, return a WebSearchResult with an error message
|
282 |
+
return ScrapeResult(
|
283 |
+
url=item.url,
|
284 |
+
title=item.title,
|
285 |
+
description=item.description,
|
286 |
+
text=f"Error fetching content: HTTP {response.status}",
|
287 |
+
)
|
288 |
+
except Exception as e:
|
289 |
+
# Instead of raising, return a WebSearchResult with an error message
|
290 |
+
return ScrapeResult(
|
291 |
+
url=item.url,
|
292 |
+
title=item.title,
|
293 |
+
description=item.description,
|
294 |
+
text=f"Error fetching content: {str(e)}",
|
295 |
+
)
|
296 |
+
|
297 |
+
|
298 |
+
def html_to_text(html_content: str) -> str:
|
299 |
+
"""
|
300 |
+
Strips out all of the unnecessary elements from the HTML context to prepare it for text extraction / LLM processing.
|
301 |
+
"""
|
302 |
+
# Parse the HTML using lxml for speed
|
303 |
+
soup = BeautifulSoup(html_content, "lxml")
|
304 |
+
|
305 |
+
# Extract text from relevant tags
|
306 |
+
tags_to_extract = ("h1", "h2", "h3", "h4", "h5", "h6", "p", "li", "blockquote")
|
307 |
+
|
308 |
+
# Use a generator expression for efficiency
|
309 |
+
extracted_text = "\n".join(
|
310 |
+
element.get_text(strip=True)
|
311 |
+
for element in soup.find_all(tags_to_extract)
|
312 |
+
if element.get_text(strip=True)
|
313 |
+
)
|
314 |
+
|
315 |
+
return extracted_text
|
316 |
+
|
317 |
+
|
318 |
+
def is_valid_url(url: str) -> bool:
|
319 |
+
"""Check that a URL does not contain restricted file extensions."""
|
320 |
+
if any(
|
321 |
+
ext in url
|
322 |
+
for ext in [
|
323 |
+
".pdf",
|
324 |
+
".doc",
|
325 |
+
".xls",
|
326 |
+
".ppt",
|
327 |
+
".zip",
|
328 |
+
".rar",
|
329 |
+
".7z",
|
330 |
+
".txt",
|
331 |
+
".js",
|
332 |
+
".xml",
|
333 |
+
".css",
|
334 |
+
".png",
|
335 |
+
".jpg",
|
336 |
+
".jpeg",
|
337 |
+
".gif",
|
338 |
+
".ico",
|
339 |
+
".svg",
|
340 |
+
".webp",
|
341 |
+
".mp3",
|
342 |
+
".mp4",
|
343 |
+
".avi",
|
344 |
+
".mov",
|
345 |
+
".wmv",
|
346 |
+
".flv",
|
347 |
+
".wma",
|
348 |
+
".wav",
|
349 |
+
".m4a",
|
350 |
+
".m4v",
|
351 |
+
".m4b",
|
352 |
+
".m4p",
|
353 |
+
".m4u",
|
354 |
+
]
|
355 |
+
):
|
356 |
+
return False
|
357 |
+
return True
|
358 |
+
|
359 |
+
|
360 |
+
async def url_to_contents(url):
|
361 |
+
async with AsyncWebCrawler() as crawler:
|
362 |
+
result = await crawler.arun(
|
363 |
+
url=url,
|
364 |
+
)
|
365 |
+
# print(result.markdown)
|
366 |
+
|
367 |
+
return result.markdown
|
368 |
+
|
369 |
+
|
370 |
+
async def url_to_fit_contents(res):
|
371 |
+
|
372 |
+
str_fit_max = 40000 # 40,000字符通常在10,000token,5个合起来不超过50k
|
373 |
+
|
374 |
+
browser_config = BrowserConfig(
|
375 |
+
headless=True,
|
376 |
+
verbose=True,
|
377 |
+
)
|
378 |
+
run_config = CrawlerRunConfig(
|
379 |
+
cache_mode=CacheMode.DISABLED,
|
380 |
+
markdown_generator=DefaultMarkdownGenerator(
|
381 |
+
content_filter=PruningContentFilter(
|
382 |
+
threshold=1.0, threshold_type="fixed", min_word_threshold=0
|
383 |
+
)
|
384 |
+
),
|
385 |
+
# markdown_generator=DefaultMarkdownGenerator(
|
386 |
+
# content_filter=BM25ContentFilter(user_query="WHEN_WE_FOCUS_BASED_ON_A_USER_QUERY", bm25_threshold=1.0)
|
387 |
+
# ),
|
388 |
+
)
|
389 |
+
|
390 |
+
try:
|
391 |
+
async with AsyncWebCrawler(config=browser_config) as crawler:
|
392 |
+
# 使用 asyncio.wait_for 来设置超时
|
393 |
+
result = await asyncio.wait_for(
|
394 |
+
crawler.arun(url=res.url, config=run_config), timeout=15 # 设置超时
|
395 |
+
)
|
396 |
+
print(f"char before filtering {len(result.markdown.raw_markdown)}.")
|
397 |
+
print(f"char after filtering {len(result.markdown.fit_markdown)}.")
|
398 |
+
return result.markdown.fit_markdown[
|
399 |
+
:str_fit_max
|
400 |
+
] # 如果成功,返回结果的前str_fit_max个字符
|
401 |
+
except asyncio.TimeoutError:
|
402 |
+
print(f"Timeout occurred while accessing {res.url}.") # 打印超时信息
|
403 |
+
return res.text[:str_fit_max] # 如果发生超时,返回res粗略提取
|
404 |
+
except Exception as e:
|
405 |
+
print(f"Exception occurred: {str(e)}") # 打印其他异常信息
|
406 |
+
return res.text[:str_fit_max] # 如果发生其他异常,返回res粗略提取
|
python-services/Retrieve/utils/bio_logger.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BioLogger - A comprehensive logging utility for the bio RAG server.
|
3 |
+
|
4 |
+
This module provides a centralized logging system with correlation ID support,
|
5 |
+
structured logging, and configurable output handlers.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import sys
|
9 |
+
import traceback
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, Optional
|
12 |
+
|
13 |
+
from asgi_correlation_id import correlation_id
|
14 |
+
from loguru import logger
|
15 |
+
|
16 |
+
|
17 |
+
class BioLogger:
|
18 |
+
"""
|
19 |
+
Enhanced logging utility with correlation ID support and structured logging.
|
20 |
+
|
21 |
+
This class provides a unified interface for logging with automatic
|
22 |
+
correlation ID binding and comprehensive error tracking.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, log_dir: str = "logs", max_retention_days: int = 30):
|
26 |
+
"""
|
27 |
+
Initialize the BioLogger.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
log_dir: Directory to store log files
|
31 |
+
max_retention_days: Maximum number of days to retain log files
|
32 |
+
"""
|
33 |
+
self.log_dir = Path(log_dir)
|
34 |
+
self.max_retention_days = max_retention_days
|
35 |
+
self._setup_logging()
|
36 |
+
|
37 |
+
def _setup_logging(self) -> None:
|
38 |
+
"""Configure loguru logger with handlers."""
|
39 |
+
# Remove default handler
|
40 |
+
logger.remove()
|
41 |
+
|
42 |
+
# Create log directory
|
43 |
+
self.log_dir.mkdir(exist_ok=True)
|
44 |
+
|
45 |
+
# Terminal handler
|
46 |
+
logger.add(
|
47 |
+
sys.stderr,
|
48 |
+
format=self._get_format_string(),
|
49 |
+
level="INFO",
|
50 |
+
colorize=True,
|
51 |
+
backtrace=True,
|
52 |
+
diagnose=True,
|
53 |
+
)
|
54 |
+
|
55 |
+
# File handlers
|
56 |
+
log_file = self.log_dir / "bio_rag_{time:YYYY-MM-DD}.log"
|
57 |
+
|
58 |
+
# Info level file handler
|
59 |
+
logger.add(
|
60 |
+
str(log_file),
|
61 |
+
format=self._get_format_string(),
|
62 |
+
level="INFO",
|
63 |
+
rotation="1 day",
|
64 |
+
retention=f"{self.max_retention_days} days",
|
65 |
+
compression="zip",
|
66 |
+
backtrace=True,
|
67 |
+
diagnose=True,
|
68 |
+
)
|
69 |
+
|
70 |
+
# Error level file handler
|
71 |
+
logger.add(
|
72 |
+
str(log_file),
|
73 |
+
format=self._get_format_string(),
|
74 |
+
level="ERROR",
|
75 |
+
rotation="1 day",
|
76 |
+
retention=f"{self.max_retention_days} days",
|
77 |
+
compression="zip",
|
78 |
+
backtrace=True,
|
79 |
+
diagnose=True,
|
80 |
+
)
|
81 |
+
|
82 |
+
def _get_format_string(self) -> str:
|
83 |
+
"""Get the log format string with correlation ID."""
|
84 |
+
return "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | [CID:{extra[correlation_id]}] | {name}:{function}:{line} | {message}"
|
85 |
+
|
86 |
+
def _get_correlation_id(self) -> str:
|
87 |
+
"""Get the current correlation ID or return SYSTEM."""
|
88 |
+
return correlation_id.get() or "SYSTEM"
|
89 |
+
|
90 |
+
def _bind_logger(self):
|
91 |
+
"""Bind logger with current correlation ID."""
|
92 |
+
return logger.bind(correlation_id=self._get_correlation_id())
|
93 |
+
|
94 |
+
def debug(self, message: str, **kwargs: Any) -> None:
|
95 |
+
"""
|
96 |
+
Log a debug message.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
message: The message to log
|
100 |
+
**kwargs: Additional context data
|
101 |
+
"""
|
102 |
+
self._bind_logger().debug(message, **kwargs)
|
103 |
+
|
104 |
+
def info(self, message: str, **kwargs: Any) -> None:
|
105 |
+
"""
|
106 |
+
Log an info message.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
message: The message to log
|
110 |
+
**kwargs: Additional context data
|
111 |
+
"""
|
112 |
+
self._bind_logger().info(message, **kwargs)
|
113 |
+
|
114 |
+
def warning(self, message: str, **kwargs: Any) -> None:
|
115 |
+
"""
|
116 |
+
Log a warning message.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
message: The message to log
|
120 |
+
**kwargs: Additional context data
|
121 |
+
"""
|
122 |
+
self._bind_logger().warning(message, **kwargs)
|
123 |
+
|
124 |
+
def error(
|
125 |
+
self, message: str, exc_info: Optional[Exception] = None, **kwargs: Any
|
126 |
+
) -> None:
|
127 |
+
"""
|
128 |
+
Log an error message with optional exception information.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
message: The error message
|
132 |
+
exc_info: Optional exception object for detailed error tracking
|
133 |
+
**kwargs: Additional context data
|
134 |
+
"""
|
135 |
+
if exc_info is not None:
|
136 |
+
error_details = self._format_exception_details(message, exc_info)
|
137 |
+
self._bind_logger().error(error_details, **kwargs)
|
138 |
+
else:
|
139 |
+
self._bind_logger().error(message, **kwargs)
|
140 |
+
|
141 |
+
def critical(
|
142 |
+
self, message: str, exc_info: Optional[Exception] = None, **kwargs: Any
|
143 |
+
) -> None:
|
144 |
+
"""
|
145 |
+
Log a critical error message.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
message: The critical error message
|
149 |
+
exc_info: Optional exception object for detailed error tracking
|
150 |
+
**kwargs: Additional context data
|
151 |
+
"""
|
152 |
+
if exc_info is not None:
|
153 |
+
error_details = self._format_exception_details(message, exc_info)
|
154 |
+
self._bind_logger().critical(error_details, **kwargs)
|
155 |
+
else:
|
156 |
+
self._bind_logger().critical(message, **kwargs)
|
157 |
+
|
158 |
+
def _format_exception_details(self, message: str, exc_info: Exception) -> str:
|
159 |
+
"""
|
160 |
+
Format exception details for logging.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
message: The base error message
|
164 |
+
exc_info: The exception object
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Formatted error details string
|
168 |
+
"""
|
169 |
+
exc_type = exc_info.__class__.__name__
|
170 |
+
exc_message = str(exc_info)
|
171 |
+
|
172 |
+
# Get stack trace
|
173 |
+
stack_trace = []
|
174 |
+
if exc_info.__traceback__:
|
175 |
+
tb_list = traceback.extract_tb(exc_info.__traceback__)
|
176 |
+
for tb in tb_list:
|
177 |
+
stack_trace.append(
|
178 |
+
f" File: {tb.filename}, "
|
179 |
+
f"Line: {tb.lineno}, "
|
180 |
+
f"Function: {tb.name}"
|
181 |
+
)
|
182 |
+
|
183 |
+
# Format error details
|
184 |
+
error_details = [
|
185 |
+
f"Error Message: {message}",
|
186 |
+
f"Exception Type: {exc_type}",
|
187 |
+
f"Exception Details: {exc_message}",
|
188 |
+
]
|
189 |
+
|
190 |
+
if stack_trace:
|
191 |
+
error_details.append("Stack Trace:")
|
192 |
+
error_details.extend(stack_trace)
|
193 |
+
|
194 |
+
return "\n".join(error_details)
|
195 |
+
|
196 |
+
def log_performance(self, operation: str, duration: float, **kwargs: Any) -> None:
|
197 |
+
"""
|
198 |
+
Log performance metrics.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
operation: Name of the operation
|
202 |
+
duration: Duration in seconds
|
203 |
+
**kwargs: Additional performance metrics
|
204 |
+
"""
|
205 |
+
message = f"Performance: {operation} took {duration:.3f}s"
|
206 |
+
if kwargs:
|
207 |
+
metrics = ", ".join(f"{k}={v}" for k, v in kwargs.items())
|
208 |
+
message += f" | {metrics}"
|
209 |
+
|
210 |
+
self.info(message)
|
211 |
+
|
212 |
+
def log_api_call(
|
213 |
+
self, method: str, url: str, status_code: int, duration: float
|
214 |
+
) -> None:
|
215 |
+
"""
|
216 |
+
Log API call details.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
method: HTTP method
|
220 |
+
url: API endpoint URL
|
221 |
+
status_code: HTTP status code
|
222 |
+
duration: Request duration in seconds
|
223 |
+
"""
|
224 |
+
level = "error" if status_code >= 400 else "info"
|
225 |
+
message = f"API Call: {method} {url} -> {status_code} ({duration:.3f}s)"
|
226 |
+
|
227 |
+
if level == "error":
|
228 |
+
self.error(message)
|
229 |
+
else:
|
230 |
+
self.info(message)
|
231 |
+
|
232 |
+
def log_database_operation(
|
233 |
+
self, operation: str, table: str, duration: float, **kwargs: Any
|
234 |
+
) -> None:
|
235 |
+
"""
|
236 |
+
Log database operation details.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
operation: Database operation (SELECT, INSERT, etc.)
|
240 |
+
table: Table name
|
241 |
+
duration: Operation duration in seconds
|
242 |
+
**kwargs: Additional operation details
|
243 |
+
"""
|
244 |
+
message = f"Database: {operation} on {table} took {duration:.3f}s"
|
245 |
+
if kwargs:
|
246 |
+
details = ", ".join(f"{k}={v}" for k, v in kwargs.items())
|
247 |
+
message += f" | {details}"
|
248 |
+
|
249 |
+
self.info(message)
|
250 |
+
|
251 |
+
|
252 |
+
# Create singleton instance
|
253 |
+
bio_logger = BioLogger()
|
python-services/Retrieve/utils/http_util.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
HTTP utility functions for making synchronous and asynchronous HTTP requests.
|
3 |
+
|
4 |
+
This module provides a unified interface for HTTP operations using httpx,
|
5 |
+
with proper error handling, timeout configuration, and retry logic.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import asyncio
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import uuid
|
12 |
+
from typing import Any, Dict, Optional
|
13 |
+
|
14 |
+
import httpx
|
15 |
+
|
16 |
+
from utils.bio_logger import bio_logger as logger
|
17 |
+
|
18 |
+
|
19 |
+
class HTTPError(Exception):
|
20 |
+
"""Custom exception for HTTP-related errors."""
|
21 |
+
|
22 |
+
def __init__(self, status_code: int, message: str, url: str):
|
23 |
+
self.status_code = status_code
|
24 |
+
self.message = message
|
25 |
+
self.url = url
|
26 |
+
super().__init__(f"HTTP {status_code}: {message} for {url}")
|
27 |
+
|
28 |
+
|
29 |
+
def _create_timeout(timeout: float = 10.0) -> httpx.Timeout:
|
30 |
+
"""Create a timeout configuration for HTTP requests."""
|
31 |
+
return httpx.Timeout(timeout, connect=5.0)
|
32 |
+
|
33 |
+
|
34 |
+
def _handle_response(response: httpx.Response, url: str) -> Any:
|
35 |
+
"""Handle HTTP response and raise appropriate exceptions."""
|
36 |
+
if response.status_code == 200:
|
37 |
+
return response.json()
|
38 |
+
|
39 |
+
logger.error(f"HTTP request failed: {response.status_code} for {url}")
|
40 |
+
raise HTTPError(
|
41 |
+
status_code=response.status_code,
|
42 |
+
message=f"Request failed with status {response.status_code}",
|
43 |
+
url=url,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
async def async_http_get(
|
48 |
+
url: str,
|
49 |
+
params: Optional[Dict[str, Any]] = None,
|
50 |
+
timeout: float = 10.0,
|
51 |
+
headers: Optional[Dict[str, str]] = None,
|
52 |
+
) -> Any:
|
53 |
+
"""
|
54 |
+
Make an asynchronous HTTP GET request.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
url: The URL to make the request to
|
58 |
+
params: Query parameters to include in the request
|
59 |
+
timeout: Request timeout in seconds
|
60 |
+
headers: Optional headers to include in the request
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
The JSON response from the server
|
64 |
+
|
65 |
+
Raises:
|
66 |
+
HTTPError: If the request fails
|
67 |
+
httpx.RequestError: If there's a network error
|
68 |
+
"""
|
69 |
+
timeout_config = _create_timeout(timeout)
|
70 |
+
start_time = time.time()
|
71 |
+
|
72 |
+
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
73 |
+
response = await client.get(url=url, params=params, headers=headers)
|
74 |
+
duration = time.time() - start_time
|
75 |
+
|
76 |
+
# Log the API call
|
77 |
+
logger.log_api_call("GET", url, response.status_code, duration)
|
78 |
+
|
79 |
+
return _handle_response(response, url)
|
80 |
+
|
81 |
+
|
82 |
+
def http_get(
|
83 |
+
url: str,
|
84 |
+
params: Optional[Dict[str, Any]] = None,
|
85 |
+
timeout: float = 10.0,
|
86 |
+
headers: Optional[Dict[str, str]] = None,
|
87 |
+
) -> Any:
|
88 |
+
"""
|
89 |
+
Make a synchronous HTTP GET request.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
url: The URL to make the request to
|
93 |
+
params: Query parameters to include in the request
|
94 |
+
timeout: Request timeout in seconds
|
95 |
+
headers: Optional headers to include in the request
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
The JSON response from the server
|
99 |
+
|
100 |
+
Raises:
|
101 |
+
HTTPError: If the request fails
|
102 |
+
httpx.RequestError: If there's a network error
|
103 |
+
"""
|
104 |
+
timeout_config = _create_timeout(timeout)
|
105 |
+
start_time = time.time()
|
106 |
+
|
107 |
+
with httpx.Client(timeout=timeout_config) as client:
|
108 |
+
response = client.get(url=url, params=params, headers=headers)
|
109 |
+
duration = time.time() - start_time
|
110 |
+
|
111 |
+
# Log the API call
|
112 |
+
logger.log_api_call("GET", url, response.status_code, duration)
|
113 |
+
|
114 |
+
return _handle_response(response, url)
|
115 |
+
|
116 |
+
|
117 |
+
def http_post(
|
118 |
+
url: str, data: Any, headers: Optional[Dict[str, Any]] = None, timeout: float = 10.0
|
119 |
+
) -> Any:
|
120 |
+
"""
|
121 |
+
Make a synchronous HTTP POST request.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
url: The URL to make the request to
|
125 |
+
data: The data to send in the request body
|
126 |
+
headers: Optional headers to include in the request
|
127 |
+
timeout: Request timeout in seconds
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
The JSON response from the server
|
131 |
+
|
132 |
+
Raises:
|
133 |
+
HTTPError: If the request fails
|
134 |
+
httpx.RequestError: If there's a network error
|
135 |
+
"""
|
136 |
+
timeout_config = _create_timeout(timeout)
|
137 |
+
start_time = time.time()
|
138 |
+
|
139 |
+
with httpx.Client(timeout=timeout_config) as client:
|
140 |
+
response = client.post(url=url, json=data, headers=headers)
|
141 |
+
duration = time.time() - start_time
|
142 |
+
|
143 |
+
# Log the API call
|
144 |
+
logger.log_api_call("POST", url, response.status_code, duration)
|
145 |
+
|
146 |
+
return _handle_response(response, url)
|
147 |
+
|
148 |
+
|
149 |
+
async def async_http_post(
|
150 |
+
url: str,
|
151 |
+
data: Any,
|
152 |
+
headers: Optional[Dict[str, Any]] = None,
|
153 |
+
timeout: float = 10.0,
|
154 |
+
max_retries: int = 3,
|
155 |
+
retry_delay: float = 0.5,
|
156 |
+
) -> Any:
|
157 |
+
"""
|
158 |
+
Make an asynchronous HTTP POST request with retry logic.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
url: The URL to make the request to
|
162 |
+
data: The data to send in the request body
|
163 |
+
headers: Optional headers to include in the request
|
164 |
+
timeout: Request timeout in seconds
|
165 |
+
max_retries: Maximum number of retry attempts
|
166 |
+
retry_delay: Delay between retries in seconds
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
The JSON response from the server
|
170 |
+
|
171 |
+
Raises:
|
172 |
+
HTTPError: If the request fails after all retries
|
173 |
+
httpx.RequestError: If there's a network error
|
174 |
+
"""
|
175 |
+
timeout_config = _create_timeout(timeout)
|
176 |
+
|
177 |
+
async with httpx.AsyncClient(timeout=timeout_config) as client:
|
178 |
+
for attempt in range(1, max_retries + 1):
|
179 |
+
try:
|
180 |
+
start_time = time.time()
|
181 |
+
response = await client.post(url=url, json=data, headers=headers)
|
182 |
+
duration = time.time() - start_time
|
183 |
+
|
184 |
+
# Log the API call
|
185 |
+
logger.log_api_call("POST", url, response.status_code, duration)
|
186 |
+
|
187 |
+
if response.status_code == 200:
|
188 |
+
return response.json()
|
189 |
+
else:
|
190 |
+
logger.error(
|
191 |
+
f"HTTP POST failed (attempt {attempt}/{max_retries}): "
|
192 |
+
f"{response.status_code} for {url}"
|
193 |
+
)
|
194 |
+
if attempt < max_retries:
|
195 |
+
await asyncio.sleep(retry_delay)
|
196 |
+
else:
|
197 |
+
raise HTTPError(
|
198 |
+
status_code=response.status_code,
|
199 |
+
message=f"Request failed after {max_retries} attempts",
|
200 |
+
url=url,
|
201 |
+
)
|
202 |
+
except httpx.RequestError as e:
|
203 |
+
logger.error(f"Network error on attempt {attempt}: {e}")
|
204 |
+
if attempt < max_retries:
|
205 |
+
await asyncio.sleep(retry_delay)
|
206 |
+
else:
|
207 |
+
raise HTTPError(
|
208 |
+
status_code=0,
|
209 |
+
message=f"Network error after {max_retries} attempts: {str(e)}",
|
210 |
+
url=url,
|
211 |
+
) from e
|
212 |
+
|
213 |
+
raise HTTPError(
|
214 |
+
status_code=0,
|
215 |
+
message=f"Failed to fetch data from {url} after {max_retries} attempts",
|
216 |
+
url=url,
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
def download_file(
|
221 |
+
file_url: str, directory_path: str, timeout: int = 60, verify_ssl: bool = True
|
222 |
+
) -> Optional[str]:
|
223 |
+
"""
|
224 |
+
Download a file from a URL to a local directory.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
file_url: The URL of the file to download
|
228 |
+
directory_path: The directory to save the file in
|
229 |
+
timeout: Request timeout in seconds
|
230 |
+
verify_ssl: Whether to verify SSL certificates
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
The path to the downloaded file, or None if download failed
|
234 |
+
"""
|
235 |
+
# Extract file extension from URL
|
236 |
+
file_extension = file_url.split(".")[-1].split("?")[0] # Remove query params
|
237 |
+
random_filename = f"{uuid.uuid4()}.{file_extension}"
|
238 |
+
|
239 |
+
# Create directory if it doesn't exist
|
240 |
+
os.makedirs(directory_path, exist_ok=True)
|
241 |
+
file_path = os.path.join(directory_path, random_filename)
|
242 |
+
|
243 |
+
try:
|
244 |
+
with httpx.Client(timeout=timeout, verify=verify_ssl) as client:
|
245 |
+
with client.stream("GET", file_url) as response:
|
246 |
+
if response.status_code == 200:
|
247 |
+
with open(file_path, "wb") as file:
|
248 |
+
for chunk in response.iter_bytes(chunk_size=8192):
|
249 |
+
file.write(chunk)
|
250 |
+
logger.info(f"Successfully downloaded file to {file_path}")
|
251 |
+
return file_path
|
252 |
+
else:
|
253 |
+
logger.error(
|
254 |
+
f"Download failed with status code: {response.status_code}"
|
255 |
+
)
|
256 |
+
return None
|
257 |
+
except httpx.TimeoutException:
|
258 |
+
logger.error("Download request timed out")
|
259 |
+
return None
|
260 |
+
except httpx.RequestError as e:
|
261 |
+
logger.error(f"Download request failed: {e}")
|
262 |
+
return None
|
263 |
+
except Exception as e:
|
264 |
+
logger.error(f"Unexpected error during download: {e}")
|
265 |
+
return None
|
266 |
+
|
267 |
+
|
268 |
+
# Backward compatibility functions
|
269 |
+
async def async_http_post_legacy(url: str, params: dict) -> Any:
|
270 |
+
"""
|
271 |
+
Legacy async HTTP POST function for backward compatibility.
|
272 |
+
|
273 |
+
This function maintains the old interface but uses the new implementation.
|
274 |
+
"""
|
275 |
+
return await async_http_post(url=url, data=params)
|
python-services/Retrieve/utils/i18n_context.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
国际化上下文管理器
|
3 |
+
提供更优雅的语言设置方式,避免在函数间传递language参数
|
4 |
+
"""
|
5 |
+
|
6 |
+
import contextvars
|
7 |
+
from utils.i18n_types import Language
|
8 |
+
|
9 |
+
# 创建上下文变量
|
10 |
+
_language_context = contextvars.ContextVar("language", default=Language.ENGLISH)
|
11 |
+
|
12 |
+
|
13 |
+
class I18nContext:
|
14 |
+
"""国际化上下文管理器"""
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def set_language(language: Language) -> None:
|
18 |
+
"""
|
19 |
+
设置当前上下文的语言
|
20 |
+
|
21 |
+
Args:
|
22 |
+
language: 语言枚举值
|
23 |
+
"""
|
24 |
+
_language_context.set(language)
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def get_language() -> Language:
|
28 |
+
"""
|
29 |
+
获取当前上下文的语言
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
当前语言枚举值
|
33 |
+
"""
|
34 |
+
return _language_context.get()
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def reset_language() -> None:
|
38 |
+
"""重置语言为默认值"""
|
39 |
+
_language_context.set(Language.ENGLISH)
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def get_language_value() -> str:
|
43 |
+
"""
|
44 |
+
获取当前语言的字符串值
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
语言字符串值
|
48 |
+
"""
|
49 |
+
return _language_context.get().value
|
50 |
+
|
51 |
+
|
52 |
+
class I18nContextManager:
|
53 |
+
"""国际化上下文管理器,支持with语句"""
|
54 |
+
|
55 |
+
def __init__(self, language: Language):
|
56 |
+
"""
|
57 |
+
初始化上下文管理器
|
58 |
+
|
59 |
+
Args:
|
60 |
+
language: 要设置的语言
|
61 |
+
"""
|
62 |
+
self.language = language
|
63 |
+
self._previous_language = None
|
64 |
+
|
65 |
+
def __enter__(self):
|
66 |
+
"""进入上下文时保存当前语言并设置新语言"""
|
67 |
+
self._previous_language = I18nContext.get_language()
|
68 |
+
I18nContext.set_language(self.language)
|
69 |
+
return self
|
70 |
+
|
71 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
72 |
+
"""退出上下文时恢复之前的语言"""
|
73 |
+
if self._previous_language is not None:
|
74 |
+
I18nContext.set_language(self._previous_language)
|
75 |
+
|
76 |
+
|
77 |
+
# 便捷函数
|
78 |
+
def set_language(language: Language) -> None:
|
79 |
+
"""设置当前语言"""
|
80 |
+
I18nContext.set_language(language)
|
81 |
+
|
82 |
+
|
83 |
+
def get_language() -> Language:
|
84 |
+
"""获取当前语言"""
|
85 |
+
return I18nContext.get_language()
|
86 |
+
|
87 |
+
|
88 |
+
def reset_language() -> None:
|
89 |
+
"""重置语言为默认值"""
|
90 |
+
I18nContext.reset_language()
|
91 |
+
|
92 |
+
|
93 |
+
def with_language(language: Language):
|
94 |
+
"""
|
95 |
+
创建语言上下文管理器
|
96 |
+
|
97 |
+
Args:
|
98 |
+
language: 要设置的语言
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
上下文管理器
|
102 |
+
"""
|
103 |
+
return I18nContextManager(language)
|
104 |
+
|
105 |
+
|
106 |
+
# 装饰器,用于自动设置语言
|
107 |
+
def with_language_decorator(language: Language):
|
108 |
+
"""
|
109 |
+
装饰器,为函数自动设置语言上下文
|
110 |
+
|
111 |
+
Args:
|
112 |
+
language: 要设置的语言
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
装饰器函数
|
116 |
+
"""
|
117 |
+
|
118 |
+
def decorator(func):
|
119 |
+
def wrapper(*args, **kwargs):
|
120 |
+
with I18nContextManager(language):
|
121 |
+
return func(*args, **kwargs)
|
122 |
+
|
123 |
+
return wrapper
|
124 |
+
|
125 |
+
return decorator
|
python-services/Retrieve/utils/i18n_messages.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
国际化消息配置文件
|
3 |
+
包含所有错误消息、成功消息、状态消息和UI标签消息的中英文映射
|
4 |
+
"""
|
5 |
+
|
6 |
+
from utils.i18n_types import Language
|
7 |
+
|
8 |
+
|
9 |
+
# 错误消息国际化
|
10 |
+
ERROR_MESSAGES = {
|
11 |
+
Language.CHINESE: {
|
12 |
+
"invalid_request": "无效的请求参数",
|
13 |
+
"search_failed": "搜索失败",
|
14 |
+
"no_results": "未找到相关结果",
|
15 |
+
"service_unavailable": "服务暂时不可用",
|
16 |
+
"internal_error": "内部服务器错误",
|
17 |
+
"invalid_language": "不支持的语言设置",
|
18 |
+
"query_too_long": "查询内容过长",
|
19 |
+
"rate_limit_exceeded": "请求频率过高,请稍后重试",
|
20 |
+
"authentication_failed": "认证失败",
|
21 |
+
"permission_denied": "权限不足",
|
22 |
+
"resource_not_found": "资源未找到",
|
23 |
+
"network_error": "网络连接错误",
|
24 |
+
"timeout_error": "请求超时",
|
25 |
+
"invalid_format": "数据格式错误",
|
26 |
+
"missing_required_field": "缺少必需字段",
|
27 |
+
"invalid_user_id": "无效的用户ID",
|
28 |
+
"search_service_error": "搜索服务错误",
|
29 |
+
"llm_service_error": "语言模型服务错误",
|
30 |
+
"embedding_service_error": "向量化服务错误",
|
31 |
+
"database_error": "数据库错误",
|
32 |
+
},
|
33 |
+
Language.ENGLISH: {
|
34 |
+
"invalid_request": "Invalid request parameters",
|
35 |
+
"search_failed": "Search failed",
|
36 |
+
"no_results": "No relevant results found",
|
37 |
+
"service_unavailable": "Service temporarily unavailable",
|
38 |
+
"internal_error": "Internal server error",
|
39 |
+
"invalid_language": "Unsupported language setting",
|
40 |
+
"query_too_long": "Query content too long",
|
41 |
+
"rate_limit_exceeded": "Request rate exceeded, please try again later",
|
42 |
+
"authentication_failed": "Authentication failed",
|
43 |
+
"permission_denied": "Permission denied",
|
44 |
+
"resource_not_found": "Resource not found",
|
45 |
+
"network_error": "Network connection error",
|
46 |
+
"timeout_error": "Request timeout",
|
47 |
+
"invalid_format": "Invalid data format",
|
48 |
+
"missing_required_field": "Missing required field",
|
49 |
+
"invalid_user_id": "Invalid user ID",
|
50 |
+
"search_service_error": "Search service error",
|
51 |
+
"llm_service_error": "Language model service error",
|
52 |
+
"embedding_service_error": "Embedding service error",
|
53 |
+
"database_error": "Database error",
|
54 |
+
},
|
55 |
+
}
|
56 |
+
|
57 |
+
# 成功消息国际化
|
58 |
+
SUCCESS_MESSAGES = {
|
59 |
+
Language.CHINESE: {
|
60 |
+
"search_success": "搜索成功",
|
61 |
+
"chat_success": "聊天服务正常",
|
62 |
+
"health_check_ok": "服务运行正常",
|
63 |
+
"results_found": "找到相关结果",
|
64 |
+
"processing_complete": "处理完成",
|
65 |
+
},
|
66 |
+
Language.ENGLISH: {
|
67 |
+
"search_success": "Search successful",
|
68 |
+
"chat_success": "Chat service normal",
|
69 |
+
"health_check_ok": "Service running normally",
|
70 |
+
"results_found": "Relevant results found",
|
71 |
+
"processing_complete": "Processing complete",
|
72 |
+
},
|
73 |
+
}
|
74 |
+
|
75 |
+
# 状态消息国际化
|
76 |
+
STATUS_MESSAGES = {
|
77 |
+
Language.CHINESE: {
|
78 |
+
"processing": "正在处理",
|
79 |
+
"searching": "正在搜索",
|
80 |
+
"generating": "正在生成回答",
|
81 |
+
"completed": "已完成",
|
82 |
+
"failed": "处理失败",
|
83 |
+
},
|
84 |
+
Language.ENGLISH: {
|
85 |
+
"processing": "Processing",
|
86 |
+
"searching": "Searching",
|
87 |
+
"generating": "Generating answer",
|
88 |
+
"completed": "Completed",
|
89 |
+
"failed": "Processing failed",
|
90 |
+
},
|
91 |
+
}
|
92 |
+
|
93 |
+
# UI标签消息国际化
|
94 |
+
LABEL_MESSAGES = {
|
95 |
+
Language.CHINESE: {
|
96 |
+
"web_search_start": "正在调用 Browser 进行内容检索,所需时间较长,请等待...",
|
97 |
+
"web_search": "正在调用 Browser 进行内容检索",
|
98 |
+
"personal_search_start": "正在调用 个人知识库 进行内容检索,所需时间较长,请等待...",
|
99 |
+
"personal_search": "正在调用 个人知识库 进行内容检索",
|
100 |
+
"pubmed_search_start": "正在调用 PubMed 进行内容检索,所需时间较长,请等待...",
|
101 |
+
"pubmed_search": "正在调用 PubMed 进行内容检索",
|
102 |
+
"generating_answer": "正在生成回答",
|
103 |
+
"processing": "正在处理",
|
104 |
+
"personal_search_description": "片段 {index}",
|
105 |
+
},
|
106 |
+
Language.ENGLISH: {
|
107 |
+
"web_search_start": "Retrieving content from Browser, this may take a while, please wait...",
|
108 |
+
"web_search": "Retrieving content from Browser",
|
109 |
+
"personal_search_start": "Retrieving content from Personal Knowledge Base, this may take a while, please wait...",
|
110 |
+
"personal_search": "Retrieving content from Personal Knowledge Base",
|
111 |
+
"pubmed_search_start": "Retrieving content from PubMed, this may take a while, please wait...",
|
112 |
+
"pubmed_search": "Retrieving content from PubMed",
|
113 |
+
"generating_answer": "Generating answer",
|
114 |
+
"processing": "Processing",
|
115 |
+
"personal_search_description": "Chunk {index} from this reference.",
|
116 |
+
},
|
117 |
+
}
|
118 |
+
|
119 |
+
# 系统消息国际化
|
120 |
+
SYSTEM_MESSAGES = {
|
121 |
+
Language.CHINESE: {
|
122 |
+
"welcome": "欢迎使用生物医学RAG服务",
|
123 |
+
"service_start": "服务已启动",
|
124 |
+
"service_stop": "服务已停止",
|
125 |
+
"connection_established": "连接已建立",
|
126 |
+
"connection_lost": "连接已断开",
|
127 |
+
"maintenance_mode": "系统维护中",
|
128 |
+
"updating": "系统更新中",
|
129 |
+
"backup_restore": "备份恢复中",
|
130 |
+
},
|
131 |
+
Language.ENGLISH: {
|
132 |
+
"welcome": "Welcome to Biomedical RAG Service",
|
133 |
+
"service_start": "Service started",
|
134 |
+
"service_stop": "Service stopped",
|
135 |
+
"connection_established": "Connection established",
|
136 |
+
"connection_lost": "Connection lost",
|
137 |
+
"maintenance_mode": "System under maintenance",
|
138 |
+
"updating": "System updating",
|
139 |
+
"backup_restore": "Backup restoring",
|
140 |
+
},
|
141 |
+
}
|
142 |
+
|
143 |
+
|
144 |
+
# 业务消息国际化
|
145 |
+
BUSINESS_MESSAGES = {
|
146 |
+
Language.CHINESE: {
|
147 |
+
"search_started": "开始搜索...",
|
148 |
+
"search_completed": "搜索完成",
|
149 |
+
"no_search_results": "未找到搜索结果",
|
150 |
+
"processing_request": "正在处理请求...",
|
151 |
+
"request_completed": "请求处理完成",
|
152 |
+
"upload_success": "文件上传成功",
|
153 |
+
"upload_failed": "文件上传失败",
|
154 |
+
"download_started": "开始下载...",
|
155 |
+
"download_completed": "下载完成",
|
156 |
+
"operation_success": "操作成功",
|
157 |
+
"operation_failed": "操作失败",
|
158 |
+
"data_saved": "数据已保存",
|
159 |
+
"data_deleted": "数据已删除",
|
160 |
+
"data_updated": "数据已更新",
|
161 |
+
"connection_timeout": "连接超时",
|
162 |
+
"server_busy": "服务器繁忙",
|
163 |
+
"maintenance_notice": "系统维护通知",
|
164 |
+
},
|
165 |
+
Language.ENGLISH: {
|
166 |
+
"search_started": "Search started...",
|
167 |
+
"search_completed": "Search completed",
|
168 |
+
"no_search_results": "No search results found",
|
169 |
+
"processing_request": "Processing request...",
|
170 |
+
"request_completed": "Request completed",
|
171 |
+
"upload_success": "File uploaded successfully",
|
172 |
+
"upload_failed": "File upload failed",
|
173 |
+
"download_started": "Download started...",
|
174 |
+
"download_completed": "Download completed",
|
175 |
+
"operation_success": "Operation successful",
|
176 |
+
"operation_failed": "Operation failed",
|
177 |
+
"data_saved": "Data saved",
|
178 |
+
"data_deleted": "Data deleted",
|
179 |
+
"data_updated": "Data updated",
|
180 |
+
"connection_timeout": "Connection timeout",
|
181 |
+
"server_busy": "Server busy",
|
182 |
+
"maintenance_notice": "System maintenance notice",
|
183 |
+
},
|
184 |
+
}
|
185 |
+
|
186 |
+
# 所有消息类型的映射
|
187 |
+
ALL_MESSAGE_TYPES = {
|
188 |
+
"error": ERROR_MESSAGES,
|
189 |
+
"success": SUCCESS_MESSAGES,
|
190 |
+
"status": STATUS_MESSAGES,
|
191 |
+
"label": LABEL_MESSAGES,
|
192 |
+
"system": SYSTEM_MESSAGES,
|
193 |
+
"business": BUSINESS_MESSAGES,
|
194 |
+
}
|
195 |
+
|
196 |
+
|
197 |
+
def get_message(message_type: str, key: str, language: Language) -> str:
|
198 |
+
"""
|
199 |
+
获取指定类型的国际化消息
|
200 |
+
|
201 |
+
Args:
|
202 |
+
message_type: 消息类型 (error, success, status, label, system, business)
|
203 |
+
key: 消息键
|
204 |
+
language: 语言
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
国际化消息字符串
|
208 |
+
"""
|
209 |
+
if message_type not in ALL_MESSAGE_TYPES:
|
210 |
+
return f"Unknown message type: {message_type}"
|
211 |
+
|
212 |
+
messages = ALL_MESSAGE_TYPES[message_type]
|
213 |
+
default_language = Language.CHINESE
|
214 |
+
|
215 |
+
return messages.get(language, messages[default_language]).get(
|
216 |
+
key,
|
217 |
+
messages[default_language].get(key, f"Unknown {message_type} message: {key}"),
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
def get_all_messages_for_language(language: Language) -> dict:
|
222 |
+
"""
|
223 |
+
获取指定语言的所有消息
|
224 |
+
|
225 |
+
Args:
|
226 |
+
language: 语言
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
包含所有消息类型的字典
|
230 |
+
"""
|
231 |
+
result = {}
|
232 |
+
for message_type, messages in ALL_MESSAGE_TYPES.items():
|
233 |
+
result[message_type] = messages.get(language, messages[Language.CHINESE])
|
234 |
+
return result
|
235 |
+
|
236 |
+
|
237 |
+
def get_available_message_types() -> list:
|
238 |
+
"""
|
239 |
+
获取所有可用的消息类型
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
消息类型列表
|
243 |
+
"""
|
244 |
+
return list(ALL_MESSAGE_TYPES.keys())
|
245 |
+
|
246 |
+
|
247 |
+
def get_available_keys_for_type(message_type: str) -> list:
|
248 |
+
"""
|
249 |
+
获取指定消息类型的所有可用键
|
250 |
+
|
251 |
+
Args:
|
252 |
+
message_type: 消息类型
|
253 |
+
|
254 |
+
Returns:
|
255 |
+
键列表
|
256 |
+
"""
|
257 |
+
if message_type not in ALL_MESSAGE_TYPES:
|
258 |
+
return []
|
259 |
+
|
260 |
+
messages = ALL_MESSAGE_TYPES[message_type]
|
261 |
+
# 使用中文作为默认语言来获取所有键
|
262 |
+
return list(messages[Language.CHINESE].keys())
|
python-services/Retrieve/utils/i18n_types.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
国际化类型定义
|
3 |
+
"""
|
4 |
+
|
5 |
+
from enum import Enum
|
6 |
+
|
7 |
+
|
8 |
+
class Language(Enum):
|
9 |
+
"""支持的语言枚举"""
|
10 |
+
|
11 |
+
CHINESE = "zh"
|
12 |
+
ENGLISH = "en"
|
python-services/Retrieve/utils/i18n_util.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
国际化工具类,支持中英文切换功能
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, Any, Optional
|
6 |
+
from utils.i18n_types import Language
|
7 |
+
from utils.i18n_messages import get_message
|
8 |
+
from utils.i18n_context import I18nContext
|
9 |
+
|
10 |
+
|
11 |
+
class I18nUtil:
|
12 |
+
"""国际化工具类"""
|
13 |
+
|
14 |
+
# 默认语言
|
15 |
+
DEFAULT_LANGUAGE = Language.ENGLISH
|
16 |
+
|
17 |
+
# 语言映射
|
18 |
+
LANGUAGE_MAPPING = {
|
19 |
+
"zh": Language.CHINESE,
|
20 |
+
"zh_cn": Language.CHINESE,
|
21 |
+
"en": Language.ENGLISH,
|
22 |
+
"en_us": Language.ENGLISH,
|
23 |
+
}
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def parse_language(cls, language_str: Optional[str]) -> Language:
|
27 |
+
"""
|
28 |
+
解析语言字符串
|
29 |
+
|
30 |
+
Args:
|
31 |
+
language_str: 语言字符串
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
语言枚举值
|
35 |
+
"""
|
36 |
+
if not language_str:
|
37 |
+
return cls.DEFAULT_LANGUAGE
|
38 |
+
|
39 |
+
# 标准化语言字符串
|
40 |
+
normalized = language_str.lower()
|
41 |
+
# 处理连字符和下划线
|
42 |
+
normalized = normalized.replace("-", "_")
|
43 |
+
|
44 |
+
return cls.LANGUAGE_MAPPING.get(normalized, cls.DEFAULT_LANGUAGE)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def get_error_message(cls, key: str, language: Optional[Language] = None) -> str:
|
48 |
+
"""
|
49 |
+
获取错误消息
|
50 |
+
|
51 |
+
Args:
|
52 |
+
key: 错误消息键
|
53 |
+
language: 语言,如果为None则使用上下文中的语言
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
错误消息
|
57 |
+
"""
|
58 |
+
if language is None:
|
59 |
+
language = I18nContext.get_language()
|
60 |
+
|
61 |
+
return get_message("error", key, language)
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def get_success_message(cls, key: str, language: Optional[Language] = None) -> str:
|
65 |
+
"""
|
66 |
+
获取成功消息
|
67 |
+
|
68 |
+
Args:
|
69 |
+
key: 成功消息键
|
70 |
+
language: 语言,如果为None则使用上下文中的语言
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
成功消息
|
74 |
+
"""
|
75 |
+
if language is None:
|
76 |
+
language = I18nContext.get_language()
|
77 |
+
|
78 |
+
return get_message("success", key, language)
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def get_status_message(cls, key: str, language: Optional[Language] = None) -> str:
|
82 |
+
"""
|
83 |
+
获取状态消息
|
84 |
+
|
85 |
+
Args:
|
86 |
+
key: 状态消息键
|
87 |
+
language: 语言,如果为None则使用上下文中的语言
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
状态消息
|
91 |
+
"""
|
92 |
+
if language is None:
|
93 |
+
language = I18nContext.get_language()
|
94 |
+
|
95 |
+
return get_message("status", key, language)
|
96 |
+
|
97 |
+
@classmethod
|
98 |
+
def get_label_message(cls, key: str, language: Optional[Language] = None) -> str:
|
99 |
+
"""
|
100 |
+
获取UI标签消息
|
101 |
+
|
102 |
+
Args:
|
103 |
+
key: 标签消息键
|
104 |
+
language: 语言,如果为None则使用上下文中的语言
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
标签消息
|
108 |
+
"""
|
109 |
+
if language is None:
|
110 |
+
language = I18nContext.get_language()
|
111 |
+
|
112 |
+
return get_message("label", key, language)
|
113 |
+
|
114 |
+
@classmethod
|
115 |
+
def get_system_message(cls, key: str, language: Optional[Language] = None) -> str:
|
116 |
+
"""
|
117 |
+
获取系统消息
|
118 |
+
|
119 |
+
Args:
|
120 |
+
key: 系统消息键
|
121 |
+
language: 语言,如果为None则使用上下文中的语言
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
系统消息
|
125 |
+
"""
|
126 |
+
if language is None:
|
127 |
+
language = I18nContext.get_language()
|
128 |
+
|
129 |
+
return get_message("system", key, language)
|
130 |
+
|
131 |
+
@classmethod
|
132 |
+
def get_business_message(cls, key: str, language: Optional[Language] = None) -> str:
|
133 |
+
"""
|
134 |
+
获取业务消息
|
135 |
+
|
136 |
+
Args:
|
137 |
+
key: 业务消息键
|
138 |
+
language: 语言,如果为None则使用上下文中的语言
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
业务消息
|
142 |
+
"""
|
143 |
+
if language is None:
|
144 |
+
language = I18nContext.get_language()
|
145 |
+
|
146 |
+
return get_message("business", key, language)
|
147 |
+
|
148 |
+
@classmethod
|
149 |
+
def create_error_response(
|
150 |
+
cls,
|
151 |
+
error_key: str,
|
152 |
+
language: Optional[Language] = None,
|
153 |
+
details: Optional[str] = None,
|
154 |
+
error_code: int = 400,
|
155 |
+
) -> Dict[str, Any]:
|
156 |
+
"""
|
157 |
+
创建错误响应
|
158 |
+
|
159 |
+
Args:
|
160 |
+
error_key: 错误消息键
|
161 |
+
language: 语言
|
162 |
+
details: 错误详情
|
163 |
+
error_code: 错误代码
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
错误响应字典
|
167 |
+
"""
|
168 |
+
if language is None:
|
169 |
+
language = I18nContext.get_language()
|
170 |
+
|
171 |
+
response = {
|
172 |
+
"success": False,
|
173 |
+
"error": {
|
174 |
+
"code": error_code,
|
175 |
+
"message": cls.get_error_message(error_key, language),
|
176 |
+
"language": language.value,
|
177 |
+
},
|
178 |
+
}
|
179 |
+
|
180 |
+
if details:
|
181 |
+
response["error"]["details"] = details
|
182 |
+
|
183 |
+
return response
|
184 |
+
|
185 |
+
@classmethod
|
186 |
+
def create_success_response(
|
187 |
+
cls,
|
188 |
+
data: Any,
|
189 |
+
language: Optional[Language] = None,
|
190 |
+
message_key: str = "search_success",
|
191 |
+
) -> Dict[str, Any]:
|
192 |
+
"""
|
193 |
+
创建成功响应
|
194 |
+
|
195 |
+
Args:
|
196 |
+
data: 响应数据
|
197 |
+
language: 语言
|
198 |
+
message_key: 成功消息键
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
成功响应字典
|
202 |
+
"""
|
203 |
+
if language is None:
|
204 |
+
language = I18nContext.get_language()
|
205 |
+
|
206 |
+
return {
|
207 |
+
"success": True,
|
208 |
+
"data": data,
|
209 |
+
"message": cls.get_success_message(message_key, language),
|
210 |
+
"language": language.value,
|
211 |
+
}
|
212 |
+
|
213 |
+
@classmethod
|
214 |
+
def create_status_response(
|
215 |
+
cls,
|
216 |
+
status_key: str,
|
217 |
+
language: Optional[Language] = None,
|
218 |
+
data: Optional[Any] = None,
|
219 |
+
) -> Dict[str, Any]:
|
220 |
+
"""
|
221 |
+
创建状态响应
|
222 |
+
|
223 |
+
Args:
|
224 |
+
status_key: 状态消息键
|
225 |
+
language: 语言
|
226 |
+
data: 响应数据
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
状态响应字典
|
230 |
+
"""
|
231 |
+
if language is None:
|
232 |
+
language = I18nContext.get_language()
|
233 |
+
|
234 |
+
response = {
|
235 |
+
"status": cls.get_status_message(status_key, language),
|
236 |
+
"language": language.value,
|
237 |
+
}
|
238 |
+
|
239 |
+
if data is not None:
|
240 |
+
response["data"] = data
|
241 |
+
|
242 |
+
return response
|
243 |
+
|
244 |
+
|
245 |
+
# 便捷函数
|
246 |
+
def get_language(language_str: Optional[str]) -> Language:
|
247 |
+
"""获取语言枚举值"""
|
248 |
+
return I18nUtil.parse_language(language_str)
|
249 |
+
|
250 |
+
|
251 |
+
def get_error_message(key: str, language: Optional[Language] = None) -> str:
|
252 |
+
"""获取错误消息"""
|
253 |
+
return I18nUtil.get_error_message(key, language)
|
254 |
+
|
255 |
+
|
256 |
+
def get_success_message(key: str, language: Optional[Language] = None) -> str:
|
257 |
+
"""获取成功消息"""
|
258 |
+
return I18nUtil.get_success_message(key, language)
|
259 |
+
|
260 |
+
|
261 |
+
def get_status_message(key: str, language: Optional[Language] = None) -> str:
|
262 |
+
"""获取状态消息"""
|
263 |
+
return I18nUtil.get_status_message(key, language)
|
264 |
+
|
265 |
+
|
266 |
+
def get_label_message(key: str, language: Optional[Language] = None) -> str:
|
267 |
+
"""获取UI标签消息"""
|
268 |
+
return I18nUtil.get_label_message(key, language)
|
269 |
+
|
270 |
+
|
271 |
+
def get_system_message(key: str, language: Optional[Language] = None) -> str:
|
272 |
+
"""获取系统消息"""
|
273 |
+
return I18nUtil.get_system_message(key, language)
|
274 |
+
|
275 |
+
|
276 |
+
def get_business_message(key: str, language: Optional[Language] = None) -> str:
|
277 |
+
"""获取业务消息"""
|
278 |
+
return I18nUtil.get_business_message(key, language)
|
279 |
+
|
280 |
+
|
281 |
+
def create_error_response(
|
282 |
+
error_key: str,
|
283 |
+
language: Optional[Language] = None,
|
284 |
+
details: Optional[str] = None,
|
285 |
+
error_code: int = 400,
|
286 |
+
) -> Dict[str, Any]:
|
287 |
+
"""创建错误响应"""
|
288 |
+
return I18nUtil.create_error_response(error_key, language, details, error_code)
|
289 |
+
|
290 |
+
|
291 |
+
def create_success_response(
|
292 |
+
data: Any, language: Optional[Language] = None, message_key: str = "search_success"
|
293 |
+
) -> Dict[str, Any]:
|
294 |
+
"""创建成功响应"""
|
295 |
+
return I18nUtil.create_success_response(data, language, message_key)
|
296 |
+
|
297 |
+
|
298 |
+
def create_status_response(
|
299 |
+
status_key: str, language: Optional[Language] = None, data: Optional[Any] = None
|
300 |
+
) -> Dict[str, Any]:
|
301 |
+
"""创建状态响应"""
|
302 |
+
return I18nUtil.create_status_response(status_key, language, data)
|
python-services/Retrieve/utils/snowflake_id.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import threading
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
|
6 |
+
class SnowflakeIDGenerator:
|
7 |
+
"""
|
8 |
+
雪花ID生成器
|
9 |
+
|
10 |
+
雪花ID结构 (64位):
|
11 |
+
- 符号位: 1位,固定为0
|
12 |
+
- 时间戳: 41位,毫秒级时间戳
|
13 |
+
- 工作机器ID: 10位,包含5位数据中心ID和5位机器ID
|
14 |
+
- 序列号: 12位,同一毫秒内的自增序列
|
15 |
+
|
16 |
+
特点:
|
17 |
+
- 趋势递增
|
18 |
+
- 全局唯一
|
19 |
+
- 支持分布式环境
|
20 |
+
- 高性能
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, datacenter_id: int = 1, worker_id: int = 1, sequence: int = 0):
|
24 |
+
"""
|
25 |
+
初始化雪花ID生成器
|
26 |
+
|
27 |
+
Args:
|
28 |
+
datacenter_id: 数据中心ID (0-31)
|
29 |
+
worker_id: 工作机器ID (0-31)
|
30 |
+
sequence: 初始序列号
|
31 |
+
"""
|
32 |
+
# 位数分配
|
33 |
+
self.TIMESTAMP_BITS = 41
|
34 |
+
self.DATACENTER_ID_BITS = 5
|
35 |
+
self.WORKER_ID_BITS = 5
|
36 |
+
self.SEQUENCE_BITS = 12
|
37 |
+
|
38 |
+
# 最大值
|
39 |
+
self.MAX_DATACENTER_ID = -1 ^ (-1 << self.DATACENTER_ID_BITS)
|
40 |
+
self.MAX_WORKER_ID = -1 ^ (-1 << self.WORKER_ID_BITS)
|
41 |
+
self.MAX_SEQUENCE = -1 ^ (-1 << self.SEQUENCE_BITS)
|
42 |
+
|
43 |
+
# 偏移量
|
44 |
+
self.WORKER_ID_SHIFT = self.SEQUENCE_BITS
|
45 |
+
self.DATACENTER_ID_SHIFT = self.SEQUENCE_BITS + self.WORKER_ID_BITS
|
46 |
+
self.TIMESTAMP_LEFT_SHIFT = (
|
47 |
+
self.SEQUENCE_BITS + self.WORKER_ID_BITS + self.DATACENTER_ID_BITS
|
48 |
+
)
|
49 |
+
|
50 |
+
# 验证参数
|
51 |
+
if datacenter_id > self.MAX_DATACENTER_ID or datacenter_id < 0:
|
52 |
+
raise ValueError(
|
53 |
+
f"Datacenter ID must be between 0 and {self.MAX_DATACENTER_ID}"
|
54 |
+
)
|
55 |
+
if worker_id > self.MAX_WORKER_ID or worker_id < 0:
|
56 |
+
raise ValueError(f"Worker ID must be between 0 and {self.MAX_WORKER_ID}")
|
57 |
+
|
58 |
+
self.datacenter_id = datacenter_id
|
59 |
+
self.worker_id = worker_id
|
60 |
+
self.sequence = sequence
|
61 |
+
|
62 |
+
# 时间戳基准点 (2023-01-01 00:00:00 UTC)
|
63 |
+
self.EPOCH = 1672531200000
|
64 |
+
|
65 |
+
# 上次生成ID的时间戳
|
66 |
+
self.last_timestamp = -1
|
67 |
+
|
68 |
+
# 线程锁
|
69 |
+
self.lock = threading.Lock()
|
70 |
+
|
71 |
+
def _get_timestamp(self) -> int:
|
72 |
+
"""
|
73 |
+
获取当前毫秒时间戳
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
当前毫秒时间戳
|
77 |
+
"""
|
78 |
+
return int(time.time() * 1000)
|
79 |
+
|
80 |
+
def _wait_for_next_millis(self, last_timestamp: int) -> int:
|
81 |
+
"""
|
82 |
+
等待到下一毫秒
|
83 |
+
|
84 |
+
Args:
|
85 |
+
last_timestamp: 上次时间戳
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
新的时间戳
|
89 |
+
"""
|
90 |
+
timestamp = self._get_timestamp()
|
91 |
+
while timestamp <= last_timestamp:
|
92 |
+
timestamp = self._get_timestamp()
|
93 |
+
return timestamp
|
94 |
+
|
95 |
+
def generate_id(self) -> int:
|
96 |
+
"""
|
97 |
+
生成雪花ID
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
64位雪花ID
|
101 |
+
|
102 |
+
Raises:
|
103 |
+
RuntimeError: 时钟回拨时抛出异常
|
104 |
+
"""
|
105 |
+
with self.lock:
|
106 |
+
timestamp = self._get_timestamp()
|
107 |
+
|
108 |
+
# 检查时钟回拨
|
109 |
+
if timestamp < self.last_timestamp:
|
110 |
+
raise RuntimeError(
|
111 |
+
f"Clock moved backwards. Refusing to generate id for {self.last_timestamp - timestamp} milliseconds"
|
112 |
+
)
|
113 |
+
|
114 |
+
# 如果是同一毫秒内
|
115 |
+
if timestamp == self.last_timestamp:
|
116 |
+
self.sequence = (self.sequence + 1) & self.MAX_SEQUENCE
|
117 |
+
# 如果序列号溢出,等待下一毫秒
|
118 |
+
if self.sequence == 0:
|
119 |
+
timestamp = self._wait_for_next_millis(self.last_timestamp)
|
120 |
+
else:
|
121 |
+
# 不同毫秒,序列号重置
|
122 |
+
self.sequence = 0
|
123 |
+
|
124 |
+
self.last_timestamp = timestamp
|
125 |
+
|
126 |
+
# 生成ID
|
127 |
+
snowflake_id = (
|
128 |
+
((timestamp - self.EPOCH) << self.TIMESTAMP_LEFT_SHIFT)
|
129 |
+
| (self.datacenter_id << self.DATACENTER_ID_SHIFT)
|
130 |
+
| (self.worker_id << self.WORKER_ID_SHIFT)
|
131 |
+
| self.sequence
|
132 |
+
)
|
133 |
+
|
134 |
+
return snowflake_id
|
135 |
+
|
136 |
+
def generate_id_str(self) -> str:
|
137 |
+
"""
|
138 |
+
生成字符串格式的雪花ID
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
字符串格式的雪花ID
|
142 |
+
"""
|
143 |
+
return str(self.generate_id())
|
144 |
+
|
145 |
+
def parse_id(self, snowflake_id: int) -> dict:
|
146 |
+
"""
|
147 |
+
解析雪花ID
|
148 |
+
|
149 |
+
Args:
|
150 |
+
snowflake_id: 雪花ID
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
包含解析结果的字典
|
154 |
+
"""
|
155 |
+
timestamp = (snowflake_id >> self.TIMESTAMP_LEFT_SHIFT) + self.EPOCH
|
156 |
+
datacenter_id = (
|
157 |
+
snowflake_id >> self.DATACENTER_ID_SHIFT
|
158 |
+
) & self.MAX_DATACENTER_ID
|
159 |
+
worker_id = (snowflake_id >> self.WORKER_ID_SHIFT) & self.MAX_WORKER_ID
|
160 |
+
sequence = snowflake_id & self.MAX_SEQUENCE
|
161 |
+
|
162 |
+
return {
|
163 |
+
"timestamp": timestamp,
|
164 |
+
"datacenter_id": datacenter_id,
|
165 |
+
"worker_id": worker_id,
|
166 |
+
"sequence": sequence,
|
167 |
+
"datetime": time.strftime(
|
168 |
+
"%Y-%m-%d %H:%M:%S", time.localtime(timestamp / 1000)
|
169 |
+
),
|
170 |
+
}
|
171 |
+
|
172 |
+
|
173 |
+
# 全局雪花ID生成器实例
|
174 |
+
_snowflake_generator: Optional[SnowflakeIDGenerator] = None
|
175 |
+
_generator_lock = threading.Lock()
|
176 |
+
|
177 |
+
|
178 |
+
def get_snowflake_generator(
|
179 |
+
datacenter_id: int = 1, worker_id: int = 1
|
180 |
+
) -> SnowflakeIDGenerator:
|
181 |
+
"""
|
182 |
+
获取全局雪花ID生成器实例
|
183 |
+
|
184 |
+
Args:
|
185 |
+
datacenter_id: 数据中心ID
|
186 |
+
worker_id: 工作机器ID
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
雪花ID生成器实例
|
190 |
+
"""
|
191 |
+
global _snowflake_generator
|
192 |
+
|
193 |
+
if _snowflake_generator is None:
|
194 |
+
with _generator_lock:
|
195 |
+
if _snowflake_generator is None:
|
196 |
+
_snowflake_generator = SnowflakeIDGenerator(datacenter_id, worker_id)
|
197 |
+
|
198 |
+
return _snowflake_generator
|
199 |
+
|
200 |
+
|
201 |
+
def generate_snowflake_id() -> int:
|
202 |
+
"""
|
203 |
+
生成雪花ID (使用默认配置)
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
64位雪花ID
|
207 |
+
"""
|
208 |
+
return get_snowflake_generator().generate_id()
|
209 |
+
|
210 |
+
|
211 |
+
def generate_snowflake_id_str() -> str:
|
212 |
+
"""
|
213 |
+
生成字符串格式的雪花ID (使用默认配置)
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
字符串格式的雪花ID
|
217 |
+
"""
|
218 |
+
return get_snowflake_generator().generate_id_str()
|
219 |
+
|
220 |
+
|
221 |
+
def parse_snowflake_id(snowflake_id: int) -> dict:
|
222 |
+
"""
|
223 |
+
解析雪花ID
|
224 |
+
|
225 |
+
Args:
|
226 |
+
snowflake_id: 雪花ID
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
包含解析结果的字典
|
230 |
+
"""
|
231 |
+
return get_snowflake_generator().parse_id(snowflake_id)
|
232 |
+
|
233 |
+
|
234 |
+
# 便捷函数
|
235 |
+
def snowflake_id() -> int:
|
236 |
+
"""
|
237 |
+
快速生成雪花ID的便捷函数
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
64位雪花ID
|
241 |
+
"""
|
242 |
+
return generate_snowflake_id()
|
243 |
+
|
244 |
+
|
245 |
+
def snowflake_id_str() -> str:
|
246 |
+
"""
|
247 |
+
快速生成字符串格式雪花ID的便捷函数
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
字符串格式的雪花ID
|
251 |
+
"""
|
252 |
+
return generate_snowflake_id_str()
|
python-services/Retrieve/utils/token_util.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tiktoken
|
2 |
+
|
3 |
+
|
4 |
+
def num_tokens_from_messages(messages, model="gpt-4o"):
|
5 |
+
"""
|
6 |
+
Returns the number of tokens used by a list of messages.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
messages (list): A list of messages.
|
10 |
+
model (str): The name of the model to use for tokenization.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
int: The number of tokens used by the messages.
|
14 |
+
"""
|
15 |
+
try:
|
16 |
+
encoding = tiktoken.encoding_for_model(model)
|
17 |
+
except KeyError:
|
18 |
+
print("Warning: model not found. Using cl100k_base encoding.")
|
19 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
20 |
+
if model == "gpt-3.5-turbo":
|
21 |
+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
22 |
+
elif model == "gpt-4o":
|
23 |
+
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
24 |
+
elif model == "gpt-3.5-turbo-0301":
|
25 |
+
tokens_per_message = (
|
26 |
+
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
27 |
+
)
|
28 |
+
tokens_per_name = -1 # if there's a name, the role is omitted
|
29 |
+
elif model == "gpt-4-0314":
|
30 |
+
tokens_per_message = 3
|
31 |
+
tokens_per_name = 1
|
32 |
+
else:
|
33 |
+
raise NotImplementedError(
|
34 |
+
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
35 |
+
)
|
36 |
+
num_tokens = 0
|
37 |
+
for message in messages:
|
38 |
+
num_tokens += tokens_per_message
|
39 |
+
for key, value in message.items():
|
40 |
+
num_tokens += len(encoding.encode(value))
|
41 |
+
if key == "name":
|
42 |
+
num_tokens += tokens_per_name
|
43 |
+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
44 |
+
return num_tokens
|
45 |
+
|
46 |
+
|
47 |
+
def num_tokens_from_text(text: str, model: str = "gpt-4o") -> int:
|
48 |
+
"""
|
49 |
+
Returns the number of tokens used by a text.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
text (str): The text to tokenize.
|
53 |
+
model (str): The name of the model to use for tokenization.
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
encoding = tiktoken.encoding_for_model(model)
|
57 |
+
except KeyError:
|
58 |
+
print("Warning: model not found. Using cl100k_base encoding.")
|
59 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
60 |
+
num_tokens = 0
|
61 |
+
if text:
|
62 |
+
num_tokens += len(encoding.encode(text))
|
63 |
+
return num_tokens
|
requirements.txt
CHANGED
@@ -12,4 +12,21 @@ pymupdf>=1.25.4
|
|
12 |
python-dotenv>=1.1.0
|
13 |
streamlit>=1.44.1
|
14 |
nest-asyncio>=1.6.0
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
python-dotenv>=1.1.0
|
13 |
streamlit>=1.44.1
|
14 |
nest-asyncio>=1.6.0
|
15 |
+
asgi_correlation_id==4.3.4
|
16 |
+
fastapi==0.115.12
|
17 |
+
uvicorn==0.34.0
|
18 |
+
loguru==0.7.3
|
19 |
+
pyyaml==6.0.2
|
20 |
+
httpx==0.28.1
|
21 |
+
requests==2.32.3
|
22 |
+
biopython==1.85
|
23 |
+
openpyxl==3.1.5
|
24 |
+
openai==1.86.0
|
25 |
+
openai-agents==0.0.17
|
26 |
+
pandas==2.2.3
|
27 |
+
pymilvus==2.5.8
|
28 |
+
crawl4ai==0.7.0
|
29 |
+
aiohttp==3.11.18
|
30 |
+
beautifulsoup4==4.12.3
|
31 |
+
tiktoken==0.9.0
|
32 |
+
fastapi-mcp==0.4.0
|
requirements_back.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faiss-cpu>=1.10.0
|
2 |
+
jupyter>=1.1.1
|
3 |
+
langchain-anthropic>=0.3.10
|
4 |
+
langchain-community>=0.3.20
|
5 |
+
langchain-mcp-adapters==0.1.9
|
6 |
+
langchain-openai>=0.3.11
|
7 |
+
langgraph>=0.3.21
|
8 |
+
mcp>=1.6.0
|
9 |
+
fastmcp
|
10 |
+
notebook>=7.3.3
|
11 |
+
pymupdf>=1.25.4
|
12 |
+
python-dotenv>=1.1.0
|
13 |
+
streamlit>=1.44.1
|
14 |
+
nest-asyncio>=1.6.0
|
15 |
+
fastapi
|