Spaces:
Running
Running
Upload 77 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- Dockerfile +23 -0
- app/config/config.py +479 -0
- app/core/application.py +153 -0
- app/core/constants.py +79 -0
- app/core/security.py +90 -0
- app/database/__init__.py +3 -0
- app/database/connection.py +71 -0
- app/database/initialization.py +77 -0
- app/database/models.py +62 -0
- app/database/services.py +429 -0
- app/domain/gemini_models.py +79 -0
- app/domain/image_models.py +20 -0
- app/domain/openai_models.py +42 -0
- app/exception/exceptions.py +140 -0
- app/handler/error_handler.py +32 -0
- app/handler/message_converter.py +349 -0
- app/handler/response_handler.py +360 -0
- app/handler/retry_handler.py +50 -0
- app/handler/stream_optimizer.py +143 -0
- app/log/logger.py +233 -0
- app/main.py +15 -0
- app/middleware/middleware.py +80 -0
- app/middleware/request_logging_middleware.py +40 -0
- app/middleware/smart_routing_middleware.py +210 -0
- app/router/config_routes.py +133 -0
- app/router/error_log_routes.py +233 -0
- app/router/gemini_routes.py +374 -0
- app/router/openai_compatiable_routes.py +113 -0
- app/router/openai_routes.py +175 -0
- app/router/routes.py +187 -0
- app/router/scheduler_routes.py +57 -0
- app/router/stats_routes.py +55 -0
- app/router/version_routes.py +37 -0
- app/router/vertex_express_routes.py +146 -0
- app/scheduler/scheduled_tasks.py +159 -0
- app/service/chat/gemini_chat_service.py +287 -0
- app/service/chat/openai_chat_service.py +606 -0
- app/service/chat/vertex_express_chat_service.py +277 -0
- app/service/client/api_client.py +222 -0
- app/service/config/config_service.py +261 -0
- app/service/embedding/embedding_service.py +78 -0
- app/service/error_log/error_log_service.py +178 -0
- app/service/image/image_create_service.py +162 -0
- app/service/key/key_manager.py +463 -0
- app/service/model/model_service.py +92 -0
- app/service/openai_compatiable/openai_compatiable_service.py +190 -0
- app/service/request_log/request_log_service.py +50 -0
- app/service/stats/stats_service.py +255 -0
- app/service/tts/tts_service.py +94 -0
.gitattributes
CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
files/image.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
files/image1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
files/image2.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
files/image3.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
files/image4.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
files/image5.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
files/image6.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
files/image7.png filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# 复制所需文件到容器中
|
6 |
+
COPY ./requirements.txt /app
|
7 |
+
COPY ./VERSION /app
|
8 |
+
|
9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
10 |
+
COPY ./app /app/app
|
11 |
+
ENV API_KEYS='["your_api_key_1"]'
|
12 |
+
ENV ALLOWED_TOKENS='["your_token_1"]'
|
13 |
+
ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
|
14 |
+
ENV TOOLS_CODE_EXECUTION_ENABLED=false
|
15 |
+
ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
|
16 |
+
ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
|
17 |
+
ENV URL_NORMALIZATION_ENABLED=false
|
18 |
+
|
19 |
+
# Expose port
|
20 |
+
EXPOSE 7860
|
21 |
+
|
22 |
+
# Run the application
|
23 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--no-access-log"]
|
app/config/config.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
应用程序配置模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
import datetime
|
6 |
+
import json
|
7 |
+
from typing import Any, Dict, List, Type
|
8 |
+
|
9 |
+
from pydantic import ValidationError, ValidationInfo, field_validator
|
10 |
+
from pydantic_settings import BaseSettings
|
11 |
+
from sqlalchemy import insert, select, update
|
12 |
+
|
13 |
+
from app.core.constants import (
|
14 |
+
API_VERSION,
|
15 |
+
DEFAULT_CREATE_IMAGE_MODEL,
|
16 |
+
DEFAULT_FILTER_MODELS,
|
17 |
+
DEFAULT_MODEL,
|
18 |
+
DEFAULT_SAFETY_SETTINGS,
|
19 |
+
DEFAULT_STREAM_CHUNK_SIZE,
|
20 |
+
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
21 |
+
DEFAULT_STREAM_MAX_DELAY,
|
22 |
+
DEFAULT_STREAM_MIN_DELAY,
|
23 |
+
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
24 |
+
DEFAULT_TIMEOUT,
|
25 |
+
MAX_RETRIES,
|
26 |
+
)
|
27 |
+
from app.log.logger import Logger
|
28 |
+
|
29 |
+
|
30 |
+
class Settings(BaseSettings):
|
31 |
+
# 数据库配置
|
32 |
+
DATABASE_TYPE: str = "mysql" # sqlite 或 mysql
|
33 |
+
SQLITE_DATABASE: str = "default_db"
|
34 |
+
MYSQL_HOST: str = ""
|
35 |
+
MYSQL_PORT: int = 3306
|
36 |
+
MYSQL_USER: str = ""
|
37 |
+
MYSQL_PASSWORD: str = ""
|
38 |
+
MYSQL_DATABASE: str = ""
|
39 |
+
MYSQL_SOCKET: str = ""
|
40 |
+
|
41 |
+
# 验证 MySQL 配置
|
42 |
+
@field_validator(
|
43 |
+
"MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE"
|
44 |
+
)
|
45 |
+
def validate_mysql_config(cls, v: Any, info: ValidationInfo) -> Any:
|
46 |
+
if info.data.get("DATABASE_TYPE") == "mysql":
|
47 |
+
if v is None or v == "":
|
48 |
+
raise ValueError(
|
49 |
+
"MySQL configuration is required when DATABASE_TYPE is 'mysql'"
|
50 |
+
)
|
51 |
+
return v
|
52 |
+
|
53 |
+
# API相关配置
|
54 |
+
API_KEYS: List[str]
|
55 |
+
ALLOWED_TOKENS: List[str]
|
56 |
+
BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
|
57 |
+
AUTH_TOKEN: str = ""
|
58 |
+
MAX_FAILURES: int = 3
|
59 |
+
TEST_MODEL: str = DEFAULT_MODEL
|
60 |
+
TIME_OUT: int = DEFAULT_TIMEOUT
|
61 |
+
MAX_RETRIES: int = MAX_RETRIES
|
62 |
+
PROXIES: List[str] = []
|
63 |
+
PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
|
64 |
+
VERTEX_API_KEYS: List[str] = []
|
65 |
+
VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
|
66 |
+
|
67 |
+
# 智能路由配置
|
68 |
+
URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
|
69 |
+
|
70 |
+
# 模型相关配置
|
71 |
+
SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
72 |
+
IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
|
73 |
+
FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
|
74 |
+
TOOLS_CODE_EXECUTION_ENABLED: bool = False
|
75 |
+
SHOW_SEARCH_LINK: bool = True
|
76 |
+
SHOW_THINKING_PROCESS: bool = True
|
77 |
+
THINKING_MODELS: List[str] = []
|
78 |
+
THINKING_BUDGET_MAP: Dict[str, float] = {}
|
79 |
+
|
80 |
+
# TTS相关配置
|
81 |
+
TTS_MODEL: str = "gemini-2.5-flash-preview-tts"
|
82 |
+
TTS_VOICE_NAME: str = "Zephyr"
|
83 |
+
TTS_SPEED: str = "normal"
|
84 |
+
|
85 |
+
# 图像生成相关配置
|
86 |
+
PAID_KEY: str = ""
|
87 |
+
CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
|
88 |
+
UPLOAD_PROVIDER: str = "smms"
|
89 |
+
SMMS_SECRET_TOKEN: str = ""
|
90 |
+
PICGO_API_KEY: str = ""
|
91 |
+
CLOUDFLARE_IMGBED_URL: str = ""
|
92 |
+
CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
|
93 |
+
|
94 |
+
# 流式输出优化器配置
|
95 |
+
STREAM_OPTIMIZER_ENABLED: bool = False
|
96 |
+
STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
|
97 |
+
STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
|
98 |
+
STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
|
99 |
+
STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
|
100 |
+
STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
|
101 |
+
|
102 |
+
# 假流式配置 (Fake Streaming Configuration)
|
103 |
+
FAKE_STREAM_ENABLED: bool = False # 是否启用假流式输出
|
104 |
+
FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: int = 5 # 假流式发送空数据的间隔时间(秒)
|
105 |
+
|
106 |
+
# 调度器配置
|
107 |
+
CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时
|
108 |
+
TIMEZONE: str = "Asia/Shanghai" # 默认时区
|
109 |
+
|
110 |
+
# github
|
111 |
+
GITHUB_REPO_OWNER: str = "snailyp"
|
112 |
+
GITHUB_REPO_NAME: str = "gemini-balance"
|
113 |
+
|
114 |
+
# 日志配置
|
115 |
+
LOG_LEVEL: str = "INFO"
|
116 |
+
AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True
|
117 |
+
AUTO_DELETE_ERROR_LOGS_DAYS: int = 7
|
118 |
+
AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False
|
119 |
+
AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
|
120 |
+
SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
|
121 |
+
|
122 |
+
|
123 |
+
def __init__(self, **kwargs):
|
124 |
+
super().__init__(**kwargs)
|
125 |
+
# 设置默认AUTH_TOKEN(如果未提供)
|
126 |
+
if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
|
127 |
+
self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
|
128 |
+
|
129 |
+
|
130 |
+
# 创建全局配置实例
|
131 |
+
settings = Settings()
|
132 |
+
|
133 |
+
|
134 |
+
def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
|
135 |
+
"""尝试将数据库字符串值解析为目标 Python 类型"""
|
136 |
+
from app.log.logger import get_config_logger
|
137 |
+
|
138 |
+
logger = get_config_logger()
|
139 |
+
try:
|
140 |
+
# 处理 List[str]
|
141 |
+
if target_type == List[str]:
|
142 |
+
try:
|
143 |
+
parsed = json.loads(db_value)
|
144 |
+
if isinstance(parsed, list):
|
145 |
+
return [str(item) for item in parsed]
|
146 |
+
except json.JSONDecodeError:
|
147 |
+
return [item.strip() for item in db_value.split(",") if item.strip()]
|
148 |
+
logger.warning(
|
149 |
+
f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
|
150 |
+
)
|
151 |
+
return [item.strip() for item in db_value.split(",") if item.strip()]
|
152 |
+
# 处理 Dict[str, float]
|
153 |
+
elif target_type == Dict[str, float]:
|
154 |
+
parsed_dict = {}
|
155 |
+
try:
|
156 |
+
parsed = json.loads(db_value)
|
157 |
+
if isinstance(parsed, dict):
|
158 |
+
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
159 |
+
else:
|
160 |
+
logger.warning(
|
161 |
+
f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
|
162 |
+
)
|
163 |
+
except (json.JSONDecodeError, ValueError, TypeError) as e1:
|
164 |
+
if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
|
165 |
+
logger.warning(
|
166 |
+
f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
|
167 |
+
)
|
168 |
+
try:
|
169 |
+
corrected_db_value = db_value.replace("'", '"')
|
170 |
+
parsed = json.loads(corrected_db_value)
|
171 |
+
if isinstance(parsed, dict):
|
172 |
+
parsed_dict = {str(k): float(v) for k, v in parsed.items()}
|
173 |
+
else:
|
174 |
+
logger.warning(
|
175 |
+
f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
|
176 |
+
)
|
177 |
+
except (json.JSONDecodeError, ValueError, TypeError) as e2:
|
178 |
+
logger.error(
|
179 |
+
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
logger.error(
|
183 |
+
f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
|
184 |
+
)
|
185 |
+
return parsed_dict
|
186 |
+
# 处理 List[Dict[str, str]]
|
187 |
+
elif target_type == List[Dict[str, str]]:
|
188 |
+
try:
|
189 |
+
parsed = json.loads(db_value)
|
190 |
+
if isinstance(parsed, list):
|
191 |
+
# 验证列表中的每个元素是否为字典,并且键和值都是字符串
|
192 |
+
valid = all(
|
193 |
+
isinstance(item, dict)
|
194 |
+
and all(isinstance(k, str) for k in item.keys())
|
195 |
+
and all(isinstance(v, str) for v in item.values())
|
196 |
+
for item in parsed
|
197 |
+
)
|
198 |
+
if valid:
|
199 |
+
return parsed
|
200 |
+
else:
|
201 |
+
logger.warning(
|
202 |
+
f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
|
203 |
+
)
|
204 |
+
return []
|
205 |
+
else:
|
206 |
+
logger.warning(
|
207 |
+
f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
|
208 |
+
)
|
209 |
+
return []
|
210 |
+
except json.JSONDecodeError:
|
211 |
+
logger.error(
|
212 |
+
f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
|
213 |
+
)
|
214 |
+
return []
|
215 |
+
except Exception as e:
|
216 |
+
logger.error(
|
217 |
+
f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
|
218 |
+
)
|
219 |
+
return []
|
220 |
+
# 处理 bool
|
221 |
+
elif target_type == bool:
|
222 |
+
return db_value.lower() in ("true", "1", "yes", "on")
|
223 |
+
# 处理 int
|
224 |
+
elif target_type == int:
|
225 |
+
return int(db_value)
|
226 |
+
# 处理 float
|
227 |
+
elif target_type == float:
|
228 |
+
return float(db_value)
|
229 |
+
# 默认为 str 或其他 pydantic 能直接处理的类型
|
230 |
+
else:
|
231 |
+
return db_value
|
232 |
+
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
233 |
+
logger.warning(
|
234 |
+
f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value."
|
235 |
+
)
|
236 |
+
return db_value # 解析失败则返回原始字符串
|
237 |
+
|
238 |
+
|
239 |
+
async def sync_initial_settings():
|
240 |
+
"""
|
241 |
+
应用启动时同步配置:
|
242 |
+
1. 从数据库加载设置。
|
243 |
+
2. 将数据库设置合并到内存 settings (数据库优先)。
|
244 |
+
3. 将最终的内存 settings 同步回数据库。
|
245 |
+
"""
|
246 |
+
from app.log.logger import get_config_logger
|
247 |
+
|
248 |
+
logger = get_config_logger()
|
249 |
+
# 延迟导入以避免循环依赖和确保数据库连接已初始化
|
250 |
+
from app.database.connection import database
|
251 |
+
from app.database.models import Settings as SettingsModel
|
252 |
+
|
253 |
+
global settings
|
254 |
+
logger.info("Starting initial settings synchronization...")
|
255 |
+
|
256 |
+
if not database.is_connected:
|
257 |
+
try:
|
258 |
+
await database.connect()
|
259 |
+
logger.info("Database connection established for initial sync.")
|
260 |
+
except Exception as e:
|
261 |
+
logger.error(
|
262 |
+
f"Failed to connect to database for initial settings sync: {e}. Skipping sync."
|
263 |
+
)
|
264 |
+
return
|
265 |
+
|
266 |
+
try:
|
267 |
+
# 1. 从数据库加载设置
|
268 |
+
db_settings_raw: List[Dict[str, Any]] = []
|
269 |
+
try:
|
270 |
+
query = select(SettingsModel.key, SettingsModel.value)
|
271 |
+
results = await database.fetch_all(query)
|
272 |
+
db_settings_raw = [
|
273 |
+
{"key": row["key"], "value": row["value"]} for row in results
|
274 |
+
]
|
275 |
+
logger.info(f"Fetched {len(db_settings_raw)} settings from database.")
|
276 |
+
except Exception as e:
|
277 |
+
logger.error(
|
278 |
+
f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings."
|
279 |
+
)
|
280 |
+
# 即使数据库读取失败,也要继续执行,确保基于 env/dotenv 的配置能同步到数据库
|
281 |
+
|
282 |
+
db_settings_map: Dict[str, str] = {
|
283 |
+
s["key"]: s["value"] for s in db_settings_raw
|
284 |
+
}
|
285 |
+
|
286 |
+
# 2. 将数据库设置合并到内存 settings (数据库优先)
|
287 |
+
updated_in_memory = False
|
288 |
+
|
289 |
+
for key, db_value in db_settings_map.items():
|
290 |
+
if key == "DATABASE_TYPE":
|
291 |
+
logger.debug(
|
292 |
+
f"Skipping update of '{key}' in memory from database. "
|
293 |
+
"This setting is controlled by environment/dotenv."
|
294 |
+
)
|
295 |
+
continue
|
296 |
+
if hasattr(settings, key):
|
297 |
+
target_type = Settings.__annotations__.get(key)
|
298 |
+
if target_type:
|
299 |
+
try:
|
300 |
+
parsed_db_value = _parse_db_value(key, db_value, target_type)
|
301 |
+
memory_value = getattr(settings, key)
|
302 |
+
|
303 |
+
# 比较解析后的值和内存中的值
|
304 |
+
# 注意:对于列表等复杂类型,直接比较可能不够健壮,但这里简化处理
|
305 |
+
if parsed_db_value != memory_value:
|
306 |
+
# 检查类型是否匹配,以防解析函数返回了不兼容的类型
|
307 |
+
type_match = False
|
308 |
+
if target_type == List[str] and isinstance(
|
309 |
+
parsed_db_value, list
|
310 |
+
):
|
311 |
+
type_match = True
|
312 |
+
elif target_type == Dict[str, float] and isinstance(
|
313 |
+
parsed_db_value, dict
|
314 |
+
):
|
315 |
+
type_match = True
|
316 |
+
elif target_type not in (
|
317 |
+
List[str],
|
318 |
+
Dict[str, float],
|
319 |
+
) and isinstance(parsed_db_value, target_type):
|
320 |
+
type_match = True
|
321 |
+
|
322 |
+
if type_match:
|
323 |
+
setattr(settings, key, parsed_db_value)
|
324 |
+
logger.debug(
|
325 |
+
f"Updated setting '{key}' in memory from database value ({target_type})."
|
326 |
+
)
|
327 |
+
updated_in_memory = True
|
328 |
+
else:
|
329 |
+
logger.warning(
|
330 |
+
f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update."
|
331 |
+
)
|
332 |
+
|
333 |
+
except Exception as e:
|
334 |
+
logger.error(
|
335 |
+
f"Error processing database setting for key '{key}': {e}"
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
logger.warning(
|
339 |
+
f"Database setting '{key}' not found in Settings model definition. Ignoring."
|
340 |
+
)
|
341 |
+
|
342 |
+
# 如果内存中有更新,重新验证 Pydantic 模型(可选但推荐)
|
343 |
+
if updated_in_memory:
|
344 |
+
try:
|
345 |
+
# 重新加载以确保类型转换和验证
|
346 |
+
settings = Settings(**settings.model_dump())
|
347 |
+
logger.info(
|
348 |
+
"Settings object re-validated after merging database values."
|
349 |
+
)
|
350 |
+
except ValidationError as e:
|
351 |
+
logger.error(
|
352 |
+
f"Validation error after merging database settings: {e}. Settings might be inconsistent."
|
353 |
+
)
|
354 |
+
|
355 |
+
# 3. 将最终的内存 settings 同步回数据库
|
356 |
+
final_memory_settings = settings.model_dump()
|
357 |
+
settings_to_update: List[Dict[str, Any]] = []
|
358 |
+
settings_to_insert: List[Dict[str, Any]] = []
|
359 |
+
now = datetime.datetime.now(datetime.timezone.utc)
|
360 |
+
|
361 |
+
existing_db_keys = set(db_settings_map.keys())
|
362 |
+
|
363 |
+
for key, value in final_memory_settings.items():
|
364 |
+
if key == "DATABASE_TYPE":
|
365 |
+
logger.debug(
|
366 |
+
f"Skipping synchronization of '{key}' to database. "
|
367 |
+
"This setting is controlled by environment/dotenv."
|
368 |
+
)
|
369 |
+
continue
|
370 |
+
|
371 |
+
# 序列化值为字符串或 JSON 字符串
|
372 |
+
if isinstance(value, (list, dict)):
|
373 |
+
db_value = json.dumps(
|
374 |
+
value, ensure_ascii=False
|
375 |
+
)
|
376 |
+
elif isinstance(value, bool):
|
377 |
+
db_value = str(value).lower()
|
378 |
+
elif value is None:
|
379 |
+
db_value = ""
|
380 |
+
else:
|
381 |
+
db_value = str(value)
|
382 |
+
|
383 |
+
data = {
|
384 |
+
"key": key,
|
385 |
+
"value": db_value,
|
386 |
+
"description": f"{key} configuration setting",
|
387 |
+
"updated_at": now,
|
388 |
+
}
|
389 |
+
|
390 |
+
if key in existing_db_keys:
|
391 |
+
# 仅当值与数据库中的不同时才更新
|
392 |
+
if db_settings_map[key] != db_value:
|
393 |
+
settings_to_update.append(data)
|
394 |
+
else:
|
395 |
+
# 如果键不在数据库中,则插入
|
396 |
+
data["created_at"] = now
|
397 |
+
settings_to_insert.append(data)
|
398 |
+
|
399 |
+
# 在事务中执行批量插入和更新
|
400 |
+
if settings_to_insert or settings_to_update:
|
401 |
+
try:
|
402 |
+
async with database.transaction():
|
403 |
+
if settings_to_insert:
|
404 |
+
# 获取现有描述以避免覆盖
|
405 |
+
query_existing = select(
|
406 |
+
SettingsModel.key, SettingsModel.description
|
407 |
+
).where(
|
408 |
+
SettingsModel.key.in_(
|
409 |
+
[s["key"] for s in settings_to_insert]
|
410 |
+
)
|
411 |
+
)
|
412 |
+
existing_desc = {
|
413 |
+
row["key"]: row["description"]
|
414 |
+
for row in await database.fetch_all(query_existing)
|
415 |
+
}
|
416 |
+
for item in settings_to_insert:
|
417 |
+
item["description"] = existing_desc.get(
|
418 |
+
item["key"], item["description"]
|
419 |
+
)
|
420 |
+
|
421 |
+
query_insert = insert(SettingsModel).values(settings_to_insert)
|
422 |
+
await database.execute(query=query_insert)
|
423 |
+
logger.info(
|
424 |
+
f"Synced (inserted) {len(settings_to_insert)} settings to database."
|
425 |
+
)
|
426 |
+
|
427 |
+
if settings_to_update:
|
428 |
+
# 获取现有描述以避免覆盖
|
429 |
+
query_existing = select(
|
430 |
+
SettingsModel.key, SettingsModel.description
|
431 |
+
).where(
|
432 |
+
SettingsModel.key.in_(
|
433 |
+
[s["key"] for s in settings_to_update]
|
434 |
+
)
|
435 |
+
)
|
436 |
+
existing_desc = {
|
437 |
+
row["key"]: row["description"]
|
438 |
+
for row in await database.fetch_all(query_existing)
|
439 |
+
}
|
440 |
+
|
441 |
+
for setting_data in settings_to_update:
|
442 |
+
setting_data["description"] = existing_desc.get(
|
443 |
+
setting_data["key"], setting_data["description"]
|
444 |
+
)
|
445 |
+
query_update = (
|
446 |
+
update(SettingsModel)
|
447 |
+
.where(SettingsModel.key == setting_data["key"])
|
448 |
+
.values(
|
449 |
+
value=setting_data["value"],
|
450 |
+
description=setting_data["description"],
|
451 |
+
updated_at=setting_data["updated_at"],
|
452 |
+
)
|
453 |
+
)
|
454 |
+
await database.execute(query=query_update)
|
455 |
+
logger.info(
|
456 |
+
f"Synced (updated) {len(settings_to_update)} settings to database."
|
457 |
+
)
|
458 |
+
except Exception as e:
|
459 |
+
logger.error(
|
460 |
+
f"Failed to sync settings to database during startup: {str(e)}"
|
461 |
+
)
|
462 |
+
else:
|
463 |
+
logger.info(
|
464 |
+
"No setting changes detected between memory and database during initial sync."
|
465 |
+
)
|
466 |
+
|
467 |
+
# 刷新日志等级
|
468 |
+
Logger.update_log_levels(final_memory_settings.get("LOG_LEVEL"))
|
469 |
+
|
470 |
+
except Exception as e:
|
471 |
+
logger.error(f"An unexpected error occurred during initial settings sync: {e}")
|
472 |
+
finally:
|
473 |
+
if database.is_connected:
|
474 |
+
try:
|
475 |
+
pass
|
476 |
+
except Exception as e:
|
477 |
+
logger.error(f"Error disconnecting database after initial sync: {e}")
|
478 |
+
|
479 |
+
logger.info("Initial settings synchronization finished.")
|
app/core/application.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import asynccontextmanager
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from fastapi import FastAPI
|
5 |
+
from fastapi.staticfiles import StaticFiles
|
6 |
+
from fastapi.templating import Jinja2Templates
|
7 |
+
|
8 |
+
from app.config.config import settings, sync_initial_settings
|
9 |
+
from app.database.connection import connect_to_db, disconnect_from_db
|
10 |
+
from app.database.initialization import initialize_database
|
11 |
+
from app.exception.exceptions import setup_exception_handlers
|
12 |
+
from app.log.logger import get_application_logger
|
13 |
+
from app.middleware.middleware import setup_middlewares
|
14 |
+
from app.router.routes import setup_routers
|
15 |
+
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
16 |
+
from app.service.key.key_manager import get_key_manager_instance
|
17 |
+
from app.service.update.update_service import check_for_updates
|
18 |
+
from app.utils.helpers import get_current_version
|
19 |
+
|
20 |
+
logger = get_application_logger()
|
21 |
+
|
22 |
+
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
|
23 |
+
STATIC_DIR = PROJECT_ROOT / "app" / "static"
|
24 |
+
TEMPLATES_DIR = PROJECT_ROOT / "app" / "templates"
|
25 |
+
|
26 |
+
# 初始化模板引擎,并添加全局变量
|
27 |
+
templates = Jinja2Templates(directory="app/templates")
|
28 |
+
|
29 |
+
|
30 |
+
# 定义一个函数来更新模板全局变量
|
31 |
+
def update_template_globals(app: FastAPI, update_info: dict):
|
32 |
+
# Jinja2Templates 实例没有直接更新全局变量的方法
|
33 |
+
# 我们需要在请求上下文中传递这些变量,或者修改 Jinja 环境
|
34 |
+
# 更简单的方法是将其存储在 app.state 中,并在渲染时传递
|
35 |
+
app.state.update_info = update_info
|
36 |
+
logger.info(f"Update info stored in app.state: {update_info}")
|
37 |
+
|
38 |
+
|
39 |
+
# --- Helper functions for lifespan ---
|
40 |
+
async def _setup_database_and_config(app_settings):
|
41 |
+
"""Initializes database, syncs settings, and initializes KeyManager."""
|
42 |
+
initialize_database()
|
43 |
+
logger.info("Database initialized successfully")
|
44 |
+
await connect_to_db()
|
45 |
+
await sync_initial_settings()
|
46 |
+
await get_key_manager_instance(app_settings.API_KEYS, app_settings.VERTEX_API_KEYS)
|
47 |
+
logger.info("Database, config sync, and KeyManager initialized successfully")
|
48 |
+
|
49 |
+
|
50 |
+
async def _shutdown_database():
|
51 |
+
"""Disconnects from the database."""
|
52 |
+
await disconnect_from_db()
|
53 |
+
|
54 |
+
|
55 |
+
def _start_scheduler():
|
56 |
+
"""Starts the background scheduler."""
|
57 |
+
try:
|
58 |
+
start_scheduler()
|
59 |
+
logger.info("Scheduler started successfully.")
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Failed to start scheduler: {e}")
|
62 |
+
|
63 |
+
|
64 |
+
def _stop_scheduler():
|
65 |
+
"""Stops the background scheduler."""
|
66 |
+
stop_scheduler()
|
67 |
+
|
68 |
+
|
69 |
+
async def _perform_update_check(app: FastAPI):
|
70 |
+
"""Checks for updates and stores the info in app.state."""
|
71 |
+
update_available, latest_version, error_message = await check_for_updates()
|
72 |
+
current_version = get_current_version()
|
73 |
+
update_info = {
|
74 |
+
"update_available": update_available,
|
75 |
+
"latest_version": latest_version,
|
76 |
+
"error_message": error_message,
|
77 |
+
"current_version": current_version,
|
78 |
+
}
|
79 |
+
if not hasattr(app, "state"):
|
80 |
+
from starlette.datastructures import State
|
81 |
+
|
82 |
+
app.state = State()
|
83 |
+
app.state.update_info = update_info
|
84 |
+
logger.info(f"Update check completed. Info: {update_info}")
|
85 |
+
|
86 |
+
|
87 |
+
@asynccontextmanager
|
88 |
+
async def lifespan(app: FastAPI):
|
89 |
+
"""
|
90 |
+
Manages the application startup and shutdown events.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
app: FastAPI应用实例
|
94 |
+
"""
|
95 |
+
logger.info("Application starting up...")
|
96 |
+
try:
|
97 |
+
await _setup_database_and_config(settings)
|
98 |
+
await _perform_update_check(app)
|
99 |
+
_start_scheduler()
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
logger.critical(
|
103 |
+
f"Critical error during application startup: {str(e)}", exc_info=True
|
104 |
+
)
|
105 |
+
|
106 |
+
yield
|
107 |
+
|
108 |
+
logger.info("Application shutting down...")
|
109 |
+
_stop_scheduler()
|
110 |
+
await _shutdown_database()
|
111 |
+
|
112 |
+
|
113 |
+
def create_app() -> FastAPI:
|
114 |
+
"""
|
115 |
+
创建并配置FastAPI应用程序实例
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
FastAPI: 配置好的FastAPI应用程序实例
|
119 |
+
"""
|
120 |
+
|
121 |
+
# 创建FastAPI应用
|
122 |
+
current_version = get_current_version()
|
123 |
+
app = FastAPI(
|
124 |
+
title="Gemini Balance API",
|
125 |
+
description="Gemini API代理服务,支持负载均衡和密钥管理",
|
126 |
+
version=current_version,
|
127 |
+
lifespan=lifespan,
|
128 |
+
)
|
129 |
+
|
130 |
+
if not hasattr(app, "state"):
|
131 |
+
from starlette.datastructures import State
|
132 |
+
|
133 |
+
app.state = State()
|
134 |
+
app.state.update_info = {
|
135 |
+
"update_available": False,
|
136 |
+
"latest_version": None,
|
137 |
+
"error_message": "Initializing...",
|
138 |
+
"current_version": current_version,
|
139 |
+
}
|
140 |
+
|
141 |
+
# 配置静态文件
|
142 |
+
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
143 |
+
|
144 |
+
# 配置中间件
|
145 |
+
setup_middlewares(app)
|
146 |
+
|
147 |
+
# 配置异常处理器
|
148 |
+
setup_exception_handlers(app)
|
149 |
+
|
150 |
+
# 配置路由
|
151 |
+
setup_routers(app)
|
152 |
+
|
153 |
+
return app
|
app/core/constants.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
常量定义模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
# API相关常量
|
6 |
+
API_VERSION = "v1beta"
|
7 |
+
DEFAULT_TIMEOUT = 300 # 秒
|
8 |
+
MAX_RETRIES = 3 # 最大重试次数
|
9 |
+
|
10 |
+
# 模型相关常量
|
11 |
+
SUPPORTED_ROLES = ["user", "model", "system"]
|
12 |
+
DEFAULT_MODEL = "gemini-1.5-flash"
|
13 |
+
DEFAULT_TEMPERATURE = 0.7
|
14 |
+
DEFAULT_MAX_TOKENS = 8192
|
15 |
+
DEFAULT_TOP_P = 0.9
|
16 |
+
DEFAULT_TOP_K = 40
|
17 |
+
DEFAULT_FILTER_MODELS = [
|
18 |
+
"gemini-1.0-pro-vision-latest",
|
19 |
+
"gemini-pro-vision",
|
20 |
+
"chat-bison-001",
|
21 |
+
"text-bison-001",
|
22 |
+
"embedding-gecko-001"
|
23 |
+
]
|
24 |
+
DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
|
25 |
+
|
26 |
+
# 图像生成相关常量
|
27 |
+
VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
|
28 |
+
|
29 |
+
# 上传提供商
|
30 |
+
UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
|
31 |
+
DEFAULT_UPLOAD_PROVIDER = "smms"
|
32 |
+
|
33 |
+
# 流式输出相关常量
|
34 |
+
DEFAULT_STREAM_MIN_DELAY = 0.016
|
35 |
+
DEFAULT_STREAM_MAX_DELAY = 0.024
|
36 |
+
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
|
37 |
+
DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
|
38 |
+
DEFAULT_STREAM_CHUNK_SIZE = 5
|
39 |
+
|
40 |
+
# 正则表达式模式
|
41 |
+
IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
|
42 |
+
DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
|
43 |
+
|
44 |
+
# Audio/Video Settings
|
45 |
+
SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
|
46 |
+
SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
|
47 |
+
MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
|
48 |
+
MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
|
49 |
+
|
50 |
+
# Optional: Define MIME type mappings if needed, or handle directly in converter
|
51 |
+
AUDIO_FORMAT_TO_MIMETYPE = {
|
52 |
+
"wav": "audio/wav",
|
53 |
+
"mp3": "audio/mpeg",
|
54 |
+
"flac": "audio/flac",
|
55 |
+
"ogg": "audio/ogg",
|
56 |
+
}
|
57 |
+
|
58 |
+
VIDEO_FORMAT_TO_MIMETYPE = {
|
59 |
+
"mp4": "video/mp4",
|
60 |
+
"mov": "video/quicktime",
|
61 |
+
"avi": "video/x-msvideo",
|
62 |
+
"webm": "video/webm",
|
63 |
+
}
|
64 |
+
|
65 |
+
GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
|
66 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
67 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
68 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
69 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
70 |
+
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
|
71 |
+
]
|
72 |
+
|
73 |
+
DEFAULT_SAFETY_SETTINGS = [
|
74 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
75 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
76 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
77 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
78 |
+
{"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
|
79 |
+
]
|
app/core/security.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from fastapi import Header, HTTPException
|
4 |
+
|
5 |
+
from app.config.config import settings
|
6 |
+
from app.log.logger import get_security_logger
|
7 |
+
|
8 |
+
logger = get_security_logger()
|
9 |
+
|
10 |
+
|
11 |
+
def verify_auth_token(token: str) -> bool:
|
12 |
+
return token == settings.AUTH_TOKEN
|
13 |
+
|
14 |
+
|
15 |
+
class SecurityService:
|
16 |
+
|
17 |
+
async def verify_key(self, key: str):
|
18 |
+
if key not in settings.ALLOWED_TOKENS and key != settings.AUTH_TOKEN:
|
19 |
+
logger.error("Invalid key")
|
20 |
+
raise HTTPException(status_code=401, detail="Invalid key")
|
21 |
+
return key
|
22 |
+
|
23 |
+
async def verify_authorization(
|
24 |
+
self, authorization: Optional[str] = Header(None)
|
25 |
+
) -> str:
|
26 |
+
if not authorization:
|
27 |
+
logger.error("Missing Authorization header")
|
28 |
+
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
29 |
+
|
30 |
+
if not authorization.startswith("Bearer "):
|
31 |
+
logger.error("Invalid Authorization header format")
|
32 |
+
raise HTTPException(
|
33 |
+
status_code=401, detail="Invalid Authorization header format"
|
34 |
+
)
|
35 |
+
|
36 |
+
token = authorization.replace("Bearer ", "")
|
37 |
+
if token not in settings.ALLOWED_TOKENS and token != settings.AUTH_TOKEN:
|
38 |
+
logger.error("Invalid token")
|
39 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
40 |
+
|
41 |
+
return token
|
42 |
+
|
43 |
+
async def verify_goog_api_key(
|
44 |
+
self, x_goog_api_key: Optional[str] = Header(None)
|
45 |
+
) -> str:
|
46 |
+
"""验证Google API Key"""
|
47 |
+
if not x_goog_api_key:
|
48 |
+
logger.error("Missing x-goog-api-key header")
|
49 |
+
raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
|
50 |
+
|
51 |
+
if (
|
52 |
+
x_goog_api_key not in settings.ALLOWED_TOKENS
|
53 |
+
and x_goog_api_key != settings.AUTH_TOKEN
|
54 |
+
):
|
55 |
+
logger.error("Invalid x-goog-api-key")
|
56 |
+
raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
|
57 |
+
|
58 |
+
return x_goog_api_key
|
59 |
+
|
60 |
+
async def verify_auth_token(
|
61 |
+
self, authorization: Optional[str] = Header(None)
|
62 |
+
) -> str:
|
63 |
+
if not authorization:
|
64 |
+
logger.error("Missing auth_token header")
|
65 |
+
raise HTTPException(status_code=401, detail="Missing auth_token header")
|
66 |
+
token = authorization.replace("Bearer ", "")
|
67 |
+
if token != settings.AUTH_TOKEN:
|
68 |
+
logger.error("Invalid auth_token")
|
69 |
+
raise HTTPException(status_code=401, detail="Invalid auth_token")
|
70 |
+
|
71 |
+
return token
|
72 |
+
|
73 |
+
async def verify_key_or_goog_api_key(
|
74 |
+
self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None)
|
75 |
+
) -> str:
|
76 |
+
"""验证URL中的key或请求头中的x-goog-api-key"""
|
77 |
+
# 如果URL中的key有效,直接返回
|
78 |
+
if key in settings.ALLOWED_TOKENS or key == settings.AUTH_TOKEN:
|
79 |
+
return key
|
80 |
+
|
81 |
+
# 否则检查请求头中的x-goog-api-key
|
82 |
+
if not x_goog_api_key:
|
83 |
+
logger.error("Invalid key and missing x-goog-api-key header")
|
84 |
+
raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
|
85 |
+
|
86 |
+
if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN:
|
87 |
+
logger.error("Invalid key and invalid x-goog-api-key")
|
88 |
+
raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
|
89 |
+
|
90 |
+
return x_goog_api_key
|
app/database/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
数据库模块
|
3 |
+
"""
|
app/database/connection.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
数据库连接池模块
|
3 |
+
"""
|
4 |
+
from pathlib import Path
|
5 |
+
from urllib.parse import quote_plus
|
6 |
+
from databases import Database
|
7 |
+
from sqlalchemy import create_engine, MetaData
|
8 |
+
from sqlalchemy.ext.declarative import declarative_base
|
9 |
+
|
10 |
+
from app.config.config import settings
|
11 |
+
from app.log.logger import get_database_logger
|
12 |
+
|
13 |
+
logger = get_database_logger()
|
14 |
+
|
15 |
+
# 数据库URL
|
16 |
+
if settings.DATABASE_TYPE == "sqlite":
|
17 |
+
# 确保 data 目录存在
|
18 |
+
data_dir = Path("data")
|
19 |
+
data_dir.mkdir(exist_ok=True)
|
20 |
+
db_path = data_dir / settings.SQLITE_DATABASE
|
21 |
+
DATABASE_URL = f"sqlite:///{db_path}"
|
22 |
+
elif settings.DATABASE_TYPE == "mysql":
|
23 |
+
if settings.MYSQL_SOCKET:
|
24 |
+
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
|
25 |
+
else:
|
26 |
+
DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
|
27 |
+
else:
|
28 |
+
raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.")
|
29 |
+
|
30 |
+
# 创建数据库引擎
|
31 |
+
# pool_pre_ping=True: 在从连接池获取连接前执行简单的 "ping" 测试,确保连接有效
|
32 |
+
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
|
33 |
+
|
34 |
+
# 创建元数据对象
|
35 |
+
metadata = MetaData()
|
36 |
+
|
37 |
+
# 创建基类
|
38 |
+
Base = declarative_base(metadata=metadata)
|
39 |
+
|
40 |
+
# 创建数据库连接池,并配置连接池参数,在sqlite中不使用连接池
|
41 |
+
# min_size/max_size: 连接池的最小/最大连接数
|
42 |
+
# pool_recycle=3600: 连接在池中允许存在的最大秒数(生命周期)。
|
43 |
+
# 设置为 3600 秒(1小时),确保在 MySQL 默认的 wait_timeout (通常8小时) 或其他网络超时之前回收连接。
|
44 |
+
# 如果遇到连接失效问题,可以尝试调低此值,使其小于实际的 wait_timeout 或网络超时时间。
|
45 |
+
# databases 库会自动处理连接失效后的重连尝试。
|
46 |
+
if settings.DATABASE_TYPE == "sqlite":
|
47 |
+
database = Database(DATABASE_URL)
|
48 |
+
else:
|
49 |
+
database = Database(DATABASE_URL, min_size=5, max_size=20, pool_recycle=1800)
|
50 |
+
|
51 |
+
async def connect_to_db():
|
52 |
+
"""
|
53 |
+
连接到数据库
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
await database.connect()
|
57 |
+
logger.info(f"Connected to {settings.DATABASE_TYPE}")
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Failed to connect to database: {str(e)}")
|
60 |
+
raise
|
61 |
+
|
62 |
+
|
63 |
+
async def disconnect_from_db():
|
64 |
+
"""
|
65 |
+
断开数据库连接
|
66 |
+
"""
|
67 |
+
try:
|
68 |
+
await database.disconnect()
|
69 |
+
logger.info(f"Disconnected from {settings.DATABASE_TYPE}")
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Failed to disconnect from database: {str(e)}")
|
app/database/initialization.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
数据库初始化模块
|
3 |
+
"""
|
4 |
+
from dotenv import dotenv_values
|
5 |
+
|
6 |
+
from sqlalchemy import inspect
|
7 |
+
from sqlalchemy.orm import Session
|
8 |
+
|
9 |
+
from app.database.connection import engine, Base
|
10 |
+
from app.database.models import Settings
|
11 |
+
from app.log.logger import get_database_logger
|
12 |
+
|
13 |
+
logger = get_database_logger()
|
14 |
+
|
15 |
+
|
16 |
+
def create_tables():
|
17 |
+
"""
|
18 |
+
创建数据库表
|
19 |
+
"""
|
20 |
+
try:
|
21 |
+
# 创建所有表
|
22 |
+
Base.metadata.create_all(engine)
|
23 |
+
logger.info("Database tables created successfully")
|
24 |
+
except Exception as e:
|
25 |
+
logger.error(f"Failed to create database tables: {str(e)}")
|
26 |
+
raise
|
27 |
+
|
28 |
+
|
29 |
+
def import_env_to_settings():
|
30 |
+
"""
|
31 |
+
将.env文件中的配置项导入到t_settings表中
|
32 |
+
"""
|
33 |
+
try:
|
34 |
+
# 获取.env文件中的所有配置项
|
35 |
+
env_values = dotenv_values(".env")
|
36 |
+
|
37 |
+
# 获取检查器
|
38 |
+
inspector = inspect(engine)
|
39 |
+
|
40 |
+
# 检查t_settings表是否存在
|
41 |
+
if "t_settings" in inspector.get_table_names():
|
42 |
+
# 使用Session进行数据库操作
|
43 |
+
with Session(engine) as session:
|
44 |
+
# 获取所有现有的配置项
|
45 |
+
current_settings = {setting.key: setting for setting in session.query(Settings).all()}
|
46 |
+
|
47 |
+
# 遍历所有配置项
|
48 |
+
for key, value in env_values.items():
|
49 |
+
# 检查配置项是否已存在
|
50 |
+
if key not in current_settings:
|
51 |
+
# 插入配置项
|
52 |
+
new_setting = Settings(key=key, value=value)
|
53 |
+
session.add(new_setting)
|
54 |
+
logger.info(f"Inserted setting: {key}")
|
55 |
+
|
56 |
+
# 提交事务
|
57 |
+
session.commit()
|
58 |
+
|
59 |
+
logger.info("Environment variables imported to settings table successfully")
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Failed to import environment variables to settings table: {str(e)}")
|
62 |
+
raise
|
63 |
+
|
64 |
+
|
65 |
+
def initialize_database():
|
66 |
+
"""
|
67 |
+
初始化数据库
|
68 |
+
"""
|
69 |
+
try:
|
70 |
+
# 创建表
|
71 |
+
create_tables()
|
72 |
+
|
73 |
+
# 导入环境变量
|
74 |
+
import_env_to_settings()
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Failed to initialize database: {str(e)}")
|
77 |
+
raise
|
app/database/models.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
数据库模型模块
|
3 |
+
"""
|
4 |
+
import datetime
|
5 |
+
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
|
6 |
+
|
7 |
+
from app.database.connection import Base
|
8 |
+
|
9 |
+
|
10 |
+
class Settings(Base):
|
11 |
+
"""
|
12 |
+
设置表,对应.env中的配置项
|
13 |
+
"""
|
14 |
+
__tablename__ = "t_settings"
|
15 |
+
|
16 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
17 |
+
key = Column(String(100), nullable=False, unique=True, comment="配置项键名")
|
18 |
+
value = Column(Text, nullable=True, comment="配置项值")
|
19 |
+
description = Column(String(255), nullable=True, comment="配置项描述")
|
20 |
+
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
21 |
+
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
22 |
+
|
23 |
+
def __repr__(self):
|
24 |
+
return f"<Settings(key='{self.key}', value='{self.value}')>"
|
25 |
+
|
26 |
+
|
27 |
+
class ErrorLog(Base):
|
28 |
+
"""
|
29 |
+
错误日志表
|
30 |
+
"""
|
31 |
+
__tablename__ = "t_error_logs"
|
32 |
+
|
33 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
34 |
+
gemini_key = Column(String(100), nullable=True, comment="Gemini API密钥")
|
35 |
+
model_name = Column(String(100), nullable=True, comment="模型名称")
|
36 |
+
error_type = Column(String(50), nullable=True, comment="错误类型")
|
37 |
+
error_log = Column(Text, nullable=True, comment="错误日志")
|
38 |
+
error_code = Column(Integer, nullable=True, comment="错误代码")
|
39 |
+
request_msg = Column(JSON, nullable=True, comment="请求消息")
|
40 |
+
request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
|
41 |
+
|
42 |
+
def __repr__(self):
|
43 |
+
return f"<ErrorLog(id='{self.id}', gemini_key='{self.gemini_key}')>"
|
44 |
+
|
45 |
+
|
46 |
+
class RequestLog(Base):
|
47 |
+
"""
|
48 |
+
API 请求日志表
|
49 |
+
"""
|
50 |
+
|
51 |
+
__tablename__ = "t_request_log"
|
52 |
+
|
53 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
54 |
+
request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
|
55 |
+
model_name = Column(String(100), nullable=True, comment="模型名称")
|
56 |
+
api_key = Column(String(100), nullable=True, comment="使用的API密钥")
|
57 |
+
is_success = Column(Boolean, nullable=False, comment="请求是否成功")
|
58 |
+
status_code = Column(Integer, nullable=True, comment="API响应状态码")
|
59 |
+
latency_ms = Column(Integer, nullable=True, comment="请求耗时(毫秒)")
|
60 |
+
|
61 |
+
def __repr__(self):
|
62 |
+
return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
|
app/database/services.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
数据库服务模块
|
3 |
+
"""
|
4 |
+
from typing import List, Optional, Dict, Any, Union
|
5 |
+
from datetime import datetime
|
6 |
+
from sqlalchemy import func, desc, asc, select, insert, update, delete
|
7 |
+
import json
|
8 |
+
from app.database.connection import database
|
9 |
+
from app.database.models import Settings, ErrorLog, RequestLog
|
10 |
+
from app.log.logger import get_database_logger
|
11 |
+
|
12 |
+
logger = get_database_logger()
|
13 |
+
|
14 |
+
|
15 |
+
async def get_all_settings() -> List[Dict[str, Any]]:
|
16 |
+
"""
|
17 |
+
获取所有设置
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
List[Dict[str, Any]]: 设置列表
|
21 |
+
"""
|
22 |
+
try:
|
23 |
+
query = select(Settings)
|
24 |
+
result = await database.fetch_all(query)
|
25 |
+
return [dict(row) for row in result]
|
26 |
+
except Exception as e:
|
27 |
+
logger.error(f"Failed to get all settings: {str(e)}")
|
28 |
+
raise
|
29 |
+
|
30 |
+
|
31 |
+
async def get_setting(key: str) -> Optional[Dict[str, Any]]:
|
32 |
+
"""
|
33 |
+
获取指定键的设置
|
34 |
+
|
35 |
+
Args:
|
36 |
+
key: 设置键名
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Optional[Dict[str, Any]]: 设置信息,如果不存在则返回None
|
40 |
+
"""
|
41 |
+
try:
|
42 |
+
query = select(Settings).where(Settings.key == key)
|
43 |
+
result = await database.fetch_one(query)
|
44 |
+
return dict(result) if result else None
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(f"Failed to get setting {key}: {str(e)}")
|
47 |
+
raise
|
48 |
+
|
49 |
+
|
50 |
+
async def update_setting(key: str, value: str, description: Optional[str] = None) -> bool:
|
51 |
+
"""
|
52 |
+
更新设置
|
53 |
+
|
54 |
+
Args:
|
55 |
+
key: 设置键名
|
56 |
+
value: 设置值
|
57 |
+
description: 设置描述
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
bool: 是否更新成功
|
61 |
+
"""
|
62 |
+
try:
|
63 |
+
# 检查设置是否存在
|
64 |
+
setting = await get_setting(key)
|
65 |
+
|
66 |
+
if setting:
|
67 |
+
# 更新设置
|
68 |
+
query = (
|
69 |
+
update(Settings)
|
70 |
+
.where(Settings.key == key)
|
71 |
+
.values(
|
72 |
+
value=value,
|
73 |
+
description=description if description else setting["description"],
|
74 |
+
updated_at=datetime.now()
|
75 |
+
)
|
76 |
+
)
|
77 |
+
await database.execute(query)
|
78 |
+
logger.info(f"Updated setting: {key}")
|
79 |
+
return True
|
80 |
+
else:
|
81 |
+
# 插入设置
|
82 |
+
query = (
|
83 |
+
insert(Settings)
|
84 |
+
.values(
|
85 |
+
key=key,
|
86 |
+
value=value,
|
87 |
+
description=description,
|
88 |
+
created_at=datetime.now(),
|
89 |
+
updated_at=datetime.now()
|
90 |
+
)
|
91 |
+
)
|
92 |
+
await database.execute(query)
|
93 |
+
logger.info(f"Inserted setting: {key}")
|
94 |
+
return True
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f"Failed to update setting {key}: {str(e)}")
|
97 |
+
return False
|
98 |
+
|
99 |
+
|
100 |
+
async def add_error_log(
|
101 |
+
gemini_key: Optional[str] = None,
|
102 |
+
model_name: Optional[str] = None,
|
103 |
+
error_type: Optional[str] = None,
|
104 |
+
error_log: Optional[str] = None,
|
105 |
+
error_code: Optional[int] = None,
|
106 |
+
request_msg: Optional[Union[Dict[str, Any], str]] = None
|
107 |
+
) -> bool:
|
108 |
+
"""
|
109 |
+
添加错误日志
|
110 |
+
|
111 |
+
Args:
|
112 |
+
gemini_key: Gemini API密钥
|
113 |
+
error_log: 错误日志
|
114 |
+
error_code: 错误代码 (例如 HTTP 状态码)
|
115 |
+
request_msg: 请求消息
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
bool: 是否添加成功
|
119 |
+
"""
|
120 |
+
try:
|
121 |
+
# 如果request_msg是字典,则转换为JSON字符串
|
122 |
+
if isinstance(request_msg, dict):
|
123 |
+
request_msg_json = request_msg
|
124 |
+
elif isinstance(request_msg, str):
|
125 |
+
try:
|
126 |
+
request_msg_json = json.loads(request_msg)
|
127 |
+
except json.JSONDecodeError:
|
128 |
+
request_msg_json = {"message": request_msg}
|
129 |
+
else:
|
130 |
+
request_msg_json = None
|
131 |
+
|
132 |
+
# 插入错误日志
|
133 |
+
query = (
|
134 |
+
insert(ErrorLog)
|
135 |
+
.values(
|
136 |
+
gemini_key=gemini_key,
|
137 |
+
error_type=error_type,
|
138 |
+
error_log=error_log,
|
139 |
+
model_name=model_name,
|
140 |
+
error_code=error_code,
|
141 |
+
request_msg=request_msg_json,
|
142 |
+
request_time=datetime.now()
|
143 |
+
)
|
144 |
+
)
|
145 |
+
await database.execute(query)
|
146 |
+
logger.info(f"Added error log for key: {gemini_key}")
|
147 |
+
return True
|
148 |
+
except Exception as e:
|
149 |
+
logger.error(f"Failed to add error log: {str(e)}")
|
150 |
+
return False
|
151 |
+
|
152 |
+
|
153 |
+
async def get_error_logs(
|
154 |
+
limit: int = 20,
|
155 |
+
offset: int = 0,
|
156 |
+
key_search: Optional[str] = None,
|
157 |
+
error_search: Optional[str] = None,
|
158 |
+
error_code_search: Optional[str] = None,
|
159 |
+
start_date: Optional[datetime] = None,
|
160 |
+
end_date: Optional[datetime] = None,
|
161 |
+
sort_by: str = 'id',
|
162 |
+
sort_order: str = 'desc'
|
163 |
+
) -> List[Dict[str, Any]]:
|
164 |
+
"""
|
165 |
+
获取错误日志,支持搜索、日期过滤和排序
|
166 |
+
|
167 |
+
Args:
|
168 |
+
limit (int): 限制数量
|
169 |
+
offset (int): 偏移量
|
170 |
+
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹��)
|
171 |
+
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
172 |
+
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
173 |
+
start_date (Optional[datetime]): 开始日期时间
|
174 |
+
end_date (Optional[datetime]): 结束日期时间
|
175 |
+
sort_by (str): 排序字段 (例如 'id', 'request_time')
|
176 |
+
sort_order (str): 排序顺序 ('asc' or 'desc')
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
List[Dict[str, Any]]: 错误日志列表
|
180 |
+
"""
|
181 |
+
try:
|
182 |
+
query = select(
|
183 |
+
ErrorLog.id,
|
184 |
+
ErrorLog.gemini_key,
|
185 |
+
ErrorLog.model_name,
|
186 |
+
ErrorLog.error_type,
|
187 |
+
ErrorLog.error_log,
|
188 |
+
ErrorLog.error_code,
|
189 |
+
ErrorLog.request_time
|
190 |
+
)
|
191 |
+
|
192 |
+
if key_search:
|
193 |
+
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
194 |
+
if error_search:
|
195 |
+
query = query.where(
|
196 |
+
(ErrorLog.error_type.ilike(f"%{error_search}%")) |
|
197 |
+
(ErrorLog.error_log.ilike(f"%{error_search}%"))
|
198 |
+
)
|
199 |
+
if start_date:
|
200 |
+
query = query.where(ErrorLog.request_time >= start_date)
|
201 |
+
if end_date:
|
202 |
+
query = query.where(ErrorLog.request_time < end_date)
|
203 |
+
if error_code_search:
|
204 |
+
try:
|
205 |
+
error_code_int = int(error_code_search)
|
206 |
+
query = query.where(ErrorLog.error_code == error_code_int)
|
207 |
+
except ValueError:
|
208 |
+
logger.warning(f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
209 |
+
|
210 |
+
sort_column = getattr(ErrorLog, sort_by, ErrorLog.id)
|
211 |
+
if sort_order.lower() == 'asc':
|
212 |
+
query = query.order_by(asc(sort_column))
|
213 |
+
else:
|
214 |
+
query = query.order_by(desc(sort_column))
|
215 |
+
|
216 |
+
query = query.limit(limit).offset(offset)
|
217 |
+
|
218 |
+
result = await database.fetch_all(query)
|
219 |
+
return [dict(row) for row in result]
|
220 |
+
except Exception as e:
|
221 |
+
logger.exception(f"Failed to get error logs with filters: {str(e)}")
|
222 |
+
raise
|
223 |
+
|
224 |
+
|
225 |
+
async def get_error_logs_count(
|
226 |
+
key_search: Optional[str] = None,
|
227 |
+
error_search: Optional[str] = None,
|
228 |
+
error_code_search: Optional[str] = None,
|
229 |
+
start_date: Optional[datetime] = None,
|
230 |
+
end_date: Optional[datetime] = None
|
231 |
+
) -> int:
|
232 |
+
"""
|
233 |
+
获取符合条件的错误日志总数
|
234 |
+
|
235 |
+
Args:
|
236 |
+
key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
|
237 |
+
error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
|
238 |
+
error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
|
239 |
+
start_date (Optional[datetime]): 开始日期时间
|
240 |
+
end_date (Optional[datetime]): 结束日期时间
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
int: 日志总数
|
244 |
+
"""
|
245 |
+
try:
|
246 |
+
query = select(func.count()).select_from(ErrorLog)
|
247 |
+
|
248 |
+
if key_search:
|
249 |
+
query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
|
250 |
+
if error_search:
|
251 |
+
query = query.where(
|
252 |
+
(ErrorLog.error_type.ilike(f"%{error_search}%")) |
|
253 |
+
(ErrorLog.error_log.ilike(f"%{error_search}%"))
|
254 |
+
)
|
255 |
+
if start_date:
|
256 |
+
query = query.where(ErrorLog.request_time >= start_date)
|
257 |
+
if end_date:
|
258 |
+
query = query.where(ErrorLog.request_time < end_date)
|
259 |
+
if error_code_search:
|
260 |
+
try:
|
261 |
+
error_code_int = int(error_code_search)
|
262 |
+
query = query.where(ErrorLog.error_code == error_code_int)
|
263 |
+
except ValueError:
|
264 |
+
logger.warning(f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter.")
|
265 |
+
|
266 |
+
|
267 |
+
count_result = await database.fetch_one(query)
|
268 |
+
return count_result[0] if count_result else 0
|
269 |
+
except Exception as e:
|
270 |
+
logger.exception(f"Failed to count error logs with filters: {str(e)}")
|
271 |
+
raise
|
272 |
+
|
273 |
+
|
274 |
+
# 新增函数:获取单条错误日志详情
|
275 |
+
async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
276 |
+
"""
|
277 |
+
根据 ID 获取单个错误日志的详细信息
|
278 |
+
|
279 |
+
Args:
|
280 |
+
log_id (int): 错误日志的 ID
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
Optional[Dict[str, Any]]: 包含日志详细信息的字典,如果未找到则返回 None
|
284 |
+
"""
|
285 |
+
try:
|
286 |
+
query = select(ErrorLog).where(ErrorLog.id == log_id)
|
287 |
+
result = await database.fetch_one(query)
|
288 |
+
if result:
|
289 |
+
# 将 request_msg (JSONB) 转换为字符串以便在 API 中返回
|
290 |
+
log_dict = dict(result)
|
291 |
+
if 'request_msg' in log_dict and log_dict['request_msg'] is not None:
|
292 |
+
# 确保即使是 None 或非 JSON 数据也能处理
|
293 |
+
try:
|
294 |
+
log_dict['request_msg'] = json.dumps(log_dict['request_msg'], ensure_ascii=False, indent=2)
|
295 |
+
except TypeError:
|
296 |
+
log_dict['request_msg'] = str(log_dict['request_msg'])
|
297 |
+
return log_dict
|
298 |
+
else:
|
299 |
+
return None
|
300 |
+
except Exception as e:
|
301 |
+
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
302 |
+
raise
|
303 |
+
|
304 |
+
|
305 |
+
async def delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
306 |
+
"""
|
307 |
+
根据提供的 ID 列表批量删除错误日志 (异步)。
|
308 |
+
|
309 |
+
Args:
|
310 |
+
log_ids: 要删除的错误日志 ID 列表。
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
int: 实际删除的日志数量。
|
314 |
+
"""
|
315 |
+
if not log_ids:
|
316 |
+
return 0
|
317 |
+
try:
|
318 |
+
# 使用 databases 执行删除
|
319 |
+
query = delete(ErrorLog).where(ErrorLog.id.in_(log_ids))
|
320 |
+
# execute 返回受影响的行数,但 databases 库的 execute 不直接返回 rowcount
|
321 |
+
# 我们需要先查询是否存在,或者依赖数据库约束/触发器(如果适用)
|
322 |
+
# 或者,我们可以执行删除并假设成功,除非抛出异常
|
323 |
+
# 为了简单起见,我们执行删除并记录日志,不精确返回删除数量
|
324 |
+
# 如果需要精确数量,需要先执行 SELECT COUNT(*)
|
325 |
+
await database.execute(query)
|
326 |
+
# 注意:databases 的 execute 不返回 rowcount,所以我们不能直接返回删除的数量
|
327 |
+
# 返回 log_ids 的长度作为尝试删除的数量,或者返回 0/1 表示操作尝试
|
328 |
+
logger.info(f"Attempted bulk deletion for error logs with IDs: {log_ids}")
|
329 |
+
return len(log_ids) # 返回尝试删除的数量
|
330 |
+
except Exception as e:
|
331 |
+
# 数据库连接或执行错误
|
332 |
+
logger.error(f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True)
|
333 |
+
raise
|
334 |
+
|
335 |
+
async def delete_error_log_by_id(log_id: int) -> bool:
|
336 |
+
"""
|
337 |
+
根据 ID 删除单个错误日志 (异步)。
|
338 |
+
|
339 |
+
Args:
|
340 |
+
log_id: 要删除的错误日志 ID。
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
bool: 如果成功删除返回 True,否则返回 False。
|
344 |
+
"""
|
345 |
+
try:
|
346 |
+
# 先检查是否存在 (可选,但更明确)
|
347 |
+
check_query = select(ErrorLog.id).where(ErrorLog.id == log_id)
|
348 |
+
exists = await database.fetch_one(check_query)
|
349 |
+
|
350 |
+
if not exists:
|
351 |
+
logger.warning(f"Attempted to delete non-existent error log with ID: {log_id}")
|
352 |
+
return False
|
353 |
+
|
354 |
+
# 执行删除
|
355 |
+
delete_query = delete(ErrorLog).where(ErrorLog.id == log_id)
|
356 |
+
await database.execute(delete_query)
|
357 |
+
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
358 |
+
return True
|
359 |
+
except Exception as e:
|
360 |
+
logger.error(f"Error deleting error log with ID {log_id}: {e}", exc_info=True)
|
361 |
+
raise
|
362 |
+
|
363 |
+
|
364 |
+
async def delete_all_error_logs() -> int:
|
365 |
+
"""
|
366 |
+
删除所有错误日志条目。
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
int: 被删除的错误日志数量。
|
370 |
+
"""
|
371 |
+
try:
|
372 |
+
# 1. 获取删除前的总数
|
373 |
+
count_query = select(func.count()).select_from(ErrorLog)
|
374 |
+
total_to_delete = await database.fetch_val(count_query)
|
375 |
+
|
376 |
+
if total_to_delete == 0:
|
377 |
+
logger.info("No error logs found to delete.")
|
378 |
+
return 0
|
379 |
+
|
380 |
+
# 2. 执行删除操作
|
381 |
+
delete_query = delete(ErrorLog)
|
382 |
+
await database.execute(delete_query)
|
383 |
+
|
384 |
+
logger.info(f"Successfully deleted all {total_to_delete} error logs.")
|
385 |
+
return total_to_delete
|
386 |
+
except Exception as e:
|
387 |
+
logger.error(f"Failed to delete all error logs: {str(e)}", exc_info=True)
|
388 |
+
raise
|
389 |
+
|
390 |
+
|
391 |
+
# 新增函数:添加请求日志
|
392 |
+
async def add_request_log(
|
393 |
+
model_name: Optional[str],
|
394 |
+
api_key: Optional[str],
|
395 |
+
is_success: bool,
|
396 |
+
status_code: Optional[int] = None,
|
397 |
+
latency_ms: Optional[int] = None,
|
398 |
+
request_time: Optional[datetime] = None
|
399 |
+
) -> bool:
|
400 |
+
"""
|
401 |
+
添加 API 请求日志
|
402 |
+
|
403 |
+
Args:
|
404 |
+
model_name: 模型名称
|
405 |
+
api_key: 使用的 API 密钥
|
406 |
+
is_success: 请求是否成功
|
407 |
+
status_code: API 响应状态码
|
408 |
+
latency_ms: 请求耗时(毫秒)
|
409 |
+
request_time: 请求发生时间 (如果为 None, 则使用当前时间)
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
bool: 是否添加成功
|
413 |
+
"""
|
414 |
+
try:
|
415 |
+
log_time = request_time if request_time else datetime.now()
|
416 |
+
|
417 |
+
query = insert(RequestLog).values(
|
418 |
+
request_time=log_time,
|
419 |
+
model_name=model_name,
|
420 |
+
api_key=api_key,
|
421 |
+
is_success=is_success,
|
422 |
+
status_code=status_code,
|
423 |
+
latency_ms=latency_ms
|
424 |
+
)
|
425 |
+
await database.execute(query)
|
426 |
+
return True
|
427 |
+
except Exception as e:
|
428 |
+
logger.error(f"Failed to add request log: {str(e)}")
|
429 |
+
return False
|
app/domain/gemini_models.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
|
5 |
+
from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
6 |
+
|
7 |
+
|
8 |
+
class SafetySetting(BaseModel):
|
9 |
+
category: Optional[
|
10 |
+
Literal[
|
11 |
+
"HARM_CATEGORY_HATE_SPEECH",
|
12 |
+
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
13 |
+
"HARM_CATEGORY_HARASSMENT",
|
14 |
+
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
15 |
+
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
16 |
+
]
|
17 |
+
] = None
|
18 |
+
threshold: Optional[
|
19 |
+
Literal[
|
20 |
+
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
21 |
+
"BLOCK_LOW_AND_ABOVE",
|
22 |
+
"BLOCK_MEDIUM_AND_ABOVE",
|
23 |
+
"BLOCK_ONLY_HIGH",
|
24 |
+
"BLOCK_NONE",
|
25 |
+
"OFF",
|
26 |
+
]
|
27 |
+
] = None
|
28 |
+
|
29 |
+
|
30 |
+
class GenerationConfig(BaseModel):
|
31 |
+
stopSequences: Optional[List[str]] = None
|
32 |
+
responseMimeType: Optional[str] = None
|
33 |
+
responseSchema: Optional[Dict[str, Any]] = None
|
34 |
+
candidateCount: Optional[int] = 1
|
35 |
+
maxOutputTokens: Optional[int] = None
|
36 |
+
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
37 |
+
topP: Optional[float] = DEFAULT_TOP_P
|
38 |
+
topK: Optional[int] = DEFAULT_TOP_K
|
39 |
+
presencePenalty: Optional[float] = None
|
40 |
+
frequencyPenalty: Optional[float] = None
|
41 |
+
responseLogprobs: Optional[bool] = None
|
42 |
+
logprobs: Optional[int] = None
|
43 |
+
thinkingConfig: Optional[Dict[str, Any]] = None
|
44 |
+
|
45 |
+
|
46 |
+
class SystemInstruction(BaseModel):
|
47 |
+
role: Optional[str] = "system"
|
48 |
+
parts: Union[List[Dict[str, Any]], Dict[str, Any]]
|
49 |
+
|
50 |
+
|
51 |
+
class GeminiContent(BaseModel):
|
52 |
+
role: Optional[str] = None
|
53 |
+
parts: List[Dict[str, Any]]
|
54 |
+
|
55 |
+
|
56 |
+
class GeminiRequest(BaseModel):
|
57 |
+
contents: List[GeminiContent] = []
|
58 |
+
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
59 |
+
safetySettings: Optional[List[SafetySetting]] = Field(
|
60 |
+
default=None, alias="safety_settings"
|
61 |
+
)
|
62 |
+
generationConfig: Optional[GenerationConfig] = Field(
|
63 |
+
default=None, alias="generation_config"
|
64 |
+
)
|
65 |
+
systemInstruction: Optional[SystemInstruction] = Field(
|
66 |
+
default=None, alias="system_instruction"
|
67 |
+
)
|
68 |
+
|
69 |
+
class Config:
|
70 |
+
populate_by_name = True
|
71 |
+
|
72 |
+
|
73 |
+
class ResetSelectedKeysRequest(BaseModel):
|
74 |
+
keys: List[str]
|
75 |
+
key_type: str
|
76 |
+
|
77 |
+
|
78 |
+
class VerifySelectedKeysRequest(BaseModel):
|
79 |
+
keys: List[str]
|
app/domain/image_models.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
|
4 |
+
class ImageMetadata:
|
5 |
+
def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: Union[str, None] = None):
|
6 |
+
self.width = width
|
7 |
+
self.height = height
|
8 |
+
self.filename = filename
|
9 |
+
self.size = size
|
10 |
+
self.url = url
|
11 |
+
self.delete_url = delete_url
|
12 |
+
class UploadResponse:
|
13 |
+
def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
|
14 |
+
self.success = success
|
15 |
+
self.code = code
|
16 |
+
self.message = message
|
17 |
+
self.data = data
|
18 |
+
class ImageUploader:
|
19 |
+
def upload(self, file: bytes, filename: str) -> UploadResponse:
|
20 |
+
raise NotImplementedError
|
app/domain/openai_models.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
from typing import Any, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
|
5 |
+
|
6 |
+
|
7 |
+
class ChatRequest(BaseModel):
|
8 |
+
messages: List[dict]
|
9 |
+
model: str = DEFAULT_MODEL
|
10 |
+
temperature: Optional[float] = DEFAULT_TEMPERATURE
|
11 |
+
stream: Optional[bool] = False
|
12 |
+
max_tokens: Optional[int] = None
|
13 |
+
top_p: Optional[float] = DEFAULT_TOP_P
|
14 |
+
top_k: Optional[int] = DEFAULT_TOP_K
|
15 |
+
stop: Optional[Union[List[str],str]] = None
|
16 |
+
reasoning_effort: Optional[str] = None
|
17 |
+
tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
|
18 |
+
tool_choice: Optional[str] = None
|
19 |
+
response_format: Optional[dict] = None
|
20 |
+
|
21 |
+
|
22 |
+
class EmbeddingRequest(BaseModel):
|
23 |
+
input: Union[str, List[str]]
|
24 |
+
model: str = "text-embedding-004"
|
25 |
+
encoding_format: Optional[str] = "float"
|
26 |
+
|
27 |
+
|
28 |
+
class ImageGenerationRequest(BaseModel):
|
29 |
+
model: str = "imagen-3.0-generate-002"
|
30 |
+
prompt: str = ""
|
31 |
+
n: int = 1
|
32 |
+
size: Optional[str] = "1024x1024"
|
33 |
+
quality: Optional[str] = None
|
34 |
+
style: Optional[str] = None
|
35 |
+
response_format: Optional[str] = "url"
|
36 |
+
|
37 |
+
|
38 |
+
class TTSRequest(BaseModel):
|
39 |
+
model: str = "gemini-2.5-flash-preview-tts"
|
40 |
+
input: str
|
41 |
+
voice: str = "Kore"
|
42 |
+
response_format: Optional[str] = "wav"
|
app/exception/exceptions.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
异常处理模块,定义应用程序中使用的自定义异常和异常处理器
|
3 |
+
"""
|
4 |
+
|
5 |
+
from fastapi import FastAPI, Request
|
6 |
+
from fastapi.exceptions import RequestValidationError
|
7 |
+
from fastapi.responses import JSONResponse
|
8 |
+
from starlette.exceptions import HTTPException as StarletteHTTPException
|
9 |
+
|
10 |
+
from app.log.logger import get_exceptions_logger
|
11 |
+
|
12 |
+
logger = get_exceptions_logger()
|
13 |
+
|
14 |
+
|
15 |
+
class APIError(Exception):
|
16 |
+
"""API错误基类"""
|
17 |
+
|
18 |
+
def __init__(self, status_code: int, detail: str, error_code: str = None):
|
19 |
+
self.status_code = status_code
|
20 |
+
self.detail = detail
|
21 |
+
self.error_code = error_code or "api_error"
|
22 |
+
super().__init__(self.detail)
|
23 |
+
|
24 |
+
|
25 |
+
class AuthenticationError(APIError):
|
26 |
+
"""认证错误"""
|
27 |
+
|
28 |
+
def __init__(self, detail: str = "Authentication failed"):
|
29 |
+
super().__init__(
|
30 |
+
status_code=401, detail=detail, error_code="authentication_error"
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class AuthorizationError(APIError):
|
35 |
+
"""授权错误"""
|
36 |
+
|
37 |
+
def __init__(self, detail: str = "Not authorized to access this resource"):
|
38 |
+
super().__init__(
|
39 |
+
status_code=403, detail=detail, error_code="authorization_error"
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
class ResourceNotFoundError(APIError):
|
44 |
+
"""资源未找到错误"""
|
45 |
+
|
46 |
+
def __init__(self, detail: str = "Resource not found"):
|
47 |
+
super().__init__(
|
48 |
+
status_code=404, detail=detail, error_code="resource_not_found"
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class ModelNotSupportedError(APIError):
|
53 |
+
"""模型不支持错误"""
|
54 |
+
|
55 |
+
def __init__(self, model: str):
|
56 |
+
super().__init__(
|
57 |
+
status_code=400,
|
58 |
+
detail=f"Model {model} is not supported",
|
59 |
+
error_code="model_not_supported",
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
class APIKeyError(APIError):
|
64 |
+
"""API密钥错误"""
|
65 |
+
|
66 |
+
def __init__(self, detail: str = "Invalid or expired API key"):
|
67 |
+
super().__init__(status_code=401, detail=detail, error_code="api_key_error")
|
68 |
+
|
69 |
+
|
70 |
+
class ServiceUnavailableError(APIError):
|
71 |
+
"""服务不可用错误"""
|
72 |
+
|
73 |
+
def __init__(self, detail: str = "Service temporarily unavailable"):
|
74 |
+
super().__init__(
|
75 |
+
status_code=503, detail=detail, error_code="service_unavailable"
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
def setup_exception_handlers(app: FastAPI) -> None:
|
80 |
+
"""
|
81 |
+
设置应用程序的异常处理器
|
82 |
+
|
83 |
+
Args:
|
84 |
+
app: FastAPI应用程序实例
|
85 |
+
"""
|
86 |
+
|
87 |
+
@app.exception_handler(APIError)
|
88 |
+
async def api_error_handler(request: Request, exc: APIError):
|
89 |
+
"""处理API错误"""
|
90 |
+
logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
|
91 |
+
return JSONResponse(
|
92 |
+
status_code=exc.status_code,
|
93 |
+
content={"error": {"code": exc.error_code, "message": exc.detail}},
|
94 |
+
)
|
95 |
+
|
96 |
+
@app.exception_handler(StarletteHTTPException)
|
97 |
+
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
98 |
+
"""处理HTTP异常"""
|
99 |
+
logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
|
100 |
+
return JSONResponse(
|
101 |
+
status_code=exc.status_code,
|
102 |
+
content={"error": {"code": "http_error", "message": exc.detail}},
|
103 |
+
)
|
104 |
+
|
105 |
+
@app.exception_handler(RequestValidationError)
|
106 |
+
async def validation_exception_handler(
|
107 |
+
request: Request, exc: RequestValidationError
|
108 |
+
):
|
109 |
+
"""处理请求验证错误"""
|
110 |
+
error_details = []
|
111 |
+
for error in exc.errors():
|
112 |
+
error_details.append(
|
113 |
+
{"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
|
114 |
+
)
|
115 |
+
|
116 |
+
logger.error(f"Validation Error: {error_details}")
|
117 |
+
return JSONResponse(
|
118 |
+
status_code=422,
|
119 |
+
content={
|
120 |
+
"error": {
|
121 |
+
"code": "validation_error",
|
122 |
+
"message": "Request validation failed",
|
123 |
+
"details": error_details,
|
124 |
+
}
|
125 |
+
},
|
126 |
+
)
|
127 |
+
|
128 |
+
@app.exception_handler(Exception)
|
129 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
130 |
+
"""处理通用异常"""
|
131 |
+
logger.exception(f"Unhandled Exception: {str(exc)}")
|
132 |
+
return JSONResponse(
|
133 |
+
status_code=500,
|
134 |
+
content={
|
135 |
+
"error": {
|
136 |
+
"code": "internal_server_error",
|
137 |
+
"message": "An unexpected error occurred",
|
138 |
+
}
|
139 |
+
},
|
140 |
+
)
|
app/handler/error_handler.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import asynccontextmanager
|
2 |
+
from fastapi import HTTPException
|
3 |
+
import logging
|
4 |
+
|
5 |
+
@asynccontextmanager
|
6 |
+
async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None):
|
7 |
+
"""
|
8 |
+
一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。
|
9 |
+
|
10 |
+
Args:
|
11 |
+
logger: 用于记录日志的 Logger 实例。
|
12 |
+
operation_name: 操作的名称,用于日志记录和错误详情。
|
13 |
+
success_message: 操作成功时记录的自定义消息 (可选)。
|
14 |
+
failure_message: 操作失败时记录的自定义消息 (可选)。
|
15 |
+
"""
|
16 |
+
default_success_msg = f"{operation_name} request successful"
|
17 |
+
default_failure_msg = f"{operation_name} request failed"
|
18 |
+
|
19 |
+
logger.info("-" * 50 + operation_name + "-" * 50)
|
20 |
+
try:
|
21 |
+
yield
|
22 |
+
logger.info(success_message or default_success_msg)
|
23 |
+
except HTTPException as http_exc:
|
24 |
+
# 如果已经是 HTTPException,直接重新抛出,保留原始状态码和详情
|
25 |
+
logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})")
|
26 |
+
raise http_exc
|
27 |
+
except Exception as e:
|
28 |
+
# 对于其他所有异常,记录错误并抛出标准的 500 错误
|
29 |
+
logger.error(f"{failure_message or default_failure_msg}: {str(e)}")
|
30 |
+
raise HTTPException(
|
31 |
+
status_code=500, detail=f"Internal server error during {operation_name}"
|
32 |
+
) from e
|
app/handler/message_converter.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from typing import Any, Dict, List, Optional
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from app.core.constants import (
|
10 |
+
AUDIO_FORMAT_TO_MIMETYPE,
|
11 |
+
DATA_URL_PATTERN,
|
12 |
+
IMAGE_URL_PATTERN,
|
13 |
+
MAX_AUDIO_SIZE_BYTES,
|
14 |
+
MAX_VIDEO_SIZE_BYTES,
|
15 |
+
SUPPORTED_AUDIO_FORMATS,
|
16 |
+
SUPPORTED_ROLES,
|
17 |
+
SUPPORTED_VIDEO_FORMATS,
|
18 |
+
VIDEO_FORMAT_TO_MIMETYPE,
|
19 |
+
)
|
20 |
+
from app.log.logger import get_message_converter_logger
|
21 |
+
|
22 |
+
logger = get_message_converter_logger()
|
23 |
+
|
24 |
+
|
25 |
+
class MessageConverter(ABC):
|
26 |
+
"""消息转换器基类"""
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def convert(
|
30 |
+
self, messages: List[Dict[str, Any]]
|
31 |
+
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
def _get_mime_type_and_data(base64_string):
|
36 |
+
"""
|
37 |
+
从 base64 字符串中提取 MIME 类型和数据。
|
38 |
+
|
39 |
+
参数:
|
40 |
+
base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
|
41 |
+
|
42 |
+
返回:
|
43 |
+
tuple: (mime_type, encoded_data)
|
44 |
+
"""
|
45 |
+
# 检查字符串是否以 "data:" 格式开始
|
46 |
+
if base64_string.startswith("data:"):
|
47 |
+
# 提取 MIME 类型和数据
|
48 |
+
pattern = DATA_URL_PATTERN
|
49 |
+
match = re.match(pattern, base64_string)
|
50 |
+
if match:
|
51 |
+
mime_type = (
|
52 |
+
"image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
|
53 |
+
)
|
54 |
+
encoded_data = match.group(2)
|
55 |
+
return mime_type, encoded_data
|
56 |
+
|
57 |
+
# 如果不是预期格式,假定它只是数据部分
|
58 |
+
return None, base64_string
|
59 |
+
|
60 |
+
|
61 |
+
def _convert_image(image_url: str) -> Dict[str, Any]:
|
62 |
+
if image_url.startswith("data:image"):
|
63 |
+
mime_type, encoded_data = _get_mime_type_and_data(image_url)
|
64 |
+
return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
|
65 |
+
else:
|
66 |
+
encoded_data = _convert_image_to_base64(image_url)
|
67 |
+
return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
|
68 |
+
|
69 |
+
|
70 |
+
def _convert_image_to_base64(url: str) -> str:
|
71 |
+
"""
|
72 |
+
将图片URL转换为base64编码
|
73 |
+
Args:
|
74 |
+
url: 图片URL
|
75 |
+
Returns:
|
76 |
+
str: base64编码的图片数据
|
77 |
+
"""
|
78 |
+
response = requests.get(url)
|
79 |
+
if response.status_code == 200:
|
80 |
+
# 将图片内容转换为base64
|
81 |
+
img_data = base64.b64encode(response.content).decode("utf-8")
|
82 |
+
return img_data
|
83 |
+
else:
|
84 |
+
raise Exception(f"Failed to fetch image: {response.status_code}")
|
85 |
+
|
86 |
+
|
87 |
+
def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
|
88 |
+
"""
|
89 |
+
处理可能包含图片URL的文本,提取图片并转换为base64
|
90 |
+
|
91 |
+
Args:
|
92 |
+
text: 可能包含图片URL的文本
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
List[Dict[str, Any]]: 包含文本和图片的部分列表
|
96 |
+
"""
|
97 |
+
parts = []
|
98 |
+
img_url_match = re.search(IMAGE_URL_PATTERN, text)
|
99 |
+
if img_url_match:
|
100 |
+
# 提取URL
|
101 |
+
img_url = img_url_match.group(2)
|
102 |
+
# 将URL对应的图片转换为base64
|
103 |
+
try:
|
104 |
+
base64_data = _convert_image_to_base64(img_url)
|
105 |
+
parts.append(
|
106 |
+
{"inline_data": {"mimeType": "image/png", "data": base64_data}}
|
107 |
+
)
|
108 |
+
except Exception:
|
109 |
+
# 如果转换失败,回退到文本模式
|
110 |
+
parts.append({"text": text})
|
111 |
+
else:
|
112 |
+
# 没有图片URL,作为纯文本处理
|
113 |
+
parts.append({"text": text})
|
114 |
+
return parts
|
115 |
+
|
116 |
+
|
117 |
+
class OpenAIMessageConverter(MessageConverter):
|
118 |
+
"""OpenAI消息格式转换器"""
|
119 |
+
|
120 |
+
def _validate_media_data(
|
121 |
+
self, format: str, data: str, supported_formats: List[str], max_size: int
|
122 |
+
) -> tuple[Optional[str], Optional[str]]:
|
123 |
+
"""Validates format and size of Base64 media data."""
|
124 |
+
if format.lower() not in supported_formats:
|
125 |
+
logger.error(
|
126 |
+
f"Unsupported media format: {format}. Supported: {supported_formats}"
|
127 |
+
)
|
128 |
+
raise ValueError(f"Unsupported media format: {format}")
|
129 |
+
|
130 |
+
try:
|
131 |
+
decoded_data = base64.b64decode(data, validate=True)
|
132 |
+
if len(decoded_data) > max_size:
|
133 |
+
logger.error(
|
134 |
+
f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
|
135 |
+
)
|
136 |
+
raise ValueError(
|
137 |
+
f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
|
138 |
+
)
|
139 |
+
return data
|
140 |
+
except base64.binascii.Error as e:
|
141 |
+
logger.error(f"Invalid Base64 data provided: {e}")
|
142 |
+
raise ValueError("Invalid Base64 data")
|
143 |
+
except Exception as e:
|
144 |
+
logger.error(f"Error validating media data: {e}")
|
145 |
+
raise
|
146 |
+
|
147 |
+
def convert(
|
148 |
+
self, messages: List[Dict[str, Any]]
|
149 |
+
) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
150 |
+
converted_messages = []
|
151 |
+
system_instruction_parts = []
|
152 |
+
|
153 |
+
for idx, msg in enumerate(messages):
|
154 |
+
role = msg.get("role", "")
|
155 |
+
parts = []
|
156 |
+
|
157 |
+
if "content" in msg and isinstance(msg["content"], list):
|
158 |
+
for content_item in msg["content"]:
|
159 |
+
if not isinstance(content_item, dict):
|
160 |
+
logger.warning(
|
161 |
+
f"Skipping unexpected content item format: {type(content_item)}"
|
162 |
+
)
|
163 |
+
continue
|
164 |
+
|
165 |
+
content_type = content_item.get("type")
|
166 |
+
|
167 |
+
if content_type == "text" and content_item.get("text"):
|
168 |
+
parts.append({"text": content_item["text"]})
|
169 |
+
elif content_type == "image_url" and content_item.get(
|
170 |
+
"image_url", {}
|
171 |
+
).get("url"):
|
172 |
+
try:
|
173 |
+
parts.append(
|
174 |
+
_convert_image(content_item["image_url"]["url"])
|
175 |
+
)
|
176 |
+
except Exception as e:
|
177 |
+
logger.error(
|
178 |
+
f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
|
179 |
+
)
|
180 |
+
parts.append(
|
181 |
+
{
|
182 |
+
"text": f"[Error processing image: {content_item['image_url']['url']}]"
|
183 |
+
}
|
184 |
+
)
|
185 |
+
elif content_type == "input_audio" and content_item.get(
|
186 |
+
"input_audio"
|
187 |
+
):
|
188 |
+
audio_info = content_item["input_audio"]
|
189 |
+
audio_data = audio_info.get("data")
|
190 |
+
audio_format = audio_info.get("format", "").lower()
|
191 |
+
|
192 |
+
if not audio_data or not audio_format:
|
193 |
+
logger.warning(
|
194 |
+
"Skipping audio part due to missing data or format."
|
195 |
+
)
|
196 |
+
continue
|
197 |
+
|
198 |
+
try:
|
199 |
+
validated_data = self._validate_media_data(
|
200 |
+
audio_format,
|
201 |
+
audio_data,
|
202 |
+
SUPPORTED_AUDIO_FORMATS,
|
203 |
+
MAX_AUDIO_SIZE_BYTES,
|
204 |
+
)
|
205 |
+
|
206 |
+
# Get MIME type
|
207 |
+
mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
|
208 |
+
if not mime_type:
|
209 |
+
# Should not happen if format validation passed, but double-check
|
210 |
+
logger.error(
|
211 |
+
f"Could not find MIME type for supported format: {audio_format}"
|
212 |
+
)
|
213 |
+
raise ValueError(
|
214 |
+
f"Internal error: MIME type mapping missing for {audio_format}"
|
215 |
+
)
|
216 |
+
|
217 |
+
parts.append(
|
218 |
+
{
|
219 |
+
"inline_data": {
|
220 |
+
"mimeType": mime_type,
|
221 |
+
"data": validated_data, # Use the validated Base64 data
|
222 |
+
}
|
223 |
+
}
|
224 |
+
)
|
225 |
+
logger.debug(
|
226 |
+
f"Successfully added audio part (format: {audio_format})"
|
227 |
+
)
|
228 |
+
|
229 |
+
except ValueError as e:
|
230 |
+
logger.error(
|
231 |
+
f"Skipping audio part due to validation error: {e}"
|
232 |
+
)
|
233 |
+
parts.append({"text": f"[Error processing audio: {e}]"})
|
234 |
+
except Exception:
|
235 |
+
logger.exception("Unexpected error processing audio part.")
|
236 |
+
parts.append(
|
237 |
+
{"text": "[Unexpected error processing audio]"}
|
238 |
+
)
|
239 |
+
|
240 |
+
elif content_type == "input_video" and content_item.get(
|
241 |
+
"input_video"
|
242 |
+
):
|
243 |
+
video_info = content_item["input_video"]
|
244 |
+
video_data = video_info.get("data")
|
245 |
+
video_format = video_info.get("format", "").lower()
|
246 |
+
|
247 |
+
if not video_data or not video_format:
|
248 |
+
logger.warning(
|
249 |
+
"Skipping video part due to missing data or format."
|
250 |
+
)
|
251 |
+
continue
|
252 |
+
|
253 |
+
try:
|
254 |
+
validated_data = self._validate_media_data(
|
255 |
+
video_format,
|
256 |
+
video_data,
|
257 |
+
SUPPORTED_VIDEO_FORMATS,
|
258 |
+
MAX_VIDEO_SIZE_BYTES,
|
259 |
+
)
|
260 |
+
mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
|
261 |
+
if not mime_type:
|
262 |
+
raise ValueError(
|
263 |
+
f"Internal error: MIME type mapping missing for {video_format}"
|
264 |
+
)
|
265 |
+
|
266 |
+
parts.append(
|
267 |
+
{
|
268 |
+
"inline_data": {
|
269 |
+
"mimeType": mime_type,
|
270 |
+
"data": validated_data,
|
271 |
+
}
|
272 |
+
}
|
273 |
+
)
|
274 |
+
logger.debug(
|
275 |
+
f"Successfully added video part (format: {video_format})"
|
276 |
+
)
|
277 |
+
|
278 |
+
except ValueError as e:
|
279 |
+
logger.error(
|
280 |
+
f"Skipping video part due to validation error: {e}"
|
281 |
+
)
|
282 |
+
parts.append({"text": f"[Error processing video: {e}]"})
|
283 |
+
except Exception:
|
284 |
+
logger.exception("Unexpected error processing video part.")
|
285 |
+
parts.append(
|
286 |
+
{"text": "[Unexpected error processing video]"}
|
287 |
+
)
|
288 |
+
|
289 |
+
else:
|
290 |
+
# Log unrecognized but present types
|
291 |
+
if content_type:
|
292 |
+
logger.warning(
|
293 |
+
f"Unsupported content type or missing data in structured content: {content_type}"
|
294 |
+
)
|
295 |
+
|
296 |
+
elif (
|
297 |
+
"content" in msg and isinstance(msg["content"], str) and msg["content"]
|
298 |
+
):
|
299 |
+
parts.extend(_process_text_with_image(msg["content"]))
|
300 |
+
elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
|
301 |
+
# Keep existing tool call processing
|
302 |
+
for tool_call in msg["tool_calls"]:
|
303 |
+
function_call = tool_call.get("function", {})
|
304 |
+
# Sanitize arguments loading
|
305 |
+
arguments_str = function_call.get("arguments", "{}")
|
306 |
+
try:
|
307 |
+
function_call["args"] = json.loads(arguments_str)
|
308 |
+
except json.JSONDecodeError:
|
309 |
+
logger.warning(
|
310 |
+
f"Failed to decode tool call arguments: {arguments_str}"
|
311 |
+
)
|
312 |
+
function_call["args"] = {}
|
313 |
+
if "arguments" in function_call:
|
314 |
+
if "arguments" in function_call:
|
315 |
+
del function_call["arguments"]
|
316 |
+
|
317 |
+
parts.append({"functionCall": function_call})
|
318 |
+
|
319 |
+
if role not in SUPPORTED_ROLES:
|
320 |
+
if role == "tool":
|
321 |
+
role = "user"
|
322 |
+
else:
|
323 |
+
# 如果是最后一条消息,则认为是用户消息
|
324 |
+
if idx == len(messages) - 1:
|
325 |
+
role = "user"
|
326 |
+
else:
|
327 |
+
role = "model"
|
328 |
+
if parts:
|
329 |
+
if role == "system":
|
330 |
+
text_only_parts = [p for p in parts if "text" in p]
|
331 |
+
if len(text_only_parts) != len(parts):
|
332 |
+
logger.warning(
|
333 |
+
"Non-text parts found in system message; discarding them."
|
334 |
+
)
|
335 |
+
if text_only_parts:
|
336 |
+
system_instruction_parts.extend(text_only_parts)
|
337 |
+
|
338 |
+
else:
|
339 |
+
converted_messages.append({"role": role, "parts": parts})
|
340 |
+
|
341 |
+
system_instruction = (
|
342 |
+
None
|
343 |
+
if not system_instruction_parts
|
344 |
+
else {
|
345 |
+
"role": "system",
|
346 |
+
"parts": system_instruction_parts,
|
347 |
+
}
|
348 |
+
)
|
349 |
+
return converted_messages, system_instruction
|
app/handler/response_handler.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import string
|
5 |
+
import time
|
6 |
+
import uuid
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from typing import Any, Dict, List, Optional
|
9 |
+
|
10 |
+
from app.config.config import settings
|
11 |
+
from app.utils.uploader import ImageUploaderFactory
|
12 |
+
|
13 |
+
|
14 |
+
class ResponseHandler(ABC):
|
15 |
+
"""响应处理器基类"""
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def handle_response(
|
19 |
+
self, response: Dict[str, Any], model: str, stream: bool = False
|
20 |
+
) -> Dict[str, Any]:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class GeminiResponseHandler(ResponseHandler):
|
25 |
+
"""Gemini响应处理器"""
|
26 |
+
|
27 |
+
def __init__(self):
|
28 |
+
self.thinking_first = True
|
29 |
+
self.thinking_status = False
|
30 |
+
|
31 |
+
def handle_response(
|
32 |
+
self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
|
33 |
+
) -> Dict[str, Any]:
|
34 |
+
if stream:
|
35 |
+
return _handle_gemini_stream_response(response, model, stream)
|
36 |
+
return _handle_gemini_normal_response(response, model, stream)
|
37 |
+
|
38 |
+
|
39 |
+
def _handle_openai_stream_response(
|
40 |
+
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
41 |
+
) -> Dict[str, Any]:
|
42 |
+
text, tool_calls, _ = _extract_result(
|
43 |
+
response, model, stream=True, gemini_format=False
|
44 |
+
)
|
45 |
+
if not text and not tool_calls:
|
46 |
+
delta = {}
|
47 |
+
else:
|
48 |
+
delta = {"content": text, "role": "assistant"}
|
49 |
+
if tool_calls:
|
50 |
+
delta["tool_calls"] = tool_calls
|
51 |
+
template_chunk = {
|
52 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
53 |
+
"object": "chat.completion.chunk",
|
54 |
+
"created": int(time.time()),
|
55 |
+
"model": model,
|
56 |
+
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
|
57 |
+
}
|
58 |
+
if usage_metadata:
|
59 |
+
template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
|
60 |
+
return template_chunk
|
61 |
+
|
62 |
+
|
63 |
+
def _handle_openai_normal_response(
|
64 |
+
response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
|
65 |
+
) -> Dict[str, Any]:
|
66 |
+
text, tool_calls, _ = _extract_result(
|
67 |
+
response, model, stream=False, gemini_format=False
|
68 |
+
)
|
69 |
+
return {
|
70 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
71 |
+
"object": "chat.completion",
|
72 |
+
"created": int(time.time()),
|
73 |
+
"model": model,
|
74 |
+
"choices": [
|
75 |
+
{
|
76 |
+
"index": 0,
|
77 |
+
"message": {
|
78 |
+
"role": "assistant",
|
79 |
+
"content": text,
|
80 |
+
"tool_calls": tool_calls,
|
81 |
+
},
|
82 |
+
"finish_reason": finish_reason,
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
class OpenAIResponseHandler(ResponseHandler):
|
90 |
+
"""OpenAI响应处理器"""
|
91 |
+
|
92 |
+
def __init__(self, config):
|
93 |
+
self.config = config
|
94 |
+
self.thinking_first = True
|
95 |
+
self.thinking_status = False
|
96 |
+
|
97 |
+
def handle_response(
|
98 |
+
self,
|
99 |
+
response: Dict[str, Any],
|
100 |
+
model: str,
|
101 |
+
stream: bool = False,
|
102 |
+
finish_reason: str = None,
|
103 |
+
usage_metadata: Optional[Dict[str, Any]] = None,
|
104 |
+
) -> Optional[Dict[str, Any]]:
|
105 |
+
if stream:
|
106 |
+
return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
|
107 |
+
return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
|
108 |
+
|
109 |
+
def handle_image_chat_response(
|
110 |
+
self, image_str: str, model: str, stream=False, finish_reason="stop"
|
111 |
+
):
|
112 |
+
if stream:
|
113 |
+
return _handle_openai_stream_image_response(image_str, model, finish_reason)
|
114 |
+
return _handle_openai_normal_image_response(image_str, model, finish_reason)
|
115 |
+
|
116 |
+
|
117 |
+
def _handle_openai_stream_image_response(
|
118 |
+
image_str: str, model: str, finish_reason: str
|
119 |
+
) -> Dict[str, Any]:
|
120 |
+
return {
|
121 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
122 |
+
"object": "chat.completion.chunk",
|
123 |
+
"created": int(time.time()),
|
124 |
+
"model": model,
|
125 |
+
"choices": [
|
126 |
+
{
|
127 |
+
"index": 0,
|
128 |
+
"delta": {"content": image_str} if image_str else {},
|
129 |
+
"finish_reason": finish_reason,
|
130 |
+
}
|
131 |
+
],
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
def _handle_openai_normal_image_response(
|
136 |
+
image_str: str, model: str, finish_reason: str
|
137 |
+
) -> Dict[str, Any]:
|
138 |
+
return {
|
139 |
+
"id": f"chatcmpl-{uuid.uuid4()}",
|
140 |
+
"object": "chat.completion",
|
141 |
+
"created": int(time.time()),
|
142 |
+
"model": model,
|
143 |
+
"choices": [
|
144 |
+
{
|
145 |
+
"index": 0,
|
146 |
+
"message": {"role": "assistant", "content": image_str},
|
147 |
+
"finish_reason": finish_reason,
|
148 |
+
}
|
149 |
+
],
|
150 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
151 |
+
}
|
152 |
+
|
153 |
+
|
154 |
+
def _extract_result(
|
155 |
+
response: Dict[str, Any],
|
156 |
+
model: str,
|
157 |
+
stream: bool = False,
|
158 |
+
gemini_format: bool = False,
|
159 |
+
) -> tuple[str, List[Dict[str, Any]], Optional[bool]]:
|
160 |
+
text, tool_calls = "", []
|
161 |
+
thought = None
|
162 |
+
if stream:
|
163 |
+
if response.get("candidates"):
|
164 |
+
candidate = response["candidates"][0]
|
165 |
+
content = candidate.get("content", {})
|
166 |
+
parts = content.get("parts", [])
|
167 |
+
if not parts:
|
168 |
+
return "", [], None
|
169 |
+
if "text" in parts[0]:
|
170 |
+
text = parts[0].get("text")
|
171 |
+
if "thought" in parts[0]:
|
172 |
+
thought = parts[0].get("thought")
|
173 |
+
elif "executableCode" in parts[0]:
|
174 |
+
text = _format_code_block(parts[0]["executableCode"])
|
175 |
+
elif "codeExecution" in parts[0]:
|
176 |
+
text = _format_code_block(parts[0]["codeExecution"])
|
177 |
+
elif "executableCodeResult" in parts[0]:
|
178 |
+
text = _format_execution_result(parts[0]["executableCodeResult"])
|
179 |
+
elif "codeExecutionResult" in parts[0]:
|
180 |
+
text = _format_execution_result(parts[0]["codeExecutionResult"])
|
181 |
+
elif "inlineData" in parts[0]:
|
182 |
+
text = _extract_image_data(parts[0])
|
183 |
+
else:
|
184 |
+
text = ""
|
185 |
+
text = _add_search_link_text(model, candidate, text)
|
186 |
+
tool_calls = _extract_tool_calls(parts, gemini_format)
|
187 |
+
else:
|
188 |
+
if response.get("candidates"):
|
189 |
+
candidate = response["candidates"][0]
|
190 |
+
if "thinking" in model:
|
191 |
+
if settings.SHOW_THINKING_PROCESS:
|
192 |
+
if len(candidate["content"]["parts"]) == 2:
|
193 |
+
text = (
|
194 |
+
"> thinking\n\n"
|
195 |
+
+ candidate["content"]["parts"][0]["text"]
|
196 |
+
+ "\n\n---\n> output\n\n"
|
197 |
+
+ candidate["content"]["parts"][1]["text"]
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
text = candidate["content"]["parts"][0]["text"]
|
201 |
+
else:
|
202 |
+
if len(candidate["content"]["parts"]) == 2:
|
203 |
+
text = candidate["content"]["parts"][1]["text"]
|
204 |
+
else:
|
205 |
+
text = candidate["content"]["parts"][0]["text"]
|
206 |
+
else:
|
207 |
+
text = ""
|
208 |
+
if "parts" in candidate["content"]:
|
209 |
+
for part in candidate["content"]["parts"]:
|
210 |
+
if "text" in part:
|
211 |
+
text += part["text"]
|
212 |
+
if "thought" in part and thought is None:
|
213 |
+
thought = part.get("thought")
|
214 |
+
elif "inlineData" in part:
|
215 |
+
text += _extract_image_data(part)
|
216 |
+
|
217 |
+
text = _add_search_link_text(model, candidate, text)
|
218 |
+
tool_calls = _extract_tool_calls(
|
219 |
+
candidate["content"]["parts"], gemini_format
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
text = "暂无返回"
|
223 |
+
return text, tool_calls, thought
|
224 |
+
|
225 |
+
|
226 |
+
def _extract_image_data(part: dict) -> str:
|
227 |
+
image_uploader = None
|
228 |
+
if settings.UPLOAD_PROVIDER == "smms":
|
229 |
+
image_uploader = ImageUploaderFactory.create(
|
230 |
+
provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
|
231 |
+
)
|
232 |
+
elif settings.UPLOAD_PROVIDER == "picgo":
|
233 |
+
image_uploader = ImageUploaderFactory.create(
|
234 |
+
provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
|
235 |
+
)
|
236 |
+
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
237 |
+
image_uploader = ImageUploaderFactory.create(
|
238 |
+
provider=settings.UPLOAD_PROVIDER,
|
239 |
+
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
240 |
+
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
241 |
+
)
|
242 |
+
current_date = time.strftime("%Y/%m/%d")
|
243 |
+
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
244 |
+
base64_data = part["inlineData"]["data"]
|
245 |
+
# 将base64_data转成bytes数组
|
246 |
+
bytes_data = base64.b64decode(base64_data)
|
247 |
+
upload_response = image_uploader.upload(bytes_data, filename)
|
248 |
+
if upload_response.success:
|
249 |
+
text = f"\n\n\n\n"
|
250 |
+
else:
|
251 |
+
text = ""
|
252 |
+
return text
|
253 |
+
|
254 |
+
|
255 |
+
def _extract_tool_calls(
|
256 |
+
parts: List[Dict[str, Any]], gemini_format: bool
|
257 |
+
) -> List[Dict[str, Any]]:
|
258 |
+
"""提取工具调用信息"""
|
259 |
+
if not parts or not isinstance(parts, list):
|
260 |
+
return []
|
261 |
+
|
262 |
+
letters = string.ascii_lowercase + string.digits
|
263 |
+
|
264 |
+
tool_calls = list()
|
265 |
+
for i in range(len(parts)):
|
266 |
+
part = parts[i]
|
267 |
+
if not part or not isinstance(part, dict):
|
268 |
+
continue
|
269 |
+
|
270 |
+
item = part.get("functionCall", {})
|
271 |
+
if not item or not isinstance(item, dict):
|
272 |
+
continue
|
273 |
+
|
274 |
+
if gemini_format:
|
275 |
+
tool_calls.append(part)
|
276 |
+
else:
|
277 |
+
id = f"call_{''.join(random.sample(letters, 32))}"
|
278 |
+
name = item.get("name", "")
|
279 |
+
arguments = json.dumps(item.get("args", None) or {})
|
280 |
+
|
281 |
+
tool_calls.append(
|
282 |
+
{
|
283 |
+
"index": i,
|
284 |
+
"id": id,
|
285 |
+
"type": "function",
|
286 |
+
"function": {"name": name, "arguments": arguments},
|
287 |
+
}
|
288 |
+
)
|
289 |
+
|
290 |
+
return tool_calls
|
291 |
+
|
292 |
+
|
293 |
+
def _handle_gemini_stream_response(
|
294 |
+
response: Dict[str, Any], model: str, stream: bool
|
295 |
+
) -> Dict[str, Any]:
|
296 |
+
text, tool_calls, thought = _extract_result(
|
297 |
+
response, model, stream=stream, gemini_format=True
|
298 |
+
)
|
299 |
+
if tool_calls:
|
300 |
+
content = {"parts": tool_calls, "role": "model"}
|
301 |
+
else:
|
302 |
+
part = {"text": text}
|
303 |
+
if thought is not None:
|
304 |
+
part["thought"] = thought
|
305 |
+
content = {"parts": [part], "role": "model"}
|
306 |
+
response["candidates"][0]["content"] = content
|
307 |
+
return response
|
308 |
+
|
309 |
+
|
310 |
+
def _handle_gemini_normal_response(
|
311 |
+
response: Dict[str, Any], model: str, stream: bool
|
312 |
+
) -> Dict[str, Any]:
|
313 |
+
text, tool_calls, thought = _extract_result(
|
314 |
+
response, model, stream=stream, gemini_format=True
|
315 |
+
)
|
316 |
+
if tool_calls:
|
317 |
+
content = {"parts": tool_calls, "role": "model"}
|
318 |
+
else:
|
319 |
+
part = {"text": text}
|
320 |
+
if thought is not None:
|
321 |
+
part["thought"] = thought
|
322 |
+
content = {"parts": [part], "role": "model"}
|
323 |
+
response["candidates"][0]["content"] = content
|
324 |
+
return response
|
325 |
+
|
326 |
+
|
327 |
+
def _format_code_block(code_data: dict) -> str:
|
328 |
+
"""格式化代码块输出"""
|
329 |
+
language = code_data.get("language", "").lower()
|
330 |
+
code = code_data.get("code", "").strip()
|
331 |
+
return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n"""
|
332 |
+
|
333 |
+
|
334 |
+
def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
|
335 |
+
if (
|
336 |
+
settings.SHOW_SEARCH_LINK
|
337 |
+
and model.endswith("-search")
|
338 |
+
and "groundingMetadata" in candidate
|
339 |
+
and "groundingChunks" in candidate["groundingMetadata"]
|
340 |
+
):
|
341 |
+
grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
|
342 |
+
text += "\n\n---\n\n"
|
343 |
+
text += "**【引用来源】**\n\n"
|
344 |
+
for _, grounding_chunk in enumerate(grounding_chunks, 1):
|
345 |
+
if "web" in grounding_chunk:
|
346 |
+
text += _create_search_link(grounding_chunk["web"])
|
347 |
+
return text
|
348 |
+
else:
|
349 |
+
return text
|
350 |
+
|
351 |
+
|
352 |
+
def _create_search_link(grounding_chunk: dict) -> str:
|
353 |
+
return f'\n- [{grounding_chunk["title"]}]({grounding_chunk["uri"]})'
|
354 |
+
|
355 |
+
|
356 |
+
def _format_execution_result(result_data: dict) -> str:
|
357 |
+
"""格式化执行结果输出"""
|
358 |
+
outcome = result_data.get("outcome", "")
|
359 |
+
output = result_data.get("output", "").strip()
|
360 |
+
return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n"""
|
app/handler/retry_handler.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from functools import wraps
|
3 |
+
from typing import Callable, TypeVar
|
4 |
+
|
5 |
+
from app.config.config import settings
|
6 |
+
from app.log.logger import get_retry_logger
|
7 |
+
|
8 |
+
T = TypeVar("T")
|
9 |
+
logger = get_retry_logger()
|
10 |
+
|
11 |
+
|
12 |
+
class RetryHandler:
|
13 |
+
"""重试处理装饰器"""
|
14 |
+
|
15 |
+
def __init__(self, key_arg: str = "api_key"):
|
16 |
+
self.key_arg = key_arg
|
17 |
+
|
18 |
+
def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
|
19 |
+
@wraps(func)
|
20 |
+
async def wrapper(*args, **kwargs) -> T:
|
21 |
+
last_exception = None
|
22 |
+
|
23 |
+
for attempt in range(settings.MAX_RETRIES):
|
24 |
+
retries = attempt + 1
|
25 |
+
try:
|
26 |
+
return await func(*args, **kwargs)
|
27 |
+
except Exception as e:
|
28 |
+
last_exception = e
|
29 |
+
logger.warning(
|
30 |
+
f"API call failed with error: {str(e)}. Attempt {retries} of {settings.MAX_RETRIES}"
|
31 |
+
)
|
32 |
+
|
33 |
+
# 从函数参数中获取 key_manager
|
34 |
+
key_manager = kwargs.get("key_manager")
|
35 |
+
if key_manager:
|
36 |
+
old_key = kwargs.get(self.key_arg)
|
37 |
+
new_key = await key_manager.handle_api_failure(old_key, retries)
|
38 |
+
if new_key:
|
39 |
+
kwargs[self.key_arg] = new_key
|
40 |
+
logger.info(f"Switched to new API key: {new_key}")
|
41 |
+
else:
|
42 |
+
logger.error(f"No valid API key available after {retries} retries.")
|
43 |
+
break
|
44 |
+
|
45 |
+
logger.error(
|
46 |
+
f"All retry attempts failed, raising final exception: {str(last_exception)}"
|
47 |
+
)
|
48 |
+
raise last_exception
|
49 |
+
|
50 |
+
return wrapper
|
app/handler/stream_optimizer.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import asyncio
|
3 |
+
import math
|
4 |
+
from typing import Any, AsyncGenerator, Callable, List
|
5 |
+
|
6 |
+
from app.config.config import settings
|
7 |
+
from app.core.constants import (
|
8 |
+
DEFAULT_STREAM_CHUNK_SIZE,
|
9 |
+
DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
10 |
+
DEFAULT_STREAM_MAX_DELAY,
|
11 |
+
DEFAULT_STREAM_MIN_DELAY,
|
12 |
+
DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
13 |
+
)
|
14 |
+
from app.log.logger import get_gemini_logger, get_openai_logger
|
15 |
+
|
16 |
+
logger_openai = get_openai_logger()
|
17 |
+
logger_gemini = get_gemini_logger()
|
18 |
+
|
19 |
+
|
20 |
+
class StreamOptimizer:
|
21 |
+
"""流式输出优化器
|
22 |
+
|
23 |
+
提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
logger=None,
|
29 |
+
min_delay: float = DEFAULT_STREAM_MIN_DELAY,
|
30 |
+
max_delay: float = DEFAULT_STREAM_MAX_DELAY,
|
31 |
+
short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
|
32 |
+
long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
|
33 |
+
chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
|
34 |
+
):
|
35 |
+
"""初始化流式输出优化器
|
36 |
+
|
37 |
+
参数:
|
38 |
+
logger: 日志记录器
|
39 |
+
min_delay: 最小延迟时间(秒)
|
40 |
+
max_delay: 最大延迟时间(秒)
|
41 |
+
short_text_threshold: 短文本阈值(字符数)
|
42 |
+
long_text_threshold: 长文本阈值(字符数)
|
43 |
+
chunk_size: 长文本分块大小(字符数)
|
44 |
+
"""
|
45 |
+
self.logger = logger
|
46 |
+
self.min_delay = min_delay
|
47 |
+
self.max_delay = max_delay
|
48 |
+
self.short_text_threshold = short_text_threshold
|
49 |
+
self.long_text_threshold = long_text_threshold
|
50 |
+
self.chunk_size = chunk_size
|
51 |
+
|
52 |
+
def calculate_delay(self, text_length: int) -> float:
|
53 |
+
"""根据文本长度计算延迟时间
|
54 |
+
|
55 |
+
参数:
|
56 |
+
text_length: 文本长度
|
57 |
+
|
58 |
+
返回:
|
59 |
+
延迟时间(秒)
|
60 |
+
"""
|
61 |
+
if text_length <= self.short_text_threshold:
|
62 |
+
# 短文本使用较大延迟
|
63 |
+
return self.max_delay
|
64 |
+
elif text_length >= self.long_text_threshold:
|
65 |
+
# 长文本使用较小延迟
|
66 |
+
return self.min_delay
|
67 |
+
else:
|
68 |
+
# 中等长度文本使用线性插值计算延迟
|
69 |
+
# 使用对数函数使延迟变化更平滑
|
70 |
+
ratio = math.log(text_length / self.short_text_threshold) / math.log(
|
71 |
+
self.long_text_threshold / self.short_text_threshold
|
72 |
+
)
|
73 |
+
return self.max_delay - ratio * (self.max_delay - self.min_delay)
|
74 |
+
|
75 |
+
def split_text_into_chunks(self, text: str) -> List[str]:
|
76 |
+
"""将文本分割成小块
|
77 |
+
|
78 |
+
参数:
|
79 |
+
text: 要分割的文本
|
80 |
+
|
81 |
+
返回:
|
82 |
+
文本块列表
|
83 |
+
"""
|
84 |
+
return [
|
85 |
+
text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
|
86 |
+
]
|
87 |
+
|
88 |
+
async def optimize_stream_output(
|
89 |
+
self,
|
90 |
+
text: str,
|
91 |
+
create_response_chunk: Callable[[str], Any],
|
92 |
+
format_chunk: Callable[[Any], str],
|
93 |
+
) -> AsyncGenerator[str, None]:
|
94 |
+
"""优化流式输出
|
95 |
+
|
96 |
+
参数:
|
97 |
+
text: 要输出的文本
|
98 |
+
create_response_chunk: 创建响应块的函数,接收文本,返回响应块
|
99 |
+
format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
|
100 |
+
|
101 |
+
返回:
|
102 |
+
异步生成器,生成格式化后的响应块
|
103 |
+
"""
|
104 |
+
if not text:
|
105 |
+
return
|
106 |
+
|
107 |
+
# 计算智能延迟时间
|
108 |
+
delay = self.calculate_delay(len(text))
|
109 |
+
|
110 |
+
# 根据文本长度决定输出方式
|
111 |
+
if len(text) >= self.long_text_threshold:
|
112 |
+
# 长文本:分块输出
|
113 |
+
chunks = self.split_text_into_chunks(text)
|
114 |
+
for chunk_text in chunks:
|
115 |
+
chunk_response = create_response_chunk(chunk_text)
|
116 |
+
yield format_chunk(chunk_response)
|
117 |
+
await asyncio.sleep(delay)
|
118 |
+
else:
|
119 |
+
# 短文本:逐字符输出
|
120 |
+
for char in text:
|
121 |
+
char_chunk = create_response_chunk(char)
|
122 |
+
yield format_chunk(char_chunk)
|
123 |
+
await asyncio.sleep(delay)
|
124 |
+
|
125 |
+
|
126 |
+
# 创建默认的优化器实例,可以直接导入使用
|
127 |
+
openai_optimizer = StreamOptimizer(
|
128 |
+
logger=logger_openai,
|
129 |
+
min_delay=settings.STREAM_MIN_DELAY,
|
130 |
+
max_delay=settings.STREAM_MAX_DELAY,
|
131 |
+
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
132 |
+
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
133 |
+
chunk_size=settings.STREAM_CHUNK_SIZE,
|
134 |
+
)
|
135 |
+
|
136 |
+
gemini_optimizer = StreamOptimizer(
|
137 |
+
logger=logger_gemini,
|
138 |
+
min_delay=settings.STREAM_MIN_DELAY,
|
139 |
+
max_delay=settings.STREAM_MAX_DELAY,
|
140 |
+
short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
|
141 |
+
long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
|
142 |
+
chunk_size=settings.STREAM_CHUNK_SIZE,
|
143 |
+
)
|
app/log/logger.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import platform
|
3 |
+
import sys
|
4 |
+
from typing import Dict, Optional
|
5 |
+
|
6 |
+
# ANSI转义序列颜色代码
|
7 |
+
COLORS = {
|
8 |
+
"DEBUG": "\033[34m", # 蓝色
|
9 |
+
"INFO": "\033[32m", # 绿色
|
10 |
+
"WARNING": "\033[33m", # 黄色
|
11 |
+
"ERROR": "\033[31m", # 红色
|
12 |
+
"CRITICAL": "\033[1;31m", # 红色加粗
|
13 |
+
}
|
14 |
+
|
15 |
+
# Windows系统启用ANSI支持
|
16 |
+
if platform.system() == "Windows":
|
17 |
+
import ctypes
|
18 |
+
|
19 |
+
kernel32 = ctypes.windll.kernel32
|
20 |
+
kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
|
21 |
+
|
22 |
+
|
23 |
+
class ColoredFormatter(logging.Formatter):
|
24 |
+
"""
|
25 |
+
自定义的日志格式化器,添加颜色支持
|
26 |
+
"""
|
27 |
+
|
28 |
+
def format(self, record):
|
29 |
+
# 获取对应级别的颜色代码
|
30 |
+
color = COLORS.get(record.levelname, "")
|
31 |
+
# 添加颜色代码和重置代码
|
32 |
+
record.levelname = f"{color}{record.levelname}\033[0m"
|
33 |
+
# 创建包含文件名和行号的固定宽度字符串
|
34 |
+
record.fileloc = f"[{record.filename}:{record.lineno}]"
|
35 |
+
return super().format(record)
|
36 |
+
|
37 |
+
|
38 |
+
# 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
|
39 |
+
FORMATTER = ColoredFormatter(
|
40 |
+
"%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
|
41 |
+
)
|
42 |
+
|
43 |
+
# 日志级别映射
|
44 |
+
LOG_LEVELS = {
|
45 |
+
"debug": logging.DEBUG,
|
46 |
+
"info": logging.INFO,
|
47 |
+
"warning": logging.WARNING,
|
48 |
+
"error": logging.ERROR,
|
49 |
+
"critical": logging.CRITICAL,
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
class Logger:
|
54 |
+
def __init__(self):
|
55 |
+
pass
|
56 |
+
|
57 |
+
_loggers: Dict[str, logging.Logger] = {}
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def setup_logger(name: str) -> logging.Logger:
|
61 |
+
"""
|
62 |
+
设置并获取logger
|
63 |
+
:param name: logger名称
|
64 |
+
:return: logger实例
|
65 |
+
"""
|
66 |
+
# 导入 settings 对象
|
67 |
+
from app.config.config import settings
|
68 |
+
|
69 |
+
# 从全局配置获取日志级别
|
70 |
+
log_level_str = settings.LOG_LEVEL.lower()
|
71 |
+
level = LOG_LEVELS.get(log_level_str, logging.INFO)
|
72 |
+
|
73 |
+
if name in Logger._loggers:
|
74 |
+
# 如果 logger 已存在,检查并更新其级别(如果需要)
|
75 |
+
existing_logger = Logger._loggers[name]
|
76 |
+
if existing_logger.level != level:
|
77 |
+
existing_logger.setLevel(level)
|
78 |
+
return existing_logger
|
79 |
+
|
80 |
+
logger = logging.getLogger(name)
|
81 |
+
logger.setLevel(level)
|
82 |
+
logger.propagate = False
|
83 |
+
|
84 |
+
# 添加控制台输出
|
85 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
86 |
+
console_handler.setFormatter(FORMATTER)
|
87 |
+
logger.addHandler(console_handler)
|
88 |
+
|
89 |
+
Logger._loggers[name] = logger
|
90 |
+
return logger
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def get_logger(name: str) -> Optional[logging.Logger]:
|
94 |
+
"""
|
95 |
+
获取已存在的logger
|
96 |
+
:param name: logger名称
|
97 |
+
:return: logger实例或None
|
98 |
+
"""
|
99 |
+
return Logger._loggers.get(name)
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def update_log_levels(log_level: str):
|
103 |
+
"""
|
104 |
+
根据当前的全局配置更新所有已创建 logger 的日志级别。
|
105 |
+
"""
|
106 |
+
log_level_str = log_level.lower()
|
107 |
+
new_level = LOG_LEVELS.get(log_level_str, logging.INFO)
|
108 |
+
|
109 |
+
updated_count = 0
|
110 |
+
for logger_name, logger_instance in Logger._loggers.items():
|
111 |
+
if logger_instance.level != new_level:
|
112 |
+
logger_instance.setLevel(new_level)
|
113 |
+
# 可选:记录级别变更日志,但注意避免在日志模块内部产生过多日志
|
114 |
+
# print(f"Updated log level for logger '{logger_name}' to {log_level_str.upper()}")
|
115 |
+
updated_count += 1
|
116 |
+
|
117 |
+
|
118 |
+
# 预定义的loggers
|
119 |
+
def get_openai_logger():
|
120 |
+
return Logger.setup_logger("openai")
|
121 |
+
|
122 |
+
|
123 |
+
def get_gemini_logger():
|
124 |
+
return Logger.setup_logger("gemini")
|
125 |
+
|
126 |
+
|
127 |
+
def get_chat_logger():
|
128 |
+
return Logger.setup_logger("chat")
|
129 |
+
|
130 |
+
|
131 |
+
def get_model_logger():
|
132 |
+
return Logger.setup_logger("model")
|
133 |
+
|
134 |
+
|
135 |
+
def get_security_logger():
|
136 |
+
return Logger.setup_logger("security")
|
137 |
+
|
138 |
+
|
139 |
+
def get_key_manager_logger():
|
140 |
+
return Logger.setup_logger("key_manager")
|
141 |
+
|
142 |
+
|
143 |
+
def get_main_logger():
|
144 |
+
return Logger.setup_logger("main")
|
145 |
+
|
146 |
+
|
147 |
+
def get_embeddings_logger():
|
148 |
+
return Logger.setup_logger("embeddings")
|
149 |
+
|
150 |
+
|
151 |
+
def get_request_logger():
|
152 |
+
return Logger.setup_logger("request")
|
153 |
+
|
154 |
+
|
155 |
+
def get_retry_logger():
|
156 |
+
return Logger.setup_logger("retry")
|
157 |
+
|
158 |
+
|
159 |
+
def get_image_create_logger():
|
160 |
+
return Logger.setup_logger("image_create")
|
161 |
+
|
162 |
+
|
163 |
+
def get_exceptions_logger():
|
164 |
+
return Logger.setup_logger("exceptions")
|
165 |
+
|
166 |
+
|
167 |
+
def get_application_logger():
|
168 |
+
return Logger.setup_logger("application")
|
169 |
+
|
170 |
+
|
171 |
+
def get_initialization_logger():
|
172 |
+
return Logger.setup_logger("initialization")
|
173 |
+
|
174 |
+
|
175 |
+
def get_middleware_logger():
|
176 |
+
return Logger.setup_logger("middleware")
|
177 |
+
|
178 |
+
|
179 |
+
def get_routes_logger():
|
180 |
+
return Logger.setup_logger("routes")
|
181 |
+
|
182 |
+
|
183 |
+
def get_config_routes_logger():
|
184 |
+
return Logger.setup_logger("config_routes")
|
185 |
+
|
186 |
+
|
187 |
+
def get_config_logger():
|
188 |
+
return Logger.setup_logger("config")
|
189 |
+
|
190 |
+
|
191 |
+
def get_database_logger():
|
192 |
+
return Logger.setup_logger("database")
|
193 |
+
|
194 |
+
|
195 |
+
def get_log_routes_logger():
|
196 |
+
return Logger.setup_logger("log_routes")
|
197 |
+
|
198 |
+
|
199 |
+
def get_stats_logger():
|
200 |
+
return Logger.setup_logger("stats")
|
201 |
+
|
202 |
+
|
203 |
+
def get_update_logger():
|
204 |
+
return Logger.setup_logger("update_service")
|
205 |
+
|
206 |
+
|
207 |
+
def get_scheduler_routes():
|
208 |
+
return Logger.setup_logger("scheduler_routes")
|
209 |
+
|
210 |
+
|
211 |
+
def get_message_converter_logger():
|
212 |
+
return Logger.setup_logger("message_converter")
|
213 |
+
|
214 |
+
|
215 |
+
def get_api_client_logger():
|
216 |
+
return Logger.setup_logger("api_client")
|
217 |
+
|
218 |
+
|
219 |
+
def get_openai_compatible_logger():
|
220 |
+
return Logger.setup_logger("openai_compatible")
|
221 |
+
|
222 |
+
|
223 |
+
def get_error_log_logger():
|
224 |
+
return Logger.setup_logger("error_log")
|
225 |
+
|
226 |
+
|
227 |
+
def get_request_log_logger():
|
228 |
+
return Logger.setup_logger("request_log")
|
229 |
+
|
230 |
+
|
231 |
+
def get_vertex_express_logger():
|
232 |
+
return Logger.setup_logger("vertex_express")
|
233 |
+
|
app/main.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uvicorn
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
# 在导入应用程序配置之前加载 .env 文件到环境变量
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
from app.core.application import create_app
|
8 |
+
from app.log.logger import get_main_logger
|
9 |
+
|
10 |
+
app = create_app()
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
logger = get_main_logger()
|
14 |
+
logger.info("Starting application server...")
|
15 |
+
uvicorn.run(app, host="0.0.0.0", port=8001)
|
app/middleware/middleware.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
中间件配置模块,负责设置和配置应用程序的中间件
|
3 |
+
"""
|
4 |
+
|
5 |
+
from fastapi import FastAPI, Request
|
6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
7 |
+
from fastapi.responses import RedirectResponse
|
8 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
9 |
+
|
10 |
+
# from app.middleware.request_logging_middleware import RequestLoggingMiddleware
|
11 |
+
from app.middleware.smart_routing_middleware import SmartRoutingMiddleware
|
12 |
+
from app.core.constants import API_VERSION
|
13 |
+
from app.core.security import verify_auth_token
|
14 |
+
from app.log.logger import get_middleware_logger
|
15 |
+
|
16 |
+
logger = get_middleware_logger()
|
17 |
+
|
18 |
+
|
19 |
+
class AuthMiddleware(BaseHTTPMiddleware):
|
20 |
+
"""
|
21 |
+
认证中间件,处理未经身份验证的请求
|
22 |
+
"""
|
23 |
+
|
24 |
+
async def dispatch(self, request: Request, call_next):
|
25 |
+
# 允许特定路径绕过身份验证
|
26 |
+
if (
|
27 |
+
request.url.path not in ["/", "/auth"]
|
28 |
+
and not request.url.path.startswith("/static")
|
29 |
+
and not request.url.path.startswith("/gemini")
|
30 |
+
and not request.url.path.startswith("/v1")
|
31 |
+
and not request.url.path.startswith(f"/{API_VERSION}")
|
32 |
+
and not request.url.path.startswith("/health")
|
33 |
+
and not request.url.path.startswith("/hf")
|
34 |
+
and not request.url.path.startswith("/openai")
|
35 |
+
and not request.url.path.startswith("/api/version/check")
|
36 |
+
and not request.url.path.startswith("/vertex-express")
|
37 |
+
):
|
38 |
+
|
39 |
+
auth_token = request.cookies.get("auth_token")
|
40 |
+
if not auth_token or not verify_auth_token(auth_token):
|
41 |
+
logger.warning(f"Unauthorized access attempt to {request.url.path}")
|
42 |
+
return RedirectResponse(url="/")
|
43 |
+
logger.debug("Request authenticated successfully")
|
44 |
+
|
45 |
+
response = await call_next(request)
|
46 |
+
return response
|
47 |
+
|
48 |
+
|
49 |
+
def setup_middlewares(app: FastAPI) -> None:
|
50 |
+
"""
|
51 |
+
设置应用程序的中间件
|
52 |
+
|
53 |
+
Args:
|
54 |
+
app: FastAPI应用程序实例
|
55 |
+
"""
|
56 |
+
# 添加智能路由中间件(必须在认证中间件之前)
|
57 |
+
app.add_middleware(SmartRoutingMiddleware)
|
58 |
+
|
59 |
+
# 添加认证中间件
|
60 |
+
app.add_middleware(AuthMiddleware)
|
61 |
+
|
62 |
+
# 添加请求日志中间件(可选,默认注释掉)
|
63 |
+
# app.add_middleware(RequestLoggingMiddleware)
|
64 |
+
|
65 |
+
# 配置CORS中间件
|
66 |
+
app.add_middleware(
|
67 |
+
CORSMiddleware,
|
68 |
+
allow_origins=["*"],
|
69 |
+
allow_credentials=True,
|
70 |
+
allow_methods=[
|
71 |
+
"GET",
|
72 |
+
"POST",
|
73 |
+
"PUT",
|
74 |
+
"DELETE",
|
75 |
+
"OPTIONS",
|
76 |
+
],
|
77 |
+
allow_headers=["*"],
|
78 |
+
expose_headers=["*"],
|
79 |
+
max_age=600,
|
80 |
+
)
|
app/middleware/request_logging_middleware.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from fastapi import Request
|
4 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
5 |
+
|
6 |
+
from app.log.logger import get_request_logger
|
7 |
+
|
8 |
+
logger = get_request_logger()
|
9 |
+
|
10 |
+
|
11 |
+
# 添加中间件类
|
12 |
+
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
13 |
+
async def dispatch(self, request: Request, call_next):
|
14 |
+
# 记录请求路径
|
15 |
+
logger.info(f"Request path: {request.url.path}")
|
16 |
+
|
17 |
+
# 获取并记录请求体
|
18 |
+
try:
|
19 |
+
body = await request.body()
|
20 |
+
if body:
|
21 |
+
body_str = body.decode()
|
22 |
+
# 尝试格式化JSON
|
23 |
+
try:
|
24 |
+
formatted_body = json.loads(body_str)
|
25 |
+
logger.info(
|
26 |
+
f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
|
27 |
+
)
|
28 |
+
except json.JSONDecodeError:
|
29 |
+
logger.error("Request body is not valid JSON.")
|
30 |
+
except Exception as e:
|
31 |
+
logger.error(f"Error reading request body: {str(e)}")
|
32 |
+
|
33 |
+
# 重置请求的接收器,以便后续处理器可以继续读取请求体
|
34 |
+
async def receive():
|
35 |
+
return {"type": "http.request", "body": body, "more_body": False}
|
36 |
+
|
37 |
+
request._receive = receive
|
38 |
+
|
39 |
+
response = await call_next(request)
|
40 |
+
return response
|
app/middleware/smart_routing_middleware.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Request
|
2 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
3 |
+
from app.config.config import settings
|
4 |
+
from app.log.logger import get_main_logger
|
5 |
+
import re
|
6 |
+
|
7 |
+
logger = get_main_logger()
|
8 |
+
|
9 |
+
class SmartRoutingMiddleware(BaseHTTPMiddleware):
|
10 |
+
def __init__(self, app):
|
11 |
+
super().__init__(app)
|
12 |
+
# 简化的路由规则 - 直接根据检测结果路由
|
13 |
+
pass
|
14 |
+
|
15 |
+
async def dispatch(self, request: Request, call_next):
|
16 |
+
if not settings.URL_NORMALIZATION_ENABLED:
|
17 |
+
return await call_next(request)
|
18 |
+
logger.debug(f"request: {request}")
|
19 |
+
original_path = str(request.url.path)
|
20 |
+
method = request.method
|
21 |
+
|
22 |
+
# 尝试修复URL
|
23 |
+
fixed_path, fix_info = self.fix_request_url(original_path, method, request)
|
24 |
+
|
25 |
+
if fixed_path != original_path:
|
26 |
+
logger.info(f"URL fixed: {method} {original_path} → {fixed_path}")
|
27 |
+
if fix_info:
|
28 |
+
logger.debug(f"Fix details: {fix_info}")
|
29 |
+
|
30 |
+
# 重写请求路径
|
31 |
+
request.scope["path"] = fixed_path
|
32 |
+
request.scope["raw_path"] = fixed_path.encode()
|
33 |
+
|
34 |
+
return await call_next(request)
|
35 |
+
|
36 |
+
def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
|
37 |
+
"""简化的URL修复逻辑"""
|
38 |
+
|
39 |
+
# 首先检查是否已经是正确的格式,如果是则不处理
|
40 |
+
if self.is_already_correct_format(path):
|
41 |
+
return path, None
|
42 |
+
|
43 |
+
# 1. 最高优先级:包含generateContent → Gemini格式
|
44 |
+
if "generatecontent" in path.lower() or "v1beta/models" in path.lower():
|
45 |
+
return self.fix_gemini_by_operation(path, method, request)
|
46 |
+
|
47 |
+
# 2. 第二优先级:包含/openai/ → OpenAI格式
|
48 |
+
if "/openai/" in path.lower():
|
49 |
+
return self.fix_openai_by_operation(path, method)
|
50 |
+
|
51 |
+
# 3. 第三优先级:包含/v1/ → v1格式
|
52 |
+
if "/v1/" in path.lower():
|
53 |
+
return self.fix_v1_by_operation(path, method)
|
54 |
+
|
55 |
+
# 4. 第四优先级:包含/chat/completions → chat功能
|
56 |
+
if "/chat/completions" in path.lower():
|
57 |
+
return "/v1/chat/completions", {"type": "v1_chat"}
|
58 |
+
|
59 |
+
# 5. 默认:原样传递
|
60 |
+
return path, None
|
61 |
+
|
62 |
+
def is_already_correct_format(self, path: str) -> bool:
|
63 |
+
"""检查是否已经是正确的API格式"""
|
64 |
+
# 检查是否已经是正确的端点格式
|
65 |
+
correct_patterns = [
|
66 |
+
r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生
|
67 |
+
r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀
|
68 |
+
r"^/v1beta/models$", # Gemini模型列表
|
69 |
+
r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表
|
70 |
+
r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # v1格式
|
71 |
+
r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # OpenAI格式
|
72 |
+
r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # HF格式
|
73 |
+
r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式
|
74 |
+
r"^/vertex-express/v1beta/models$", # Vertex Express模型列表
|
75 |
+
r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式
|
76 |
+
]
|
77 |
+
|
78 |
+
for pattern in correct_patterns:
|
79 |
+
if re.match(pattern, path):
|
80 |
+
return True
|
81 |
+
|
82 |
+
return False
|
83 |
+
|
84 |
+
def fix_gemini_by_operation(
|
85 |
+
self, path: str, method: str, request: Request
|
86 |
+
) -> tuple:
|
87 |
+
"""根据Gemini操作修复,考虑端点偏好"""
|
88 |
+
if method == "GET":
|
89 |
+
return "/v1beta/models", {
|
90 |
+
"role": "gemini_models",
|
91 |
+
}
|
92 |
+
|
93 |
+
# 提取模型名称
|
94 |
+
try:
|
95 |
+
model_name = self.extract_model_name(path, request)
|
96 |
+
except ValueError:
|
97 |
+
# 无法提取模型名称,返回原路径不做处理
|
98 |
+
return path, None
|
99 |
+
|
100 |
+
# 检测是否为流式请求
|
101 |
+
is_stream = self.detect_stream_request(path, request)
|
102 |
+
|
103 |
+
# 检查是否有vertex-express偏好
|
104 |
+
if "/vertex-express/" in path.lower():
|
105 |
+
if is_stream:
|
106 |
+
target_url = (
|
107 |
+
f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent"
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
target_url = (
|
111 |
+
f"/vertex-express/v1beta/models/{model_name}:generateContent"
|
112 |
+
)
|
113 |
+
|
114 |
+
fix_info = {
|
115 |
+
"rule": (
|
116 |
+
"vertex_express_generate"
|
117 |
+
if not is_stream
|
118 |
+
else "vertex_express_stream"
|
119 |
+
),
|
120 |
+
"preference": "vertex_express_format",
|
121 |
+
"is_stream": is_stream,
|
122 |
+
"model": model_name,
|
123 |
+
}
|
124 |
+
else:
|
125 |
+
# 标准Gemini端点
|
126 |
+
if is_stream:
|
127 |
+
target_url = f"/v1beta/models/{model_name}:streamGenerateContent"
|
128 |
+
else:
|
129 |
+
target_url = f"/v1beta/models/{model_name}:generateContent"
|
130 |
+
|
131 |
+
fix_info = {
|
132 |
+
"rule": "gemini_generate" if not is_stream else "gemini_stream",
|
133 |
+
"preference": "gemini_format",
|
134 |
+
"is_stream": is_stream,
|
135 |
+
"model": model_name,
|
136 |
+
}
|
137 |
+
|
138 |
+
return target_url, fix_info
|
139 |
+
|
140 |
+
def fix_openai_by_operation(self, path: str, method: str) -> tuple:
|
141 |
+
"""根据操作类型修复OpenAI格式"""
|
142 |
+
if method == "POST":
|
143 |
+
if "chat" in path.lower() or "completion" in path.lower():
|
144 |
+
return "/openai/v1/chat/completions", {"type": "openai_chat"}
|
145 |
+
elif "embedding" in path.lower():
|
146 |
+
return "/openai/v1/embeddings", {"type": "openai_embeddings"}
|
147 |
+
elif "image" in path.lower():
|
148 |
+
return "/openai/v1/images/generations", {"type": "openai_images"}
|
149 |
+
elif "audio" in path.lower():
|
150 |
+
return "/openai/v1/audio/speech", {"type": "openai_audio"}
|
151 |
+
elif method == "GET":
|
152 |
+
if "model" in path.lower():
|
153 |
+
return "/openai/v1/models", {"type": "openai_models"}
|
154 |
+
|
155 |
+
return path, None
|
156 |
+
|
157 |
+
def fix_v1_by_operation(self, path: str, method: str) -> tuple:
|
158 |
+
"""根据操作类型修复v1格式"""
|
159 |
+
if method == "POST":
|
160 |
+
if "chat" in path.lower() or "completion" in path.lower():
|
161 |
+
return "/v1/chat/completions", {"type": "v1_chat"}
|
162 |
+
elif "embedding" in path.lower():
|
163 |
+
return "/v1/embeddings", {"type": "v1_embeddings"}
|
164 |
+
elif "image" in path.lower():
|
165 |
+
return "/v1/images/generations", {"type": "v1_images"}
|
166 |
+
elif "audio" in path.lower():
|
167 |
+
return "/v1/audio/speech", {"type": "v1_audio"}
|
168 |
+
elif method == "GET":
|
169 |
+
if "model" in path.lower():
|
170 |
+
return "/v1/models", {"type": "v1_models"}
|
171 |
+
|
172 |
+
return path, None
|
173 |
+
|
174 |
+
def detect_stream_request(self, path: str, request: Request) -> bool:
|
175 |
+
"""检测是否为流式请求"""
|
176 |
+
# 1. 路径中包含stream关键词
|
177 |
+
if "stream" in path.lower():
|
178 |
+
return True
|
179 |
+
|
180 |
+
# 2. 查询参数
|
181 |
+
if request.query_params.get("stream") == "true":
|
182 |
+
return True
|
183 |
+
|
184 |
+
return False
|
185 |
+
|
186 |
+
def extract_model_name(self, path: str, request: Request) -> str:
|
187 |
+
"""从请求中提取模型名称,用于构建Gemini API URL"""
|
188 |
+
# 1. 从请求体中提取
|
189 |
+
try:
|
190 |
+
if hasattr(request, "_body") and request._body:
|
191 |
+
import json
|
192 |
+
|
193 |
+
body = json.loads(request._body.decode())
|
194 |
+
if "model" in body and body["model"]:
|
195 |
+
return body["model"]
|
196 |
+
except Exception:
|
197 |
+
pass
|
198 |
+
|
199 |
+
# 2. 从查询参数中提取
|
200 |
+
model_param = request.query_params.get("model")
|
201 |
+
if model_param:
|
202 |
+
return model_param
|
203 |
+
|
204 |
+
# 3. 从路径中提取(用于已包含模型名称的路径)
|
205 |
+
match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE)
|
206 |
+
if match:
|
207 |
+
return match.group(1)
|
208 |
+
|
209 |
+
# 4. 如果无法提取模型名称,抛出异常
|
210 |
+
raise ValueError("Unable to extract model name from request")
|
app/router/config_routes.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
配置路由模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Any, Dict, List
|
6 |
+
|
7 |
+
from fastapi import APIRouter, HTTPException, Request
|
8 |
+
from fastapi.responses import RedirectResponse
|
9 |
+
from pydantic import BaseModel, Field
|
10 |
+
|
11 |
+
from app.core.security import verify_auth_token
|
12 |
+
from app.log.logger import Logger, get_config_routes_logger
|
13 |
+
from app.service.config.config_service import ConfigService
|
14 |
+
|
15 |
+
router = APIRouter(prefix="/api/config", tags=["config"])
|
16 |
+
|
17 |
+
logger = get_config_routes_logger()
|
18 |
+
|
19 |
+
|
20 |
+
@router.get("", response_model=Dict[str, Any])
|
21 |
+
async def get_config(request: Request):
|
22 |
+
auth_token = request.cookies.get("auth_token")
|
23 |
+
if not auth_token or not verify_auth_token(auth_token):
|
24 |
+
logger.warning("Unauthorized access attempt to config page")
|
25 |
+
return RedirectResponse(url="/", status_code=302)
|
26 |
+
return await ConfigService.get_config()
|
27 |
+
|
28 |
+
|
29 |
+
@router.put("", response_model=Dict[str, Any])
|
30 |
+
async def update_config(config_data: Dict[str, Any], request: Request):
|
31 |
+
auth_token = request.cookies.get("auth_token")
|
32 |
+
if not auth_token or not verify_auth_token(auth_token):
|
33 |
+
logger.warning("Unauthorized access attempt to config page")
|
34 |
+
return RedirectResponse(url="/", status_code=302)
|
35 |
+
try:
|
36 |
+
result = await ConfigService.update_config(config_data)
|
37 |
+
# 配置更新成功后,立即更新所有 logger 的级别
|
38 |
+
Logger.update_log_levels(config_data["LOG_LEVEL"])
|
39 |
+
logger.info("Log levels updated after configuration change.")
|
40 |
+
return result
|
41 |
+
except Exception as e:
|
42 |
+
logger.error(f"Error updating config or log levels: {e}", exc_info=True)
|
43 |
+
raise HTTPException(status_code=400, detail=str(e))
|
44 |
+
|
45 |
+
|
46 |
+
@router.post("/reset", response_model=Dict[str, Any])
|
47 |
+
async def reset_config(request: Request):
|
48 |
+
auth_token = request.cookies.get("auth_token")
|
49 |
+
if not auth_token or not verify_auth_token(auth_token):
|
50 |
+
logger.warning("Unauthorized access attempt to config page")
|
51 |
+
return RedirectResponse(url="/", status_code=302)
|
52 |
+
try:
|
53 |
+
return await ConfigService.reset_config()
|
54 |
+
except Exception as e:
|
55 |
+
raise HTTPException(status_code=400, detail=str(e))
|
56 |
+
|
57 |
+
|
58 |
+
class DeleteKeysRequest(BaseModel):
|
59 |
+
keys: List[str] = Field(..., description="List of API keys to delete")
|
60 |
+
|
61 |
+
|
62 |
+
@router.delete("/keys/{key_to_delete}", response_model=Dict[str, Any])
|
63 |
+
async def delete_single_key(key_to_delete: str, request: Request):
|
64 |
+
auth_token = request.cookies.get("auth_token")
|
65 |
+
if not auth_token or not verify_auth_token(auth_token):
|
66 |
+
logger.warning(f"Unauthorized attempt to delete key: {key_to_delete}")
|
67 |
+
return RedirectResponse(url="/", status_code=302)
|
68 |
+
try:
|
69 |
+
logger.info(f"Attempting to delete key: {key_to_delete}")
|
70 |
+
result = await ConfigService.delete_key(key_to_delete)
|
71 |
+
if not result.get("success"):
|
72 |
+
raise HTTPException(
|
73 |
+
status_code=(
|
74 |
+
404 if "not found" in result.get("message", "").lower() else 400
|
75 |
+
),
|
76 |
+
detail=result.get("message"),
|
77 |
+
)
|
78 |
+
return result
|
79 |
+
except HTTPException as e:
|
80 |
+
raise e
|
81 |
+
except Exception as e:
|
82 |
+
logger.error(f"Error deleting key '{key_to_delete}': {e}", exc_info=True)
|
83 |
+
raise HTTPException(status_code=500, detail=f"Error deleting key: {str(e)}")
|
84 |
+
|
85 |
+
|
86 |
+
@router.post("/keys/delete-selected", response_model=Dict[str, Any])
|
87 |
+
async def delete_selected_keys_route(
|
88 |
+
delete_request: DeleteKeysRequest, request: Request
|
89 |
+
):
|
90 |
+
auth_token = request.cookies.get("auth_token")
|
91 |
+
if not auth_token or not verify_auth_token(auth_token):
|
92 |
+
logger.warning("Unauthorized attempt to bulk delete keys")
|
93 |
+
return RedirectResponse(url="/", status_code=302)
|
94 |
+
|
95 |
+
if not delete_request.keys:
|
96 |
+
logger.warning("Attempt to bulk delete keys with an empty list.")
|
97 |
+
raise HTTPException(status_code=400, detail="No keys provided for deletion.")
|
98 |
+
|
99 |
+
try:
|
100 |
+
logger.info(f"Attempting to bulk delete {len(delete_request.keys)} keys.")
|
101 |
+
result = await ConfigService.delete_selected_keys(delete_request.keys)
|
102 |
+
if not result.get("success") and result.get("deleted_count", 0) == 0:
|
103 |
+
raise HTTPException(
|
104 |
+
status_code=400, detail=result.get("message", "Failed to delete keys.")
|
105 |
+
)
|
106 |
+
return result
|
107 |
+
except HTTPException as e:
|
108 |
+
raise e
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Error bulk deleting keys: {e}", exc_info=True)
|
111 |
+
raise HTTPException(
|
112 |
+
status_code=500, detail=f"Error bulk deleting keys: {str(e)}"
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
@router.get("/ui/models")
|
117 |
+
async def get_ui_models(request: Request):
|
118 |
+
auth_token_cookie = request.cookies.get("auth_token")
|
119 |
+
if not auth_token_cookie or not verify_auth_token(auth_token_cookie):
|
120 |
+
logger.warning("Unauthorized access attempt to /api/config/ui/models")
|
121 |
+
raise HTTPException(status_code=403, detail="Not authenticated")
|
122 |
+
|
123 |
+
try:
|
124 |
+
models = await ConfigService.fetch_ui_models()
|
125 |
+
return models
|
126 |
+
except HTTPException as e:
|
127 |
+
raise e
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"Unexpected error in /ui/models endpoint: {e}", exc_info=True)
|
130 |
+
raise HTTPException(
|
131 |
+
status_code=500,
|
132 |
+
detail=f"An unexpected error occurred while fetching UI models: {str(e)}",
|
133 |
+
)
|
app/router/error_log_routes.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
日志路由模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import Dict, List, Optional
|
7 |
+
|
8 |
+
from fastapi import (
|
9 |
+
APIRouter,
|
10 |
+
Body,
|
11 |
+
HTTPException,
|
12 |
+
Path,
|
13 |
+
Query,
|
14 |
+
Request,
|
15 |
+
Response,
|
16 |
+
status,
|
17 |
+
)
|
18 |
+
from pydantic import BaseModel
|
19 |
+
|
20 |
+
from app.core.security import verify_auth_token
|
21 |
+
from app.log.logger import get_log_routes_logger
|
22 |
+
from app.service.error_log import error_log_service
|
23 |
+
|
24 |
+
router = APIRouter(prefix="/api/logs", tags=["logs"])
|
25 |
+
|
26 |
+
logger = get_log_routes_logger()
|
27 |
+
|
28 |
+
|
29 |
+
class ErrorLogListItem(BaseModel):
|
30 |
+
id: int
|
31 |
+
gemini_key: Optional[str] = None
|
32 |
+
error_type: Optional[str] = None
|
33 |
+
error_code: Optional[int] = None
|
34 |
+
model_name: Optional[str] = None
|
35 |
+
request_time: Optional[datetime] = None
|
36 |
+
|
37 |
+
|
38 |
+
class ErrorLogListResponse(BaseModel):
|
39 |
+
logs: List[ErrorLogListItem]
|
40 |
+
total: int
|
41 |
+
|
42 |
+
|
43 |
+
@router.get("/errors", response_model=ErrorLogListResponse)
|
44 |
+
async def get_error_logs_api(
|
45 |
+
request: Request,
|
46 |
+
limit: int = Query(10, ge=1, le=1000),
|
47 |
+
offset: int = Query(0, ge=0),
|
48 |
+
key_search: Optional[str] = Query(
|
49 |
+
None, description="Search term for Gemini key (partial match)"
|
50 |
+
),
|
51 |
+
error_search: Optional[str] = Query(
|
52 |
+
None, description="Search term for error type or log message"
|
53 |
+
),
|
54 |
+
error_code_search: Optional[str] = Query(
|
55 |
+
None, description="Search term for error code"
|
56 |
+
),
|
57 |
+
start_date: Optional[datetime] = Query(
|
58 |
+
None, description="Start datetime for filtering"
|
59 |
+
),
|
60 |
+
end_date: Optional[datetime] = Query(
|
61 |
+
None, description="End datetime for filtering"
|
62 |
+
),
|
63 |
+
sort_by: str = Query(
|
64 |
+
"id", description="Field to sort by (e.g., 'id', 'request_time')"
|
65 |
+
),
|
66 |
+
sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"),
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
获取错误日志列表 (返回错误码),支持过滤和排序
|
70 |
+
|
71 |
+
Args:
|
72 |
+
request: 请求对象
|
73 |
+
limit: 限制数量
|
74 |
+
offset: 偏移量
|
75 |
+
key_search: 密钥搜索
|
76 |
+
error_search: 错误搜索 (可能搜索类型或日志内容,由DB层决定)
|
77 |
+
error_code_search: 错误码搜索
|
78 |
+
start_date: 开始日期
|
79 |
+
end_date: 结束日期
|
80 |
+
sort_by: 排序字段
|
81 |
+
sort_order: 排序顺序
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
ErrorLogListResponse: An object containing the list of logs (with error_code) and the total count.
|
85 |
+
"""
|
86 |
+
auth_token = request.cookies.get("auth_token")
|
87 |
+
if not auth_token or not verify_auth_token(auth_token):
|
88 |
+
logger.warning("Unauthorized access attempt to error logs list")
|
89 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
90 |
+
|
91 |
+
try:
|
92 |
+
result = await error_log_service.process_get_error_logs(
|
93 |
+
limit=limit,
|
94 |
+
offset=offset,
|
95 |
+
key_search=key_search,
|
96 |
+
error_search=error_search,
|
97 |
+
error_code_search=error_code_search,
|
98 |
+
start_date=start_date,
|
99 |
+
end_date=end_date,
|
100 |
+
sort_by=sort_by,
|
101 |
+
sort_order=sort_order,
|
102 |
+
)
|
103 |
+
logs_data = result["logs"]
|
104 |
+
total_count = result["total"]
|
105 |
+
|
106 |
+
validated_logs = [ErrorLogListItem(**log) for log in logs_data]
|
107 |
+
return ErrorLogListResponse(logs=validated_logs, total=total_count)
|
108 |
+
except Exception as e:
|
109 |
+
logger.exception(f"Failed to get error logs list: {str(e)}")
|
110 |
+
raise HTTPException(
|
111 |
+
status_code=500, detail=f"Failed to get error logs list: {str(e)}"
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
class ErrorLogDetailResponse(BaseModel):
|
116 |
+
id: int
|
117 |
+
gemini_key: Optional[str] = None
|
118 |
+
error_type: Optional[str] = None
|
119 |
+
error_log: Optional[str] = None
|
120 |
+
request_msg: Optional[str] = None
|
121 |
+
model_name: Optional[str] = None
|
122 |
+
request_time: Optional[datetime] = None
|
123 |
+
|
124 |
+
|
125 |
+
@router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse)
|
126 |
+
async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=1)):
|
127 |
+
"""
|
128 |
+
根据日志 ID 获取错误日志的详细信息 (包括 error_log 和 request_msg)
|
129 |
+
"""
|
130 |
+
auth_token = request.cookies.get("auth_token")
|
131 |
+
if not auth_token or not verify_auth_token(auth_token):
|
132 |
+
logger.warning(
|
133 |
+
f"Unauthorized access attempt to error log details for ID: {log_id}"
|
134 |
+
)
|
135 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
136 |
+
|
137 |
+
try:
|
138 |
+
log_details = await error_log_service.process_get_error_log_details(
|
139 |
+
log_id=log_id
|
140 |
+
)
|
141 |
+
if not log_details:
|
142 |
+
raise HTTPException(status_code=404, detail="Error log not found")
|
143 |
+
|
144 |
+
return ErrorLogDetailResponse(**log_details)
|
145 |
+
except HTTPException as http_exc:
|
146 |
+
raise http_exc
|
147 |
+
except Exception as e:
|
148 |
+
logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
|
149 |
+
raise HTTPException(
|
150 |
+
status_code=500, detail=f"Failed to get error log details: {str(e)}"
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
@router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT)
|
155 |
+
async def delete_error_logs_bulk_api(
|
156 |
+
request: Request, payload: Dict[str, List[int]] = Body(...)
|
157 |
+
):
|
158 |
+
"""
|
159 |
+
批量删除错误日志 (异步)
|
160 |
+
"""
|
161 |
+
auth_token = request.cookies.get("auth_token")
|
162 |
+
if not auth_token or not verify_auth_token(auth_token):
|
163 |
+
logger.warning("Unauthorized access attempt to bulk delete error logs")
|
164 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
165 |
+
|
166 |
+
log_ids = payload.get("ids")
|
167 |
+
if not log_ids:
|
168 |
+
raise HTTPException(status_code=400, detail="No log IDs provided for deletion.")
|
169 |
+
|
170 |
+
try:
|
171 |
+
deleted_count = await error_log_service.process_delete_error_logs_by_ids(
|
172 |
+
log_ids
|
173 |
+
)
|
174 |
+
# 注意:异步函数返回的是尝试删除的数量,可能不是精确值
|
175 |
+
logger.info(
|
176 |
+
f"Attempted bulk deletion for {deleted_count} error logs with IDs: {log_ids}"
|
177 |
+
)
|
178 |
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
179 |
+
except Exception as e:
|
180 |
+
logger.exception(f"Error bulk deleting error logs with IDs {log_ids}: {str(e)}")
|
181 |
+
raise HTTPException(
|
182 |
+
status_code=500, detail="Internal server error during bulk deletion"
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
@router.delete("/errors/all", status_code=status.HTTP_204_NO_CONTENT)
|
187 |
+
async def delete_all_error_logs_api(request: Request):
|
188 |
+
"""
|
189 |
+
删除所有错误日志 (异步)
|
190 |
+
"""
|
191 |
+
auth_token = request.cookies.get("auth_token")
|
192 |
+
if not auth_token or not verify_auth_token(auth_token):
|
193 |
+
logger.warning("Unauthorized access attempt to delete all error logs")
|
194 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
195 |
+
|
196 |
+
try:
|
197 |
+
deleted_count = await error_log_service.process_delete_all_error_logs()
|
198 |
+
logger.info(f"Successfully deleted all {deleted_count} error logs.")
|
199 |
+
# No body needed for 204 response
|
200 |
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
201 |
+
except Exception as e:
|
202 |
+
logger.exception(f"Error deleting all error logs: {str(e)}")
|
203 |
+
raise HTTPException(
|
204 |
+
status_code=500, detail="Internal server error during deletion of all logs"
|
205 |
+
)
|
206 |
+
|
207 |
+
|
208 |
+
@router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
|
209 |
+
async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)):
|
210 |
+
"""
|
211 |
+
删除单个错误日志 (异步)
|
212 |
+
"""
|
213 |
+
auth_token = request.cookies.get("auth_token")
|
214 |
+
if not auth_token or not verify_auth_token(auth_token):
|
215 |
+
logger.warning(f"Unauthorized access attempt to delete error log ID: {log_id}")
|
216 |
+
raise HTTPException(status_code=401, detail="Not authenticated")
|
217 |
+
|
218 |
+
try:
|
219 |
+
success = await error_log_service.process_delete_error_log_by_id(log_id)
|
220 |
+
if not success:
|
221 |
+
# 服务层现在在未找到时返回 False,我们在这里转换为 404
|
222 |
+
raise HTTPException(
|
223 |
+
status_code=404, detail=f"Error log with ID {log_id} not found"
|
224 |
+
)
|
225 |
+
logger.info(f"Successfully deleted error log with ID: {log_id}")
|
226 |
+
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
227 |
+
except HTTPException as http_exc:
|
228 |
+
raise http_exc
|
229 |
+
except Exception as e:
|
230 |
+
logger.exception(f"Error deleting error log with ID {log_id}: {str(e)}")
|
231 |
+
raise HTTPException(
|
232 |
+
status_code=500, detail="Internal server error during deletion"
|
233 |
+
)
|
app/router/gemini_routes.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
+
from copy import deepcopy
|
4 |
+
import asyncio
|
5 |
+
from app.config.config import settings
|
6 |
+
from app.log.logger import get_gemini_logger
|
7 |
+
from app.core.security import SecurityService
|
8 |
+
from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
|
9 |
+
from app.service.chat.gemini_chat_service import GeminiChatService
|
10 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
11 |
+
from app.service.model.model_service import ModelService
|
12 |
+
from app.handler.retry_handler import RetryHandler
|
13 |
+
from app.handler.error_handler import handle_route_errors
|
14 |
+
from app.core.constants import API_VERSION
|
15 |
+
|
16 |
+
router = APIRouter(prefix=f"/gemini/{API_VERSION}")
|
17 |
+
router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
|
18 |
+
logger = get_gemini_logger()
|
19 |
+
|
20 |
+
security_service = SecurityService()
|
21 |
+
model_service = ModelService()
|
22 |
+
|
23 |
+
|
24 |
+
async def get_key_manager():
|
25 |
+
"""获取密钥管理器实例"""
|
26 |
+
return await get_key_manager_instance()
|
27 |
+
|
28 |
+
|
29 |
+
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
30 |
+
"""获取下一个可用的API密钥"""
|
31 |
+
return await key_manager.get_next_working_key()
|
32 |
+
|
33 |
+
|
34 |
+
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
35 |
+
"""获取Gemini聊天服务实例"""
|
36 |
+
return GeminiChatService(settings.BASE_URL, key_manager)
|
37 |
+
|
38 |
+
|
39 |
+
@router.get("/models")
|
40 |
+
@router_v1beta.get("/models")
|
41 |
+
async def list_models(
|
42 |
+
_=Depends(security_service.verify_key_or_goog_api_key),
|
43 |
+
key_manager: KeyManager = Depends(get_key_manager)
|
44 |
+
):
|
45 |
+
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
46 |
+
operation_name = "list_gemini_models"
|
47 |
+
logger.info("-" * 50 + operation_name + "-" * 50)
|
48 |
+
logger.info("Handling Gemini models list request")
|
49 |
+
|
50 |
+
try:
|
51 |
+
api_key = await key_manager.get_first_valid_key()
|
52 |
+
if not api_key:
|
53 |
+
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
54 |
+
logger.info(f"Using API key: {api_key}")
|
55 |
+
|
56 |
+
models_data = await model_service.get_gemini_models(api_key)
|
57 |
+
if not models_data or "models" not in models_data:
|
58 |
+
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
59 |
+
|
60 |
+
models_json = deepcopy(models_data)
|
61 |
+
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
62 |
+
|
63 |
+
def add_derived_model(base_name, suffix, display_suffix):
|
64 |
+
model = model_mapping.get(base_name)
|
65 |
+
if not model:
|
66 |
+
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
67 |
+
return
|
68 |
+
item = deepcopy(model)
|
69 |
+
item["name"] = f"models/{base_name}{suffix}"
|
70 |
+
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
71 |
+
item["displayName"] = display_name
|
72 |
+
item["description"] = display_name
|
73 |
+
models_json["models"].append(item)
|
74 |
+
|
75 |
+
if settings.SEARCH_MODELS:
|
76 |
+
for name in settings.SEARCH_MODELS:
|
77 |
+
add_derived_model(name, "-search", " For Search")
|
78 |
+
if settings.IMAGE_MODELS:
|
79 |
+
for name in settings.IMAGE_MODELS:
|
80 |
+
add_derived_model(name, "-image", " For Image")
|
81 |
+
if settings.THINKING_MODELS:
|
82 |
+
for name in settings.THINKING_MODELS:
|
83 |
+
add_derived_model(name, "-non-thinking", " Non Thinking")
|
84 |
+
|
85 |
+
logger.info("Gemini models list request successful")
|
86 |
+
return models_json
|
87 |
+
except HTTPException as http_exc:
|
88 |
+
raise http_exc
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"Error getting Gemini models list: {str(e)}")
|
91 |
+
raise HTTPException(
|
92 |
+
status_code=500, detail="Internal server error while fetching Gemini models list"
|
93 |
+
) from e
|
94 |
+
|
95 |
+
|
96 |
+
@router.post("/models/{model_name}:generateContent")
|
97 |
+
@router_v1beta.post("/models/{model_name}:generateContent")
|
98 |
+
@RetryHandler(key_arg="api_key")
|
99 |
+
async def generate_content(
|
100 |
+
model_name: str,
|
101 |
+
request: GeminiRequest,
|
102 |
+
_=Depends(security_service.verify_key_or_goog_api_key),
|
103 |
+
api_key: str = Depends(get_next_working_key),
|
104 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
105 |
+
chat_service: GeminiChatService = Depends(get_chat_service)
|
106 |
+
):
|
107 |
+
"""处理 Gemini 非流式内容生成请求。"""
|
108 |
+
operation_name = "gemini_generate_content"
|
109 |
+
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
110 |
+
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
111 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
112 |
+
logger.info(f"Using API key: {api_key}")
|
113 |
+
|
114 |
+
if not await model_service.check_model_support(model_name):
|
115 |
+
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
116 |
+
|
117 |
+
response = await chat_service.generate_content(
|
118 |
+
model=model_name,
|
119 |
+
request=request,
|
120 |
+
api_key=api_key
|
121 |
+
)
|
122 |
+
return response
|
123 |
+
|
124 |
+
|
125 |
+
@router.post("/models/{model_name}:streamGenerateContent")
|
126 |
+
@router_v1beta.post("/models/{model_name}:streamGenerateContent")
|
127 |
+
@RetryHandler(key_arg="api_key")
|
128 |
+
async def stream_generate_content(
|
129 |
+
model_name: str,
|
130 |
+
request: GeminiRequest,
|
131 |
+
_=Depends(security_service.verify_key_or_goog_api_key),
|
132 |
+
api_key: str = Depends(get_next_working_key),
|
133 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
134 |
+
chat_service: GeminiChatService = Depends(get_chat_service)
|
135 |
+
):
|
136 |
+
"""处理 Gemini 流式内容生成请求。"""
|
137 |
+
operation_name = "gemini_stream_generate_content"
|
138 |
+
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
139 |
+
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
140 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
141 |
+
logger.info(f"Using API key: {api_key}")
|
142 |
+
|
143 |
+
if not await model_service.check_model_support(model_name):
|
144 |
+
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
145 |
+
|
146 |
+
response_stream = chat_service.stream_generate_content(
|
147 |
+
model=model_name,
|
148 |
+
request=request,
|
149 |
+
api_key=api_key
|
150 |
+
)
|
151 |
+
return StreamingResponse(response_stream, media_type="text/event-stream")
|
152 |
+
|
153 |
+
|
154 |
+
@router.post("/reset-all-fail-counts")
|
155 |
+
async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
|
156 |
+
"""批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
|
157 |
+
logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
|
158 |
+
logger.info(f"Received reset request with key_type: {key_type}")
|
159 |
+
|
160 |
+
try:
|
161 |
+
# 获取分类后的密钥
|
162 |
+
keys_by_status = await key_manager.get_keys_by_status()
|
163 |
+
valid_keys = keys_by_status.get("valid_keys", {})
|
164 |
+
invalid_keys = keys_by_status.get("invalid_keys", {})
|
165 |
+
|
166 |
+
# 根据类型选择要重置的密钥
|
167 |
+
keys_to_reset = []
|
168 |
+
if key_type == "valid":
|
169 |
+
keys_to_reset = list(valid_keys.keys())
|
170 |
+
logger.info(f"Resetting only valid keys, count: {len(keys_to_reset)}")
|
171 |
+
elif key_type == "invalid":
|
172 |
+
keys_to_reset = list(invalid_keys.keys())
|
173 |
+
logger.info(f"Resetting only invalid keys, count: {len(keys_to_reset)}")
|
174 |
+
else:
|
175 |
+
# 重置所有密钥
|
176 |
+
await key_manager.reset_failure_counts()
|
177 |
+
return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
|
178 |
+
|
179 |
+
# 批量重置指定类型的密钥
|
180 |
+
for key in keys_to_reset:
|
181 |
+
await key_manager.reset_key_failure_count(key)
|
182 |
+
|
183 |
+
return JSONResponse({
|
184 |
+
"success": True,
|
185 |
+
"message": f"{key_type}密钥的失败计数已重置",
|
186 |
+
"reset_count": len(keys_to_reset)
|
187 |
+
})
|
188 |
+
except Exception as e:
|
189 |
+
logger.error(f"Failed to reset key failure counts: {str(e)}")
|
190 |
+
return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
|
191 |
+
|
192 |
+
|
193 |
+
@router.post("/reset-selected-fail-counts")
|
194 |
+
async def reset_selected_key_fail_counts(
|
195 |
+
request: ResetSelectedKeysRequest,
|
196 |
+
key_manager: KeyManager = Depends(get_key_manager)
|
197 |
+
):
|
198 |
+
"""批量重置选定Gemini API密钥的失败计数"""
|
199 |
+
logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
|
200 |
+
keys_to_reset = request.keys
|
201 |
+
key_type = request.key_type
|
202 |
+
logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
|
203 |
+
|
204 |
+
if not keys_to_reset:
|
205 |
+
return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
|
206 |
+
|
207 |
+
reset_count = 0
|
208 |
+
errors = []
|
209 |
+
|
210 |
+
try:
|
211 |
+
for key in keys_to_reset:
|
212 |
+
try:
|
213 |
+
result = await key_manager.reset_key_failure_count(key)
|
214 |
+
if result:
|
215 |
+
reset_count += 1
|
216 |
+
else:
|
217 |
+
logger.warning(f"Key not found during selective reset: {key}")
|
218 |
+
except Exception as key_error:
|
219 |
+
logger.error(f"Error resetting key {key}: {str(key_error)}")
|
220 |
+
errors.append(f"Key {key}: {str(key_error)}")
|
221 |
+
|
222 |
+
if errors:
|
223 |
+
error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
|
224 |
+
final_success = reset_count > 0
|
225 |
+
status_code = 207 if final_success and errors else 500
|
226 |
+
return JSONResponse({
|
227 |
+
"success": final_success,
|
228 |
+
"message": error_message,
|
229 |
+
"reset_count": reset_count
|
230 |
+
}, status_code=status_code)
|
231 |
+
|
232 |
+
return JSONResponse({
|
233 |
+
"success": True,
|
234 |
+
"message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
|
235 |
+
"reset_count": reset_count
|
236 |
+
})
|
237 |
+
except Exception as e:
|
238 |
+
logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
|
239 |
+
return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
|
240 |
+
|
241 |
+
|
242 |
+
@router.post("/reset-fail-count/{api_key}")
|
243 |
+
async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
|
244 |
+
"""重置指定Gemini API密钥的失败计数"""
|
245 |
+
logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
|
246 |
+
logger.info(f"Resetting failure count for API key: {api_key}")
|
247 |
+
|
248 |
+
try:
|
249 |
+
result = await key_manager.reset_key_failure_count(api_key)
|
250 |
+
if result:
|
251 |
+
return JSONResponse({"success": True, "message": "失败计数已重置"})
|
252 |
+
return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
|
253 |
+
except Exception as e:
|
254 |
+
logger.error(f"Failed to reset key failure count: {str(e)}")
|
255 |
+
return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
|
256 |
+
|
257 |
+
|
258 |
+
@router.post("/verify-key/{api_key}")
|
259 |
+
async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
|
260 |
+
"""验证Gemini API密钥的有效性"""
|
261 |
+
logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
|
262 |
+
logger.info("Verifying API key validity")
|
263 |
+
|
264 |
+
try:
|
265 |
+
gemini_request = GeminiRequest(
|
266 |
+
contents=[
|
267 |
+
GeminiContent(
|
268 |
+
role="user",
|
269 |
+
parts=[{"text": "hi"}],
|
270 |
+
)
|
271 |
+
],
|
272 |
+
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
273 |
+
)
|
274 |
+
|
275 |
+
response = await chat_service.generate_content(
|
276 |
+
settings.TEST_MODEL,
|
277 |
+
gemini_request,
|
278 |
+
api_key
|
279 |
+
)
|
280 |
+
|
281 |
+
if response:
|
282 |
+
return JSONResponse({"status": "valid"})
|
283 |
+
except Exception as e:
|
284 |
+
logger.error(f"Key verification failed: {str(e)}")
|
285 |
+
|
286 |
+
async with key_manager.failure_count_lock:
|
287 |
+
if api_key in key_manager.key_failure_counts:
|
288 |
+
key_manager.key_failure_counts[api_key] += 1
|
289 |
+
logger.warning(f"Verification exception for key: {api_key}, incrementing failure count")
|
290 |
+
|
291 |
+
return JSONResponse({"status": "invalid", "error": str(e)})
|
292 |
+
|
293 |
+
|
294 |
+
@router.post("/verify-selected-keys")
|
295 |
+
async def verify_selected_keys(
|
296 |
+
request: VerifySelectedKeysRequest,
|
297 |
+
chat_service: GeminiChatService = Depends(get_chat_service),
|
298 |
+
key_manager: KeyManager = Depends(get_key_manager)
|
299 |
+
):
|
300 |
+
"""批量验证选定Gemini API密钥的有效性"""
|
301 |
+
logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
|
302 |
+
keys_to_verify = request.keys
|
303 |
+
logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
|
304 |
+
|
305 |
+
if not keys_to_verify:
|
306 |
+
return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
|
307 |
+
|
308 |
+
successful_keys = []
|
309 |
+
failed_keys = {}
|
310 |
+
|
311 |
+
async def _verify_single_key(api_key: str):
|
312 |
+
"""内部函数,用于验证单个密钥并处理异常"""
|
313 |
+
nonlocal successful_keys, failed_keys
|
314 |
+
try:
|
315 |
+
gemini_request = GeminiRequest(
|
316 |
+
contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
|
317 |
+
generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
|
318 |
+
)
|
319 |
+
await chat_service.generate_content(
|
320 |
+
settings.TEST_MODEL,
|
321 |
+
gemini_request,
|
322 |
+
api_key
|
323 |
+
)
|
324 |
+
successful_keys.append(api_key)
|
325 |
+
return api_key, "valid", None
|
326 |
+
except Exception as e:
|
327 |
+
error_message = str(e)
|
328 |
+
logger.warning(f"Key verification failed for {api_key}: {error_message}")
|
329 |
+
async with key_manager.failure_count_lock:
|
330 |
+
if api_key in key_manager.key_failure_counts:
|
331 |
+
key_manager.key_failure_counts[api_key] += 1
|
332 |
+
logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
|
333 |
+
else:
|
334 |
+
key_manager.key_failure_counts[api_key] = 1
|
335 |
+
logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
|
336 |
+
failed_keys[api_key] = error_message
|
337 |
+
return api_key, "invalid", error_message
|
338 |
+
|
339 |
+
tasks = [_verify_single_key(key) for key in keys_to_verify]
|
340 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
341 |
+
|
342 |
+
for result in results:
|
343 |
+
if isinstance(result, Exception):
|
344 |
+
logger.error(f"An unexpected error occurred during bulk verification task: {result}")
|
345 |
+
elif result:
|
346 |
+
if not isinstance(result, Exception) and result:
|
347 |
+
key, status, error = result
|
348 |
+
elif isinstance(result, Exception):
|
349 |
+
logger.error(f"Task execution error during bulk verification: {result}")
|
350 |
+
|
351 |
+
valid_count = len(successful_keys)
|
352 |
+
invalid_count = len(failed_keys)
|
353 |
+
logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
|
354 |
+
|
355 |
+
if failed_keys:
|
356 |
+
message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
|
357 |
+
return JSONResponse({
|
358 |
+
"success": True,
|
359 |
+
"message": message,
|
360 |
+
"successful_keys": successful_keys,
|
361 |
+
"failed_keys": failed_keys,
|
362 |
+
"valid_count": valid_count,
|
363 |
+
"invalid_count": invalid_count
|
364 |
+
})
|
365 |
+
else:
|
366 |
+
message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
|
367 |
+
return JSONResponse({
|
368 |
+
"success": True,
|
369 |
+
"message": message,
|
370 |
+
"successful_keys": successful_keys,
|
371 |
+
"failed_keys": {},
|
372 |
+
"valid_count": valid_count,
|
373 |
+
"invalid_count": 0
|
374 |
+
})
|
app/router/openai_compatiable_routes.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends
|
2 |
+
from fastapi.responses import StreamingResponse
|
3 |
+
|
4 |
+
from app.config.config import settings
|
5 |
+
from app.core.security import SecurityService
|
6 |
+
from app.domain.openai_models import (
|
7 |
+
ChatRequest,
|
8 |
+
EmbeddingRequest,
|
9 |
+
ImageGenerationRequest,
|
10 |
+
)
|
11 |
+
from app.handler.retry_handler import RetryHandler
|
12 |
+
from app.handler.error_handler import handle_route_errors
|
13 |
+
from app.log.logger import get_openai_compatible_logger
|
14 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
15 |
+
from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
|
16 |
+
|
17 |
+
|
18 |
+
router = APIRouter()
|
19 |
+
logger = get_openai_compatible_logger()
|
20 |
+
|
21 |
+
security_service = SecurityService()
|
22 |
+
|
23 |
+
async def get_key_manager():
|
24 |
+
return await get_key_manager_instance()
|
25 |
+
|
26 |
+
|
27 |
+
async def get_next_working_key_wrapper(
|
28 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
29 |
+
):
|
30 |
+
return await key_manager.get_next_working_key()
|
31 |
+
|
32 |
+
|
33 |
+
async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)):
|
34 |
+
"""获取OpenAI聊天服务实例"""
|
35 |
+
return OpenAICompatiableService(settings.BASE_URL, key_manager)
|
36 |
+
|
37 |
+
|
38 |
+
@router.get("/openai/v1/models")
|
39 |
+
async def list_models(
|
40 |
+
_=Depends(security_service.verify_authorization),
|
41 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
42 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
43 |
+
):
|
44 |
+
"""获取可用模型列表。"""
|
45 |
+
operation_name = "list_models"
|
46 |
+
async with handle_route_errors(logger, operation_name):
|
47 |
+
logger.info("Handling models list request")
|
48 |
+
api_key = await key_manager.get_first_valid_key()
|
49 |
+
logger.info(f"Using API key: {api_key}")
|
50 |
+
return await openai_service.get_models(api_key)
|
51 |
+
|
52 |
+
|
53 |
+
@router.post("/openai/v1/chat/completions")
|
54 |
+
@RetryHandler(key_arg="api_key")
|
55 |
+
async def chat_completion(
|
56 |
+
request: ChatRequest,
|
57 |
+
_=Depends(security_service.verify_authorization),
|
58 |
+
api_key: str = Depends(get_next_working_key_wrapper),
|
59 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
60 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
61 |
+
):
|
62 |
+
"""处理聊天补全请求,支持流式响应和特定模型切换。"""
|
63 |
+
operation_name = "chat_completion"
|
64 |
+
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
65 |
+
current_api_key = api_key
|
66 |
+
if is_image_chat:
|
67 |
+
current_api_key = await key_manager.get_paid_key()
|
68 |
+
|
69 |
+
async with handle_route_errors(logger, operation_name):
|
70 |
+
logger.info(f"Handling chat completion request for model: {request.model}")
|
71 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
72 |
+
logger.info(f"Using API key: {current_api_key}")
|
73 |
+
|
74 |
+
if is_image_chat:
|
75 |
+
response = await openai_service.create_image_chat_completion(request, current_api_key)
|
76 |
+
return response
|
77 |
+
else:
|
78 |
+
response = await openai_service.create_chat_completion(request, current_api_key)
|
79 |
+
if request.stream:
|
80 |
+
return StreamingResponse(response, media_type="text/event-stream")
|
81 |
+
return response
|
82 |
+
|
83 |
+
|
84 |
+
@router.post("/openai/v1/images/generations")
|
85 |
+
async def generate_image(
|
86 |
+
request: ImageGenerationRequest,
|
87 |
+
_=Depends(security_service.verify_authorization),
|
88 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
89 |
+
):
|
90 |
+
"""处理图像生成请求。"""
|
91 |
+
operation_name = "generate_image"
|
92 |
+
async with handle_route_errors(logger, operation_name):
|
93 |
+
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
94 |
+
request.model = settings.CREATE_IMAGE_MODEL
|
95 |
+
return await openai_service.generate_images(request)
|
96 |
+
|
97 |
+
|
98 |
+
@router.post("/openai/v1/embeddings")
|
99 |
+
async def embedding(
|
100 |
+
request: EmbeddingRequest,
|
101 |
+
_=Depends(security_service.verify_authorization),
|
102 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
103 |
+
openai_service: OpenAICompatiableService = Depends(get_openai_service),
|
104 |
+
):
|
105 |
+
"""处理文本嵌入请求。"""
|
106 |
+
operation_name = "embedding"
|
107 |
+
async with handle_route_errors(logger, operation_name):
|
108 |
+
logger.info(f"Handling embedding request for model: {request.model}")
|
109 |
+
api_key = await key_manager.get_next_working_key()
|
110 |
+
logger.info(f"Using API key: {api_key}")
|
111 |
+
return await openai_service.create_embeddings(
|
112 |
+
input_text=request.input, model=request.model, api_key=api_key
|
113 |
+
)
|
app/router/openai_routes.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException, Response
|
2 |
+
from fastapi.responses import StreamingResponse
|
3 |
+
|
4 |
+
from app.config.config import settings
|
5 |
+
from app.core.security import SecurityService
|
6 |
+
from app.domain.openai_models import (
|
7 |
+
ChatRequest,
|
8 |
+
EmbeddingRequest,
|
9 |
+
ImageGenerationRequest,
|
10 |
+
TTSRequest,
|
11 |
+
)
|
12 |
+
from app.handler.retry_handler import RetryHandler
|
13 |
+
from app.handler.error_handler import handle_route_errors
|
14 |
+
from app.log.logger import get_openai_logger
|
15 |
+
from app.service.chat.openai_chat_service import OpenAIChatService
|
16 |
+
from app.service.embedding.embedding_service import EmbeddingService
|
17 |
+
from app.service.image.image_create_service import ImageCreateService
|
18 |
+
from app.service.tts.tts_service import TTSService
|
19 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
20 |
+
from app.service.model.model_service import ModelService
|
21 |
+
|
22 |
+
router = APIRouter()
|
23 |
+
logger = get_openai_logger()
|
24 |
+
|
25 |
+
security_service = SecurityService()
|
26 |
+
model_service = ModelService()
|
27 |
+
embedding_service = EmbeddingService()
|
28 |
+
image_create_service = ImageCreateService()
|
29 |
+
tts_service = TTSService()
|
30 |
+
|
31 |
+
|
32 |
+
async def get_key_manager():
|
33 |
+
return await get_key_manager_instance()
|
34 |
+
|
35 |
+
|
36 |
+
async def get_next_working_key_wrapper(
|
37 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
38 |
+
):
|
39 |
+
return await key_manager.get_next_working_key()
|
40 |
+
|
41 |
+
|
42 |
+
async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
43 |
+
"""获取OpenAI聊天服务实例"""
|
44 |
+
return OpenAIChatService(settings.BASE_URL, key_manager)
|
45 |
+
|
46 |
+
|
47 |
+
async def get_tts_service():
|
48 |
+
"""获取TTS服务实例"""
|
49 |
+
return tts_service
|
50 |
+
|
51 |
+
|
52 |
+
@router.get("/v1/models")
|
53 |
+
@router.get("/hf/v1/models")
|
54 |
+
async def list_models(
|
55 |
+
_=Depends(security_service.verify_authorization),
|
56 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
57 |
+
):
|
58 |
+
"""获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
|
59 |
+
operation_name = "list_models"
|
60 |
+
async with handle_route_errors(logger, operation_name):
|
61 |
+
logger.info("Handling models list request")
|
62 |
+
api_key = await key_manager.get_first_valid_key()
|
63 |
+
logger.info(f"Using API key: {api_key}")
|
64 |
+
return await model_service.get_gemini_openai_models(api_key)
|
65 |
+
|
66 |
+
|
67 |
+
@router.post("/v1/chat/completions")
|
68 |
+
@router.post("/hf/v1/chat/completions")
|
69 |
+
@RetryHandler(key_arg="api_key")
|
70 |
+
async def chat_completion(
|
71 |
+
request: ChatRequest,
|
72 |
+
_=Depends(security_service.verify_authorization),
|
73 |
+
api_key: str = Depends(get_next_working_key_wrapper),
|
74 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
75 |
+
chat_service: OpenAIChatService = Depends(get_openai_chat_service),
|
76 |
+
):
|
77 |
+
"""处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
|
78 |
+
operation_name = "chat_completion"
|
79 |
+
is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
|
80 |
+
current_api_key = api_key
|
81 |
+
if is_image_chat:
|
82 |
+
current_api_key = await key_manager.get_paid_key()
|
83 |
+
|
84 |
+
async with handle_route_errors(logger, operation_name):
|
85 |
+
logger.info(f"Handling chat completion request for model: {request.model}")
|
86 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
87 |
+
logger.info(f"Using API key: {current_api_key}")
|
88 |
+
|
89 |
+
if not await model_service.check_model_support(request.model):
|
90 |
+
raise HTTPException(
|
91 |
+
status_code=400, detail=f"Model {request.model} is not supported"
|
92 |
+
)
|
93 |
+
|
94 |
+
if is_image_chat:
|
95 |
+
response = await chat_service.create_image_chat_completion(request, current_api_key)
|
96 |
+
if request.stream:
|
97 |
+
return StreamingResponse(response, media_type="text/event-stream")
|
98 |
+
return response
|
99 |
+
else:
|
100 |
+
response = await chat_service.create_chat_completion(request, current_api_key)
|
101 |
+
if request.stream:
|
102 |
+
return StreamingResponse(response, media_type="text/event-stream")
|
103 |
+
return response
|
104 |
+
|
105 |
+
|
106 |
+
@router.post("/v1/images/generations")
|
107 |
+
@router.post("/hf/v1/images/generations")
|
108 |
+
async def generate_image(
|
109 |
+
request: ImageGenerationRequest,
|
110 |
+
_=Depends(security_service.verify_authorization),
|
111 |
+
):
|
112 |
+
"""处理 OpenAI 图像生成请求。"""
|
113 |
+
operation_name = "generate_image"
|
114 |
+
async with handle_route_errors(logger, operation_name):
|
115 |
+
logger.info(f"Handling image generation request for prompt: {request.prompt}")
|
116 |
+
response = image_create_service.generate_images(request)
|
117 |
+
return response
|
118 |
+
|
119 |
+
|
120 |
+
@router.post("/v1/embeddings")
|
121 |
+
@router.post("/hf/v1/embeddings")
|
122 |
+
async def embedding(
|
123 |
+
request: EmbeddingRequest,
|
124 |
+
_=Depends(security_service.verify_authorization),
|
125 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
126 |
+
):
|
127 |
+
"""处理 OpenAI 文本嵌入请求。"""
|
128 |
+
operation_name = "embedding"
|
129 |
+
async with handle_route_errors(logger, operation_name):
|
130 |
+
logger.info(f"Handling embedding request for model: {request.model}")
|
131 |
+
api_key = await key_manager.get_next_working_key()
|
132 |
+
logger.info(f"Using API key: {api_key}")
|
133 |
+
response = await embedding_service.create_embedding(
|
134 |
+
input_text=request.input, model=request.model, api_key=api_key
|
135 |
+
)
|
136 |
+
return response
|
137 |
+
|
138 |
+
|
139 |
+
@router.get("/v1/keys/list")
|
140 |
+
@router.get("/hf/v1/keys/list")
|
141 |
+
async def get_keys_list(
|
142 |
+
_=Depends(security_service.verify_auth_token),
|
143 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
144 |
+
):
|
145 |
+
"""获取有效和无效的API key列表 (需要管理 Token 认证)。"""
|
146 |
+
operation_name = "get_keys_list"
|
147 |
+
async with handle_route_errors(logger, operation_name):
|
148 |
+
logger.info("Handling keys list request")
|
149 |
+
keys_status = await key_manager.get_keys_by_status()
|
150 |
+
return {
|
151 |
+
"status": "success",
|
152 |
+
"data": {
|
153 |
+
"valid_keys": keys_status["valid_keys"],
|
154 |
+
"invalid_keys": keys_status["invalid_keys"],
|
155 |
+
},
|
156 |
+
"total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
|
157 |
+
}
|
158 |
+
|
159 |
+
|
160 |
+
@router.post("/v1/audio/speech")
|
161 |
+
@router.post("/hf/v1/audio/speech")
|
162 |
+
async def text_to_speech(
|
163 |
+
request: TTSRequest,
|
164 |
+
_=Depends(security_service.verify_authorization),
|
165 |
+
api_key: str = Depends(get_next_working_key_wrapper),
|
166 |
+
tts_service: TTSService = Depends(get_tts_service),
|
167 |
+
):
|
168 |
+
"""处理 OpenAI TTS 请求。"""
|
169 |
+
operation_name = "text_to_speech"
|
170 |
+
async with handle_route_errors(logger, operation_name):
|
171 |
+
logger.info(f"Handling TTS request for model: {request.model}")
|
172 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
173 |
+
logger.info(f"Using API key: {api_key}")
|
174 |
+
audio_data = await tts_service.create_tts(request, api_key)
|
175 |
+
return Response(content=audio_data, media_type="audio/wav")
|
app/router/routes.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
路由配置模块,负责设置和配置应用程序的路由
|
3 |
+
"""
|
4 |
+
|
5 |
+
from fastapi import FastAPI, Request
|
6 |
+
from fastapi.responses import HTMLResponse, RedirectResponse
|
7 |
+
from fastapi.templating import Jinja2Templates
|
8 |
+
|
9 |
+
from app.core.security import verify_auth_token
|
10 |
+
from app.log.logger import get_routes_logger
|
11 |
+
from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes
|
12 |
+
from app.service.key.key_manager import get_key_manager_instance
|
13 |
+
from app.service.stats.stats_service import StatsService
|
14 |
+
|
15 |
+
logger = get_routes_logger()
|
16 |
+
|
17 |
+
templates = Jinja2Templates(directory="app/templates")
|
18 |
+
|
19 |
+
|
20 |
+
def setup_routers(app: FastAPI) -> None:
|
21 |
+
"""
|
22 |
+
设置应用程序的路由
|
23 |
+
|
24 |
+
Args:
|
25 |
+
app: FastAPI应用程序实例
|
26 |
+
"""
|
27 |
+
app.include_router(openai_routes.router)
|
28 |
+
app.include_router(gemini_routes.router)
|
29 |
+
app.include_router(gemini_routes.router_v1beta)
|
30 |
+
app.include_router(config_routes.router)
|
31 |
+
app.include_router(error_log_routes.router)
|
32 |
+
app.include_router(scheduler_routes.router)
|
33 |
+
app.include_router(stats_routes.router)
|
34 |
+
app.include_router(version_routes.router)
|
35 |
+
app.include_router(openai_compatiable_routes.router)
|
36 |
+
app.include_router(vertex_express_routes.router)
|
37 |
+
|
38 |
+
setup_page_routes(app)
|
39 |
+
|
40 |
+
setup_health_routes(app)
|
41 |
+
setup_api_stats_routes(app)
|
42 |
+
|
43 |
+
|
44 |
+
def setup_page_routes(app: FastAPI) -> None:
|
45 |
+
"""
|
46 |
+
设置页面相关的路由
|
47 |
+
|
48 |
+
Args:
|
49 |
+
app: FastAPI应用程序实例
|
50 |
+
"""
|
51 |
+
|
52 |
+
@app.get("/", response_class=HTMLResponse)
|
53 |
+
async def auth_page(request: Request):
|
54 |
+
"""认证页面"""
|
55 |
+
return templates.TemplateResponse("auth.html", {"request": request})
|
56 |
+
|
57 |
+
@app.post("/auth")
|
58 |
+
async def authenticate(request: Request):
|
59 |
+
"""处理认证请求"""
|
60 |
+
try:
|
61 |
+
form = await request.form()
|
62 |
+
auth_token = form.get("auth_token")
|
63 |
+
if not auth_token:
|
64 |
+
logger.warning("Authentication attempt with empty token")
|
65 |
+
return RedirectResponse(url="/", status_code=302)
|
66 |
+
|
67 |
+
if verify_auth_token(auth_token):
|
68 |
+
logger.info("Successful authentication")
|
69 |
+
response = RedirectResponse(url="/config", status_code=302)
|
70 |
+
response.set_cookie(
|
71 |
+
key="auth_token", value=auth_token, httponly=True, max_age=3600
|
72 |
+
)
|
73 |
+
return response
|
74 |
+
logger.warning("Failed authentication attempt with invalid token")
|
75 |
+
return RedirectResponse(url="/", status_code=302)
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Authentication error: {str(e)}")
|
78 |
+
return RedirectResponse(url="/", status_code=302)
|
79 |
+
|
80 |
+
@app.get("/keys", response_class=HTMLResponse)
|
81 |
+
async def keys_page(request: Request):
|
82 |
+
"""密钥管理页面"""
|
83 |
+
try:
|
84 |
+
auth_token = request.cookies.get("auth_token")
|
85 |
+
if not auth_token or not verify_auth_token(auth_token):
|
86 |
+
logger.warning("Unauthorized access attempt to keys page")
|
87 |
+
return RedirectResponse(url="/", status_code=302)
|
88 |
+
|
89 |
+
key_manager = await get_key_manager_instance()
|
90 |
+
keys_status = await key_manager.get_keys_by_status()
|
91 |
+
total_keys = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
|
92 |
+
valid_key_count = len(keys_status["valid_keys"])
|
93 |
+
invalid_key_count = len(keys_status["invalid_keys"])
|
94 |
+
|
95 |
+
stats_service = StatsService()
|
96 |
+
api_stats = await stats_service.get_api_usage_stats()
|
97 |
+
logger.info(f"API stats retrieved: {api_stats}")
|
98 |
+
|
99 |
+
logger.info(f"Keys status retrieved successfully. Total keys: {total_keys}")
|
100 |
+
return templates.TemplateResponse(
|
101 |
+
"keys_status.html",
|
102 |
+
{
|
103 |
+
"request": request,
|
104 |
+
"valid_keys": keys_status["valid_keys"],
|
105 |
+
"invalid_keys": keys_status["invalid_keys"],
|
106 |
+
"total_keys": total_keys,
|
107 |
+
"valid_key_count": valid_key_count,
|
108 |
+
"invalid_key_count": invalid_key_count,
|
109 |
+
"api_stats": api_stats,
|
110 |
+
},
|
111 |
+
)
|
112 |
+
except Exception as e:
|
113 |
+
logger.error(f"Error retrieving keys status or API stats: {str(e)}")
|
114 |
+
raise
|
115 |
+
|
116 |
+
@app.get("/config", response_class=HTMLResponse)
|
117 |
+
async def config_page(request: Request):
|
118 |
+
"""配置编辑页面"""
|
119 |
+
try:
|
120 |
+
auth_token = request.cookies.get("auth_token")
|
121 |
+
if not auth_token or not verify_auth_token(auth_token):
|
122 |
+
logger.warning("Unauthorized access attempt to config page")
|
123 |
+
return RedirectResponse(url="/", status_code=302)
|
124 |
+
|
125 |
+
logger.info("Config page accessed successfully")
|
126 |
+
return templates.TemplateResponse("config_editor.html", {"request": request})
|
127 |
+
except Exception as e:
|
128 |
+
logger.error(f"Error accessing config page: {str(e)}")
|
129 |
+
raise
|
130 |
+
|
131 |
+
@app.get("/logs", response_class=HTMLResponse)
|
132 |
+
async def logs_page(request: Request):
|
133 |
+
"""错误日志页面"""
|
134 |
+
try:
|
135 |
+
auth_token = request.cookies.get("auth_token")
|
136 |
+
if not auth_token or not verify_auth_token(auth_token):
|
137 |
+
logger.warning("Unauthorized access attempt to logs page")
|
138 |
+
return RedirectResponse(url="/", status_code=302)
|
139 |
+
|
140 |
+
logger.info("Logs page accessed successfully")
|
141 |
+
return templates.TemplateResponse("error_logs.html", {"request": request})
|
142 |
+
except Exception as e:
|
143 |
+
logger.error(f"Error accessing logs page: {str(e)}")
|
144 |
+
raise
|
145 |
+
|
146 |
+
|
147 |
+
def setup_health_routes(app: FastAPI) -> None:
|
148 |
+
"""
|
149 |
+
设置健康检查相关的路由
|
150 |
+
|
151 |
+
Args:
|
152 |
+
app: FastAPI应用程序实例
|
153 |
+
"""
|
154 |
+
|
155 |
+
@app.get("/health")
|
156 |
+
async def health_check(request: Request):
|
157 |
+
"""健康检查端点"""
|
158 |
+
logger.info("Health check endpoint called")
|
159 |
+
return {"status": "healthy"}
|
160 |
+
|
161 |
+
|
162 |
+
def setup_api_stats_routes(app: FastAPI) -> None:
|
163 |
+
"""
|
164 |
+
设置 API 统计相关的路由
|
165 |
+
|
166 |
+
Args:
|
167 |
+
app: FastAPI应用程序实例
|
168 |
+
"""
|
169 |
+
@app.get("/api/stats/details")
|
170 |
+
async def api_stats_details(request: Request, period: str):
|
171 |
+
"""获取指定时间段内的 API 调用详情"""
|
172 |
+
try:
|
173 |
+
auth_token = request.cookies.get("auth_token")
|
174 |
+
if not auth_token or not verify_auth_token(auth_token):
|
175 |
+
logger.warning("Unauthorized access attempt to API stats details")
|
176 |
+
return {"error": "Unauthorized"}, 401
|
177 |
+
|
178 |
+
logger.info(f"Fetching API call details for period: {period}")
|
179 |
+
stats_service = StatsService()
|
180 |
+
details = await stats_service.get_api_call_details(period)
|
181 |
+
return details
|
182 |
+
except ValueError as e:
|
183 |
+
logger.warning(f"Invalid period requested for API stats details: {period} - {str(e)}")
|
184 |
+
return {"error": str(e)}, 400
|
185 |
+
except Exception as e:
|
186 |
+
logger.error(f"Error fetching API stats details for period {period}: {str(e)}")
|
187 |
+
return {"error": "Internal server error"}, 500
|
app/router/scheduler_routes.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
定时任务控制路由模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
from fastapi import APIRouter, Request, HTTPException, status
|
6 |
+
from fastapi.responses import JSONResponse
|
7 |
+
|
8 |
+
from app.core.security import verify_auth_token
|
9 |
+
from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
|
10 |
+
from app.log.logger import get_scheduler_routes
|
11 |
+
|
12 |
+
logger = get_scheduler_routes()
|
13 |
+
|
14 |
+
router = APIRouter(
|
15 |
+
prefix="/api/scheduler",
|
16 |
+
tags=["Scheduler"]
|
17 |
+
)
|
18 |
+
|
19 |
+
async def verify_token(request: Request):
|
20 |
+
auth_token = request.cookies.get("auth_token")
|
21 |
+
if not auth_token or not verify_auth_token(auth_token):
|
22 |
+
logger.warning("Unauthorized access attempt to scheduler API")
|
23 |
+
raise HTTPException(
|
24 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
25 |
+
detail="Not authenticated",
|
26 |
+
headers={"WWW-Authenticate": "Bearer"},
|
27 |
+
)
|
28 |
+
|
29 |
+
@router.post("/start", summary="启动定时任务")
|
30 |
+
async def start_scheduler_endpoint(request: Request):
|
31 |
+
"""Start the background scheduler task"""
|
32 |
+
await verify_token(request)
|
33 |
+
try:
|
34 |
+
logger.info("Received request to start scheduler.")
|
35 |
+
start_scheduler()
|
36 |
+
return JSONResponse(content={"message": "Scheduler started successfully."}, status_code=status.HTTP_200_OK)
|
37 |
+
except Exception as e:
|
38 |
+
logger.error(f"Error starting scheduler: {str(e)}", exc_info=True)
|
39 |
+
raise HTTPException(
|
40 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
41 |
+
detail=f"Failed to start scheduler: {str(e)}"
|
42 |
+
)
|
43 |
+
|
44 |
+
@router.post("/stop", summary="停止定时任务")
|
45 |
+
async def stop_scheduler_endpoint(request: Request):
|
46 |
+
"""Stop the background scheduler task"""
|
47 |
+
await verify_token(request)
|
48 |
+
try:
|
49 |
+
logger.info("Received request to stop scheduler.")
|
50 |
+
stop_scheduler()
|
51 |
+
return JSONResponse(content={"message": "Scheduler stopped successfully."}, status_code=status.HTTP_200_OK)
|
52 |
+
except Exception as e:
|
53 |
+
logger.error(f"Error stopping scheduler: {str(e)}", exc_info=True)
|
54 |
+
raise HTTPException(
|
55 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
56 |
+
detail=f"Failed to stop scheduler: {str(e)}"
|
57 |
+
)
|
app/router/stats_routes.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
2 |
+
from starlette import status
|
3 |
+
from app.core.security import verify_auth_token
|
4 |
+
from app.service.stats.stats_service import StatsService
|
5 |
+
from app.log.logger import get_stats_logger
|
6 |
+
|
7 |
+
logger = get_stats_logger()
|
8 |
+
|
9 |
+
|
10 |
+
async def verify_token(request: Request):
|
11 |
+
auth_token = request.cookies.get("auth_token")
|
12 |
+
if not auth_token or not verify_auth_token(auth_token):
|
13 |
+
logger.warning("Unauthorized access attempt to scheduler API")
|
14 |
+
raise HTTPException(
|
15 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
16 |
+
detail="Not authenticated",
|
17 |
+
headers={"WWW-Authenticate": "Bearer"},
|
18 |
+
)
|
19 |
+
|
20 |
+
router = APIRouter(
|
21 |
+
prefix="/api",
|
22 |
+
tags=["stats"],
|
23 |
+
dependencies=[Depends(verify_token)]
|
24 |
+
)
|
25 |
+
|
26 |
+
stats_service = StatsService()
|
27 |
+
|
28 |
+
@router.get("/key-usage-details/{key}",
|
29 |
+
summary="获取指定密钥最近24小时的模型调用次数",
|
30 |
+
description="根据提供的 API 密钥,返回过去24小时内每个模型被调用的次数统计。")
|
31 |
+
async def get_key_usage_details(key: str):
|
32 |
+
"""
|
33 |
+
Retrieves the model usage count for a specific API key within the last 24 hours.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
key: The API key to get usage details for.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
A dictionary with model names as keys and their call counts as values.
|
40 |
+
Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
|
41 |
+
|
42 |
+
Raises:
|
43 |
+
HTTPException: If an error occurs during data retrieval.
|
44 |
+
"""
|
45 |
+
try:
|
46 |
+
usage_details = await stats_service.get_key_usage_details_last_24h(key)
|
47 |
+
if usage_details is None:
|
48 |
+
return {}
|
49 |
+
return usage_details
|
50 |
+
except Exception as e:
|
51 |
+
logger.error(f"Error fetching key usage details for key {key[:4]}...: {e}")
|
52 |
+
raise HTTPException(
|
53 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
54 |
+
detail=f"获取密钥使用详情时出错: {e}"
|
55 |
+
)
|
app/router/version_routes.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from app.service.update.update_service import check_for_updates
|
6 |
+
from app.utils.helpers import get_current_version
|
7 |
+
from app.log.logger import get_update_logger
|
8 |
+
|
9 |
+
router = APIRouter(prefix="/api/version", tags=["Version"])
|
10 |
+
logger = get_update_logger()
|
11 |
+
|
12 |
+
class VersionInfo(BaseModel):
|
13 |
+
current_version: str = Field(..., description="当前应用程序版本")
|
14 |
+
latest_version: Optional[str] = Field(None, description="可用的最新版本")
|
15 |
+
update_available: bool = Field(False, description="是否有可用更新")
|
16 |
+
error_message: Optional[str] = Field(None, description="检查更新时发生的错误信息")
|
17 |
+
|
18 |
+
@router.get("/check", response_model=VersionInfo, summary="检查应用程序更新")
|
19 |
+
async def get_version_info():
|
20 |
+
"""
|
21 |
+
检查当前应用程序版本与最新的 GitHub release 版本。
|
22 |
+
"""
|
23 |
+
try:
|
24 |
+
current_version = get_current_version()
|
25 |
+
update_available, latest_version, error_message = await check_for_updates()
|
26 |
+
|
27 |
+
logger.info(f"Version check API result: current={current_version}, latest={latest_version}, available={update_available}, error='{error_message}'")
|
28 |
+
|
29 |
+
return VersionInfo(
|
30 |
+
current_version=current_version,
|
31 |
+
latest_version=latest_version,
|
32 |
+
update_available=update_available,
|
33 |
+
error_message=error_message
|
34 |
+
)
|
35 |
+
except Exception as e:
|
36 |
+
logger.error(f"Error in /api/version/check endpoint: {e}", exc_info=True)
|
37 |
+
raise HTTPException(status_code=500, detail="检查版本信息时发生内部错误")
|
app/router/vertex_express_routes.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
2 |
+
from fastapi.responses import StreamingResponse
|
3 |
+
from copy import deepcopy
|
4 |
+
from app.config.config import settings
|
5 |
+
from app.log.logger import get_vertex_express_logger
|
6 |
+
from app.core.security import SecurityService
|
7 |
+
from app.domain.gemini_models import GeminiRequest
|
8 |
+
from app.service.chat.vertex_express_chat_service import GeminiChatService
|
9 |
+
from app.service.key.key_manager import KeyManager, get_key_manager_instance
|
10 |
+
from app.service.model.model_service import ModelService
|
11 |
+
from app.handler.retry_handler import RetryHandler
|
12 |
+
from app.handler.error_handler import handle_route_errors
|
13 |
+
from app.core.constants import API_VERSION
|
14 |
+
|
15 |
+
router = APIRouter(prefix=f"/vertex-express/{API_VERSION}")
|
16 |
+
logger = get_vertex_express_logger()
|
17 |
+
|
18 |
+
security_service = SecurityService()
|
19 |
+
model_service = ModelService()
|
20 |
+
|
21 |
+
|
22 |
+
async def get_key_manager():
|
23 |
+
"""获取密钥管理器实例"""
|
24 |
+
return await get_key_manager_instance()
|
25 |
+
|
26 |
+
|
27 |
+
async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
|
28 |
+
"""获取下一个可用的API密钥"""
|
29 |
+
return await key_manager.get_next_working_vertex_key()
|
30 |
+
|
31 |
+
|
32 |
+
async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
|
33 |
+
"""获取Gemini聊天服务实例"""
|
34 |
+
return GeminiChatService(settings.VERTEX_EXPRESS_BASE_URL, key_manager)
|
35 |
+
|
36 |
+
|
37 |
+
@router.get("/models")
|
38 |
+
async def list_models(
|
39 |
+
_=Depends(security_service.verify_key_or_goog_api_key),
|
40 |
+
key_manager: KeyManager = Depends(get_key_manager)
|
41 |
+
):
|
42 |
+
"""获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
|
43 |
+
operation_name = "list_gemini_models"
|
44 |
+
logger.info("-" * 50 + operation_name + "-" * 50)
|
45 |
+
logger.info("Handling Gemini models list request")
|
46 |
+
|
47 |
+
try:
|
48 |
+
api_key = await key_manager.get_first_valid_key()
|
49 |
+
if not api_key:
|
50 |
+
raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
|
51 |
+
logger.info(f"Using API key: {api_key}")
|
52 |
+
|
53 |
+
models_data = await model_service.get_gemini_models(api_key)
|
54 |
+
if not models_data or "models" not in models_data:
|
55 |
+
raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
|
56 |
+
|
57 |
+
models_json = deepcopy(models_data)
|
58 |
+
model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
|
59 |
+
|
60 |
+
def add_derived_model(base_name, suffix, display_suffix):
|
61 |
+
model = model_mapping.get(base_name)
|
62 |
+
if not model:
|
63 |
+
logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
|
64 |
+
return
|
65 |
+
item = deepcopy(model)
|
66 |
+
item["name"] = f"models/{base_name}{suffix}"
|
67 |
+
display_name = f'{item.get("displayName", base_name)}{display_suffix}'
|
68 |
+
item["displayName"] = display_name
|
69 |
+
item["description"] = display_name
|
70 |
+
models_json["models"].append(item)
|
71 |
+
|
72 |
+
if settings.SEARCH_MODELS:
|
73 |
+
for name in settings.SEARCH_MODELS:
|
74 |
+
add_derived_model(name, "-search", " For Search")
|
75 |
+
if settings.IMAGE_MODELS:
|
76 |
+
for name in settings.IMAGE_MODELS:
|
77 |
+
add_derived_model(name, "-image", " For Image")
|
78 |
+
if settings.THINKING_MODELS:
|
79 |
+
for name in settings.THINKING_MODELS:
|
80 |
+
add_derived_model(name, "-non-thinking", " Non Thinking")
|
81 |
+
|
82 |
+
logger.info("Gemini models list request successful")
|
83 |
+
return models_json
|
84 |
+
except HTTPException as http_exc:
|
85 |
+
raise http_exc
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f"Error getting Gemini models list: {str(e)}")
|
88 |
+
raise HTTPException(
|
89 |
+
status_code=500, detail="Internal server error while fetching Gemini models list"
|
90 |
+
) from e
|
91 |
+
|
92 |
+
|
93 |
+
@router.post("/models/{model_name}:generateContent")
|
94 |
+
@RetryHandler(key_arg="api_key")
|
95 |
+
async def generate_content(
|
96 |
+
model_name: str,
|
97 |
+
request: GeminiRequest,
|
98 |
+
_=Depends(security_service.verify_key_or_goog_api_key),
|
99 |
+
api_key: str = Depends(get_next_working_key),
|
100 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
101 |
+
chat_service: GeminiChatService = Depends(get_chat_service)
|
102 |
+
):
|
103 |
+
"""处理 Gemini 非流式内容生成请求。"""
|
104 |
+
operation_name = "gemini_generate_content"
|
105 |
+
async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
|
106 |
+
logger.info(f"Handling Gemini content generation request for model: {model_name}")
|
107 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
108 |
+
logger.info(f"Using API key: {api_key}")
|
109 |
+
|
110 |
+
if not await model_service.check_model_support(model_name):
|
111 |
+
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
112 |
+
|
113 |
+
response = await chat_service.generate_content(
|
114 |
+
model=model_name,
|
115 |
+
request=request,
|
116 |
+
api_key=api_key
|
117 |
+
)
|
118 |
+
return response
|
119 |
+
|
120 |
+
|
121 |
+
@router.post("/models/{model_name}:streamGenerateContent")
|
122 |
+
@RetryHandler(key_arg="api_key")
|
123 |
+
async def stream_generate_content(
|
124 |
+
model_name: str,
|
125 |
+
request: GeminiRequest,
|
126 |
+
_=Depends(security_service.verify_key_or_goog_api_key),
|
127 |
+
api_key: str = Depends(get_next_working_key),
|
128 |
+
key_manager: KeyManager = Depends(get_key_manager),
|
129 |
+
chat_service: GeminiChatService = Depends(get_chat_service)
|
130 |
+
):
|
131 |
+
"""处理 Gemini 流式内容生成请求。"""
|
132 |
+
operation_name = "gemini_stream_generate_content"
|
133 |
+
async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
|
134 |
+
logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
|
135 |
+
logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
|
136 |
+
logger.info(f"Using API key: {api_key}")
|
137 |
+
|
138 |
+
if not await model_service.check_model_support(model_name):
|
139 |
+
raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
|
140 |
+
|
141 |
+
response_stream = chat_service.stream_generate_content(
|
142 |
+
model=model_name,
|
143 |
+
request=request,
|
144 |
+
api_key=api_key
|
145 |
+
)
|
146 |
+
return StreamingResponse(response_stream, media_type="text/event-stream")
|
app/scheduler/scheduled_tasks.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
3 |
+
|
4 |
+
from app.config.config import settings
|
5 |
+
from app.domain.gemini_models import GeminiContent, GeminiRequest
|
6 |
+
from app.log.logger import Logger
|
7 |
+
from app.service.chat.gemini_chat_service import GeminiChatService
|
8 |
+
from app.service.error_log.error_log_service import delete_old_error_logs
|
9 |
+
from app.service.key.key_manager import get_key_manager_instance
|
10 |
+
from app.service.request_log.request_log_service import delete_old_request_logs_task
|
11 |
+
|
12 |
+
logger = Logger.setup_logger("scheduler")
|
13 |
+
|
14 |
+
|
15 |
+
async def check_failed_keys():
|
16 |
+
"""
|
17 |
+
定时检查失败次数大于0的API密钥,并尝试验证它们。
|
18 |
+
如果验证成功,重置失败计数;如果失败,增加失败计数。
|
19 |
+
"""
|
20 |
+
logger.info("Starting scheduled check for failed API keys...")
|
21 |
+
try:
|
22 |
+
key_manager = await get_key_manager_instance()
|
23 |
+
# 确保 KeyManager 已经初始化
|
24 |
+
if not key_manager or not hasattr(key_manager, "key_failure_counts"):
|
25 |
+
logger.warning(
|
26 |
+
"KeyManager instance not available or not initialized. Skipping check."
|
27 |
+
)
|
28 |
+
return
|
29 |
+
|
30 |
+
# 创建 GeminiChatService 实例用于验证
|
31 |
+
# 注意:这里直接创建实例,而不是通过依赖注入,因为这是后台任务
|
32 |
+
chat_service = GeminiChatService(settings.BASE_URL, key_manager)
|
33 |
+
|
34 |
+
# 获取需要检查的 key 列表 (失败次数 > 0)
|
35 |
+
keys_to_check = []
|
36 |
+
async with key_manager.failure_count_lock: # 访问共享数据需要加锁
|
37 |
+
# 复制一份以避免在迭代时修改字典
|
38 |
+
failure_counts_copy = key_manager.key_failure_counts.copy()
|
39 |
+
keys_to_check = [
|
40 |
+
key for key, count in failure_counts_copy.items() if count > 0
|
41 |
+
] # 检查所有失败次数大于0的key
|
42 |
+
|
43 |
+
if not keys_to_check:
|
44 |
+
logger.info("No keys with failure count > 0 found. Skipping verification.")
|
45 |
+
return
|
46 |
+
|
47 |
+
logger.info(
|
48 |
+
f"Found {len(keys_to_check)} keys with failure count > 0 to verify."
|
49 |
+
)
|
50 |
+
|
51 |
+
for key in keys_to_check:
|
52 |
+
# 隐藏部分 key 用于日志记录
|
53 |
+
log_key = f"{key[:4]}...{key[-4:]}" if len(key) > 8 else key
|
54 |
+
logger.info(f"Verifying key: {log_key}...")
|
55 |
+
try:
|
56 |
+
# 构造测试请求
|
57 |
+
gemini_request = GeminiRequest(
|
58 |
+
contents=[
|
59 |
+
GeminiContent(
|
60 |
+
role="user",
|
61 |
+
parts=[{"text": "hi"}],
|
62 |
+
)
|
63 |
+
]
|
64 |
+
)
|
65 |
+
await chat_service.generate_content(
|
66 |
+
settings.TEST_MODEL, gemini_request, key
|
67 |
+
)
|
68 |
+
logger.info(
|
69 |
+
f"Key {log_key} verification successful. Resetting failure count."
|
70 |
+
)
|
71 |
+
await key_manager.reset_key_failure_count(key)
|
72 |
+
except Exception as e:
|
73 |
+
logger.warning(
|
74 |
+
f"Key {log_key} verification failed: {str(e)}. Incrementing failure count."
|
75 |
+
)
|
76 |
+
# 直接操作计数器,需要加锁
|
77 |
+
async with key_manager.failure_count_lock:
|
78 |
+
# 再次检查 key 是否存在且失败次数未达上限
|
79 |
+
if (
|
80 |
+
key in key_manager.key_failure_counts
|
81 |
+
and key_manager.key_failure_counts[key]
|
82 |
+
< key_manager.MAX_FAILURES
|
83 |
+
):
|
84 |
+
key_manager.key_failure_counts[key] += 1
|
85 |
+
logger.info(
|
86 |
+
f"Failure count for key {log_key} incremented to {key_manager.key_failure_counts[key]}."
|
87 |
+
)
|
88 |
+
elif key in key_manager.key_failure_counts:
|
89 |
+
logger.warning(
|
90 |
+
f"Key {log_key} reached MAX_FAILURES ({key_manager.MAX_FAILURES}). Not incrementing further."
|
91 |
+
)
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(
|
95 |
+
f"An error occurred during the scheduled key check: {str(e)}", exc_info=True
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def setup_scheduler():
|
100 |
+
"""设置并启动 APScheduler"""
|
101 |
+
scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
|
102 |
+
# 添加检查失败密钥的定时任务
|
103 |
+
scheduler.add_job(
|
104 |
+
check_failed_keys,
|
105 |
+
"interval",
|
106 |
+
hours=settings.CHECK_INTERVAL_HOURS,
|
107 |
+
id="check_failed_keys_job",
|
108 |
+
name="Check Failed API Keys",
|
109 |
+
)
|
110 |
+
logger.info(
|
111 |
+
f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
|
112 |
+
)
|
113 |
+
|
114 |
+
# 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行
|
115 |
+
scheduler.add_job(
|
116 |
+
delete_old_error_logs,
|
117 |
+
"cron",
|
118 |
+
hour=3,
|
119 |
+
minute=0,
|
120 |
+
id="delete_old_error_logs_job",
|
121 |
+
name="Delete Old Error Logs",
|
122 |
+
)
|
123 |
+
logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
|
124 |
+
|
125 |
+
# 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行
|
126 |
+
scheduler.add_job(
|
127 |
+
delete_old_request_logs_task,
|
128 |
+
"cron",
|
129 |
+
hour=3,
|
130 |
+
minute=5,
|
131 |
+
id="delete_old_request_logs_job",
|
132 |
+
name="Delete Old Request Logs",
|
133 |
+
)
|
134 |
+
logger.info(
|
135 |
+
f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
|
136 |
+
)
|
137 |
+
|
138 |
+
scheduler.start()
|
139 |
+
logger.info("Scheduler started with all jobs.")
|
140 |
+
return scheduler
|
141 |
+
|
142 |
+
|
143 |
+
# 可以在这里添加一个全局的 scheduler 实例,以便在应用关闭时优雅地停止
|
144 |
+
scheduler_instance = None
|
145 |
+
|
146 |
+
|
147 |
+
def start_scheduler():
|
148 |
+
global scheduler_instance
|
149 |
+
if scheduler_instance is None or not scheduler_instance.running:
|
150 |
+
logger.info("Starting scheduler...")
|
151 |
+
scheduler_instance = setup_scheduler()
|
152 |
+
logger.info("Scheduler is already running.")
|
153 |
+
|
154 |
+
|
155 |
+
def stop_scheduler():
|
156 |
+
global scheduler_instance
|
157 |
+
if scheduler_instance and scheduler_instance.running:
|
158 |
+
scheduler_instance.shutdown()
|
159 |
+
logger.info("Scheduler stopped.")
|
app/service/chat/gemini_chat_service.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/services/chat_service.py
|
2 |
+
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import datetime
|
6 |
+
import time
|
7 |
+
from typing import Any, AsyncGenerator, Dict, List
|
8 |
+
from app.config.config import settings
|
9 |
+
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
10 |
+
from app.domain.gemini_models import GeminiRequest
|
11 |
+
from app.handler.response_handler import GeminiResponseHandler
|
12 |
+
from app.handler.stream_optimizer import gemini_optimizer
|
13 |
+
from app.log.logger import get_gemini_logger
|
14 |
+
from app.service.client.api_client import GeminiApiClient
|
15 |
+
from app.service.key.key_manager import KeyManager
|
16 |
+
from app.database.services import add_error_log, add_request_log
|
17 |
+
|
18 |
+
logger = get_gemini_logger()
|
19 |
+
|
20 |
+
|
21 |
+
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
22 |
+
"""判断消息是否包含图片部分"""
|
23 |
+
for content in contents:
|
24 |
+
if "parts" in content:
|
25 |
+
for part in content["parts"]:
|
26 |
+
if "image_url" in part or "inline_data" in part:
|
27 |
+
return True
|
28 |
+
return False
|
29 |
+
|
30 |
+
|
31 |
+
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
32 |
+
"""构建工具"""
|
33 |
+
|
34 |
+
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
35 |
+
record = dict()
|
36 |
+
for item in tools:
|
37 |
+
if not item or not isinstance(item, dict):
|
38 |
+
continue
|
39 |
+
|
40 |
+
for k, v in item.items():
|
41 |
+
if k == "functionDeclarations" and v and isinstance(v, list):
|
42 |
+
functions = record.get("functionDeclarations", [])
|
43 |
+
functions.extend(v)
|
44 |
+
record["functionDeclarations"] = functions
|
45 |
+
else:
|
46 |
+
record[k] = v
|
47 |
+
return record
|
48 |
+
|
49 |
+
tool = dict()
|
50 |
+
if payload and isinstance(payload, dict) and "tools" in payload:
|
51 |
+
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
52 |
+
payload["tools"] = [payload.get("tools")]
|
53 |
+
items = payload.get("tools", [])
|
54 |
+
if items and isinstance(items, list):
|
55 |
+
tool.update(_merge_tools(items))
|
56 |
+
|
57 |
+
if (
|
58 |
+
settings.TOOLS_CODE_EXECUTION_ENABLED
|
59 |
+
and not (model.endswith("-search") or "-thinking" in model)
|
60 |
+
and not _has_image_parts(payload.get("contents", []))
|
61 |
+
):
|
62 |
+
tool["codeExecution"] = {}
|
63 |
+
if model.endswith("-search"):
|
64 |
+
tool["googleSearch"] = {}
|
65 |
+
|
66 |
+
# 解决 "Tool use with function calling is unsupported" 问题
|
67 |
+
if tool.get("functionDeclarations"):
|
68 |
+
tool.pop("googleSearch", None)
|
69 |
+
tool.pop("codeExecution", None)
|
70 |
+
|
71 |
+
return [tool] if tool else []
|
72 |
+
|
73 |
+
|
74 |
+
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
75 |
+
"""获取安全设置"""
|
76 |
+
if model == "gemini-2.0-flash-exp":
|
77 |
+
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
78 |
+
return settings.SAFETY_SETTINGS
|
79 |
+
|
80 |
+
|
81 |
+
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
82 |
+
"""构建请求payload"""
|
83 |
+
request_dict = request.model_dump()
|
84 |
+
if request.generationConfig:
|
85 |
+
if request.generationConfig.maxOutputTokens is None:
|
86 |
+
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
87 |
+
request_dict["generationConfig"].pop("maxOutputTokens")
|
88 |
+
|
89 |
+
payload = {
|
90 |
+
"contents": request_dict.get("contents", []),
|
91 |
+
"tools": _build_tools(model, request_dict),
|
92 |
+
"safetySettings": _get_safety_settings(model),
|
93 |
+
"generationConfig": request_dict.get("generationConfig"),
|
94 |
+
"systemInstruction": request_dict.get("systemInstruction"),
|
95 |
+
}
|
96 |
+
|
97 |
+
if model.endswith("-image") or model.endswith("-image-generation"):
|
98 |
+
payload.pop("systemInstruction")
|
99 |
+
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
100 |
+
|
101 |
+
# 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
|
102 |
+
client_thinking_config = None
|
103 |
+
if request.generationConfig and request.generationConfig.thinkingConfig:
|
104 |
+
client_thinking_config = request.generationConfig.thinkingConfig
|
105 |
+
|
106 |
+
if client_thinking_config is not None:
|
107 |
+
# 客户端提供了思考配置,直接使用
|
108 |
+
payload["generationConfig"]["thinkingConfig"] = client_thinking_config
|
109 |
+
else:
|
110 |
+
# 客户端没有提供思考配置,使用默认配置
|
111 |
+
if model.endswith("-non-thinking"):
|
112 |
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
113 |
+
elif model in settings.THINKING_BUDGET_MAP:
|
114 |
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
115 |
+
|
116 |
+
return payload
|
117 |
+
|
118 |
+
|
119 |
+
class GeminiChatService:
|
120 |
+
"""聊天服务"""
|
121 |
+
|
122 |
+
def __init__(self, base_url: str, key_manager: KeyManager):
|
123 |
+
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
124 |
+
self.key_manager = key_manager
|
125 |
+
self.response_handler = GeminiResponseHandler()
|
126 |
+
|
127 |
+
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
128 |
+
"""从响应中提取文本内容"""
|
129 |
+
if not response.get("candidates"):
|
130 |
+
return ""
|
131 |
+
|
132 |
+
candidate = response["candidates"][0]
|
133 |
+
content = candidate.get("content", {})
|
134 |
+
parts = content.get("parts", [])
|
135 |
+
|
136 |
+
if parts and "text" in parts[0]:
|
137 |
+
return parts[0].get("text", "")
|
138 |
+
return ""
|
139 |
+
|
140 |
+
def _create_char_response(
|
141 |
+
self, original_response: Dict[str, Any], text: str
|
142 |
+
) -> Dict[str, Any]:
|
143 |
+
"""创建包含指定文本的响应"""
|
144 |
+
response_copy = json.loads(json.dumps(original_response))
|
145 |
+
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
146 |
+
"content", {}
|
147 |
+
).get("parts"):
|
148 |
+
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
149 |
+
return response_copy
|
150 |
+
|
151 |
+
async def generate_content(
|
152 |
+
self, model: str, request: GeminiRequest, api_key: str
|
153 |
+
) -> Dict[str, Any]:
|
154 |
+
"""生成内容"""
|
155 |
+
payload = _build_payload(model, request)
|
156 |
+
start_time = time.perf_counter()
|
157 |
+
request_datetime = datetime.datetime.now()
|
158 |
+
is_success = False
|
159 |
+
status_code = None
|
160 |
+
response = None
|
161 |
+
|
162 |
+
try:
|
163 |
+
response = await self.api_client.generate_content(payload, model, api_key)
|
164 |
+
is_success = True
|
165 |
+
status_code = 200
|
166 |
+
return self.response_handler.handle_response(response, model, stream=False)
|
167 |
+
except Exception as e:
|
168 |
+
is_success = False
|
169 |
+
error_log_msg = str(e)
|
170 |
+
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
171 |
+
match = re.search(r"status code (\d+)", error_log_msg)
|
172 |
+
if match:
|
173 |
+
status_code = int(match.group(1))
|
174 |
+
else:
|
175 |
+
status_code = 500
|
176 |
+
|
177 |
+
await add_error_log(
|
178 |
+
gemini_key=api_key,
|
179 |
+
model_name=model,
|
180 |
+
error_type="gemini-chat-non-stream",
|
181 |
+
error_log=error_log_msg,
|
182 |
+
error_code=status_code,
|
183 |
+
request_msg=payload
|
184 |
+
)
|
185 |
+
raise e
|
186 |
+
finally:
|
187 |
+
end_time = time.perf_counter()
|
188 |
+
latency_ms = int((end_time - start_time) * 1000)
|
189 |
+
await add_request_log(
|
190 |
+
model_name=model,
|
191 |
+
api_key=api_key,
|
192 |
+
is_success=is_success,
|
193 |
+
status_code=status_code,
|
194 |
+
latency_ms=latency_ms,
|
195 |
+
request_time=request_datetime
|
196 |
+
)
|
197 |
+
|
198 |
+
async def stream_generate_content(
|
199 |
+
self, model: str, request: GeminiRequest, api_key: str
|
200 |
+
) -> AsyncGenerator[str, None]:
|
201 |
+
"""流式生成内容"""
|
202 |
+
retries = 0
|
203 |
+
max_retries = settings.MAX_RETRIES
|
204 |
+
payload = _build_payload(model, request)
|
205 |
+
is_success = False
|
206 |
+
status_code = None
|
207 |
+
final_api_key = api_key
|
208 |
+
|
209 |
+
while retries < max_retries:
|
210 |
+
request_datetime = datetime.datetime.now()
|
211 |
+
start_time = time.perf_counter()
|
212 |
+
current_attempt_key = api_key
|
213 |
+
final_api_key = current_attempt_key
|
214 |
+
try:
|
215 |
+
async for line in self.api_client.stream_generate_content(
|
216 |
+
payload, model, current_attempt_key
|
217 |
+
):
|
218 |
+
# print(line)
|
219 |
+
if line.startswith("data:"):
|
220 |
+
line = line[6:]
|
221 |
+
response_data = self.response_handler.handle_response(
|
222 |
+
json.loads(line), model, stream=True
|
223 |
+
)
|
224 |
+
text = self._extract_text_from_response(response_data)
|
225 |
+
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
|
226 |
+
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
227 |
+
# 使用流式输出优化器处理文本输出
|
228 |
+
async for (
|
229 |
+
optimized_chunk
|
230 |
+
) in gemini_optimizer.optimize_stream_output(
|
231 |
+
text,
|
232 |
+
lambda t: self._create_char_response(response_data, t),
|
233 |
+
lambda c: "data: " + json.dumps(c) + "\n\n",
|
234 |
+
):
|
235 |
+
yield optimized_chunk
|
236 |
+
else:
|
237 |
+
# 如果没有文本内容(如工具调用等),整块输出
|
238 |
+
yield "data: " + json.dumps(response_data) + "\n\n"
|
239 |
+
logger.info("Streaming completed successfully")
|
240 |
+
is_success = True
|
241 |
+
status_code = 200
|
242 |
+
break
|
243 |
+
except Exception as e:
|
244 |
+
retries += 1
|
245 |
+
is_success = False
|
246 |
+
error_log_msg = str(e)
|
247 |
+
logger.warning(
|
248 |
+
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
249 |
+
)
|
250 |
+
match = re.search(r"status code (\d+)", error_log_msg)
|
251 |
+
if match:
|
252 |
+
status_code = int(match.group(1))
|
253 |
+
else:
|
254 |
+
status_code = 500
|
255 |
+
|
256 |
+
await add_error_log(
|
257 |
+
gemini_key=current_attempt_key,
|
258 |
+
model_name=model,
|
259 |
+
error_type="gemini-chat-stream",
|
260 |
+
error_log=error_log_msg,
|
261 |
+
error_code=status_code,
|
262 |
+
request_msg=payload
|
263 |
+
)
|
264 |
+
|
265 |
+
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
266 |
+
if api_key:
|
267 |
+
logger.info(f"Switched to new API key: {api_key}")
|
268 |
+
else:
|
269 |
+
logger.error(f"No valid API key available after {retries} retries.")
|
270 |
+
break
|
271 |
+
|
272 |
+
if retries >= max_retries:
|
273 |
+
logger.error(
|
274 |
+
f"Max retries ({max_retries}) reached for streaming."
|
275 |
+
)
|
276 |
+
break
|
277 |
+
finally:
|
278 |
+
end_time = time.perf_counter()
|
279 |
+
latency_ms = int((end_time - start_time) * 1000)
|
280 |
+
await add_request_log(
|
281 |
+
model_name=model,
|
282 |
+
api_key=final_api_key,
|
283 |
+
is_success=is_success,
|
284 |
+
status_code=status_code,
|
285 |
+
latency_ms=latency_ms,
|
286 |
+
request_time=request_datetime
|
287 |
+
)
|
app/service/chat/openai_chat_service.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/services/chat_service.py
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
import datetime
|
5 |
+
import json
|
6 |
+
import re
|
7 |
+
import time
|
8 |
+
from copy import deepcopy
|
9 |
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
10 |
+
|
11 |
+
from app.config.config import settings
|
12 |
+
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
13 |
+
from app.database.services import (
|
14 |
+
add_error_log,
|
15 |
+
add_request_log,
|
16 |
+
)
|
17 |
+
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
18 |
+
from app.handler.message_converter import OpenAIMessageConverter
|
19 |
+
from app.handler.response_handler import OpenAIResponseHandler
|
20 |
+
from app.handler.stream_optimizer import openai_optimizer
|
21 |
+
from app.log.logger import get_openai_logger
|
22 |
+
from app.service.client.api_client import GeminiApiClient
|
23 |
+
from app.service.image.image_create_service import ImageCreateService
|
24 |
+
from app.service.key.key_manager import KeyManager
|
25 |
+
|
26 |
+
logger = get_openai_logger()
|
27 |
+
|
28 |
+
|
29 |
+
def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
|
30 |
+
"""判断消息是否包含图片、音频或视频部分 (inline_data)"""
|
31 |
+
for content in contents:
|
32 |
+
if content and "parts" in content and isinstance(content["parts"], list):
|
33 |
+
for part in content["parts"]:
|
34 |
+
if isinstance(part, dict) and "inline_data" in part:
|
35 |
+
return True
|
36 |
+
return False
|
37 |
+
|
38 |
+
|
39 |
+
def _build_tools(
|
40 |
+
request: ChatRequest, messages: List[Dict[str, Any]]
|
41 |
+
) -> List[Dict[str, Any]]:
|
42 |
+
"""构建工具"""
|
43 |
+
tool = dict()
|
44 |
+
model = request.model
|
45 |
+
|
46 |
+
if (
|
47 |
+
settings.TOOLS_CODE_EXECUTION_ENABLED
|
48 |
+
and not (
|
49 |
+
model.endswith("-search")
|
50 |
+
or "-thinking" in model
|
51 |
+
or model.endswith("-image")
|
52 |
+
or model.endswith("-image-generation")
|
53 |
+
)
|
54 |
+
and not _has_media_parts(messages)
|
55 |
+
):
|
56 |
+
tool["codeExecution"] = {}
|
57 |
+
logger.debug("Code execution tool enabled.")
|
58 |
+
elif _has_media_parts(messages):
|
59 |
+
logger.debug("Code execution tool disabled due to media parts presence.")
|
60 |
+
|
61 |
+
if model.endswith("-search"):
|
62 |
+
tool["googleSearch"] = {}
|
63 |
+
|
64 |
+
# 将 request 中的 tools 合并到 tools 中
|
65 |
+
if request.tools:
|
66 |
+
function_declarations = []
|
67 |
+
for item in request.tools:
|
68 |
+
if not item or not isinstance(item, dict):
|
69 |
+
continue
|
70 |
+
|
71 |
+
if item.get("type", "") == "function" and item.get("function"):
|
72 |
+
function = deepcopy(item.get("function"))
|
73 |
+
parameters = function.get("parameters", {})
|
74 |
+
if parameters.get("type") == "object" and not parameters.get(
|
75 |
+
"properties", {}
|
76 |
+
):
|
77 |
+
function.pop("parameters", None)
|
78 |
+
|
79 |
+
function_declarations.append(function)
|
80 |
+
|
81 |
+
if function_declarations:
|
82 |
+
# 按照 function 的 name 去重
|
83 |
+
names, functions = set(), []
|
84 |
+
for fc in function_declarations:
|
85 |
+
if fc.get("name") not in names:
|
86 |
+
if fc.get("name")=="googleSearch":
|
87 |
+
# cherry开启内置搜索时,添加googleSearch工具
|
88 |
+
tool["googleSearch"] = {}
|
89 |
+
else:
|
90 |
+
# 其他函数,添加到functionDeclarations中
|
91 |
+
names.add(fc.get("name"))
|
92 |
+
functions.append(fc)
|
93 |
+
|
94 |
+
tool["functionDeclarations"] = functions
|
95 |
+
|
96 |
+
# 解决 "Tool use with function calling is unsupported" 问题
|
97 |
+
if tool.get("functionDeclarations"):
|
98 |
+
tool.pop("googleSearch", None)
|
99 |
+
tool.pop("codeExecution", None)
|
100 |
+
|
101 |
+
return [tool] if tool else []
|
102 |
+
|
103 |
+
|
104 |
+
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
105 |
+
"""获取安全设置"""
|
106 |
+
# if (
|
107 |
+
# "2.0" in model
|
108 |
+
# and "gemini-2.0-flash-thinking-exp" not in model
|
109 |
+
# and "gemini-2.0-pro-exp" not in model
|
110 |
+
# ):
|
111 |
+
if model == "gemini-2.0-flash-exp":
|
112 |
+
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
113 |
+
return settings.SAFETY_SETTINGS
|
114 |
+
|
115 |
+
|
116 |
+
def _build_payload(
|
117 |
+
request: ChatRequest,
|
118 |
+
messages: List[Dict[str, Any]],
|
119 |
+
instruction: Optional[Dict[str, Any]] = None,
|
120 |
+
) -> Dict[str, Any]:
|
121 |
+
"""构建请求payload"""
|
122 |
+
payload = {
|
123 |
+
"contents": messages,
|
124 |
+
"generationConfig": {
|
125 |
+
"temperature": request.temperature,
|
126 |
+
"stopSequences": request.stop,
|
127 |
+
"topP": request.top_p,
|
128 |
+
"topK": request.top_k,
|
129 |
+
},
|
130 |
+
"tools": _build_tools(request, messages),
|
131 |
+
"safetySettings": _get_safety_settings(request.model),
|
132 |
+
}
|
133 |
+
if request.max_tokens is not None:
|
134 |
+
payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
|
135 |
+
if request.model.endswith("-image") or request.model.endswith("-image-generation"):
|
136 |
+
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
137 |
+
if request.model.endswith("-non-thinking"):
|
138 |
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
139 |
+
if request.model in settings.THINKING_BUDGET_MAP:
|
140 |
+
payload["generationConfig"]["thinkingConfig"] = {
|
141 |
+
"thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
|
142 |
+
}
|
143 |
+
|
144 |
+
if (
|
145 |
+
instruction
|
146 |
+
and isinstance(instruction, dict)
|
147 |
+
and instruction.get("role") == "system"
|
148 |
+
and instruction.get("parts")
|
149 |
+
and not request.model.endswith("-image")
|
150 |
+
and not request.model.endswith("-image-generation")
|
151 |
+
):
|
152 |
+
payload["systemInstruction"] = instruction
|
153 |
+
|
154 |
+
return payload
|
155 |
+
|
156 |
+
|
157 |
+
class OpenAIChatService:
|
158 |
+
"""聊天服务"""
|
159 |
+
|
160 |
+
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
161 |
+
self.message_converter = OpenAIMessageConverter()
|
162 |
+
self.response_handler = OpenAIResponseHandler(config=None)
|
163 |
+
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
164 |
+
self.key_manager = key_manager
|
165 |
+
self.image_create_service = ImageCreateService()
|
166 |
+
|
167 |
+
def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
|
168 |
+
"""从OpenAI响应块中提取文本内容"""
|
169 |
+
if not chunk.get("choices"):
|
170 |
+
return ""
|
171 |
+
|
172 |
+
choice = chunk["choices"][0]
|
173 |
+
if "delta" in choice and "content" in choice["delta"]:
|
174 |
+
return choice["delta"]["content"]
|
175 |
+
return ""
|
176 |
+
|
177 |
+
def _create_char_openai_chunk(
|
178 |
+
self, original_chunk: Dict[str, Any], text: str
|
179 |
+
) -> Dict[str, Any]:
|
180 |
+
"""创建包含指定文本的OpenAI响应块"""
|
181 |
+
chunk_copy = json.loads(json.dumps(original_chunk))
|
182 |
+
if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
|
183 |
+
chunk_copy["choices"][0]["delta"]["content"] = text
|
184 |
+
return chunk_copy
|
185 |
+
|
186 |
+
async def create_chat_completion(
|
187 |
+
self,
|
188 |
+
request: ChatRequest,
|
189 |
+
api_key: str,
|
190 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
191 |
+
"""创建聊天完成"""
|
192 |
+
messages, instruction = self.message_converter.convert(request.messages)
|
193 |
+
|
194 |
+
payload = _build_payload(request, messages, instruction)
|
195 |
+
|
196 |
+
if request.stream:
|
197 |
+
return self._handle_stream_completion(request.model, payload, api_key)
|
198 |
+
return await self._handle_normal_completion(request.model, payload, api_key)
|
199 |
+
|
200 |
+
async def _handle_normal_completion(
|
201 |
+
self, model: str, payload: Dict[str, Any], api_key: str
|
202 |
+
) -> Dict[str, Any]:
|
203 |
+
"""处理普通聊天完成"""
|
204 |
+
start_time = time.perf_counter()
|
205 |
+
request_datetime = datetime.datetime.now()
|
206 |
+
is_success = False
|
207 |
+
status_code = None
|
208 |
+
response = None
|
209 |
+
try:
|
210 |
+
response = await self.api_client.generate_content(payload, model, api_key)
|
211 |
+
usage_metadata = response.get("usageMetadata", {})
|
212 |
+
is_success = True
|
213 |
+
status_code = 200
|
214 |
+
return self.response_handler.handle_response(
|
215 |
+
response,
|
216 |
+
model,
|
217 |
+
stream=False,
|
218 |
+
finish_reason="stop",
|
219 |
+
usage_metadata=usage_metadata,
|
220 |
+
)
|
221 |
+
except Exception as e:
|
222 |
+
is_success = False
|
223 |
+
error_log_msg = str(e)
|
224 |
+
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
225 |
+
match = re.search(r"status code (\d+)", error_log_msg)
|
226 |
+
if match:
|
227 |
+
status_code = int(match.group(1))
|
228 |
+
else:
|
229 |
+
status_code = 500
|
230 |
+
|
231 |
+
await add_error_log(
|
232 |
+
gemini_key=api_key,
|
233 |
+
model_name=model,
|
234 |
+
error_type="openai-chat-non-stream",
|
235 |
+
error_log=error_log_msg,
|
236 |
+
error_code=status_code,
|
237 |
+
request_msg=payload,
|
238 |
+
)
|
239 |
+
raise e
|
240 |
+
finally:
|
241 |
+
end_time = time.perf_counter()
|
242 |
+
latency_ms = int((end_time - start_time) * 1000)
|
243 |
+
await add_request_log(
|
244 |
+
model_name=model,
|
245 |
+
api_key=api_key,
|
246 |
+
is_success=is_success,
|
247 |
+
status_code=status_code,
|
248 |
+
latency_ms=latency_ms,
|
249 |
+
request_time=request_datetime,
|
250 |
+
)
|
251 |
+
|
252 |
+
async def _fake_stream_logic_impl(
|
253 |
+
self, model: str, payload: Dict[str, Any], api_key: str
|
254 |
+
) -> AsyncGenerator[str, None]:
|
255 |
+
"""处理伪流式 (fake stream) 的核心逻辑"""
|
256 |
+
logger.info(
|
257 |
+
f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint."
|
258 |
+
)
|
259 |
+
keep_sending_empty_data = True
|
260 |
+
|
261 |
+
async def send_empty_data_locally() -> AsyncGenerator[str, None]:
|
262 |
+
"""定期发送空数据以保持连接"""
|
263 |
+
while keep_sending_empty_data:
|
264 |
+
await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS)
|
265 |
+
if keep_sending_empty_data:
|
266 |
+
empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
267 |
+
yield f"data: {json.dumps(empty_chunk)}\n\n"
|
268 |
+
logger.debug("Sent empty data chunk for fake stream heartbeat.")
|
269 |
+
|
270 |
+
empty_data_generator = send_empty_data_locally()
|
271 |
+
api_response_task = asyncio.create_task(
|
272 |
+
self.api_client.generate_content(payload, model, api_key)
|
273 |
+
)
|
274 |
+
|
275 |
+
try:
|
276 |
+
while not api_response_task.done():
|
277 |
+
try:
|
278 |
+
next_empty_chunk = await asyncio.wait_for(
|
279 |
+
empty_data_generator.__anext__(), timeout=0.1
|
280 |
+
)
|
281 |
+
yield next_empty_chunk
|
282 |
+
except asyncio.TimeoutError:
|
283 |
+
pass
|
284 |
+
except (
|
285 |
+
StopAsyncIteration
|
286 |
+
):
|
287 |
+
break
|
288 |
+
|
289 |
+
response = await api_response_task
|
290 |
+
finally:
|
291 |
+
keep_sending_empty_data = False
|
292 |
+
|
293 |
+
if response and response.get("candidates"):
|
294 |
+
response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {}))
|
295 |
+
yield f"data: {json.dumps(response)}\n\n"
|
296 |
+
logger.info(f"Sent full response content for fake stream: {model}")
|
297 |
+
else:
|
298 |
+
error_message = "Failed to get response from model"
|
299 |
+
if (
|
300 |
+
response and isinstance(response, dict) and response.get("error")
|
301 |
+
):
|
302 |
+
error_details = response.get("error")
|
303 |
+
if isinstance(error_details, dict):
|
304 |
+
error_message = error_details.get("message", error_message)
|
305 |
+
|
306 |
+
logger.error(
|
307 |
+
f"No candidates or error in response for fake stream model {model}: {response}"
|
308 |
+
)
|
309 |
+
error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
|
310 |
+
yield f"data: {json.dumps(error_chunk)}\n\n"
|
311 |
+
|
312 |
+
async def _real_stream_logic_impl(
|
313 |
+
self, model: str, payload: Dict[str, Any], api_key: str
|
314 |
+
) -> AsyncGenerator[str, None]:
|
315 |
+
"""处理真实流式 (real stream) 的核心逻辑"""
|
316 |
+
tool_call_flag = False
|
317 |
+
usage_metadata = None
|
318 |
+
async for line in self.api_client.stream_generate_content(
|
319 |
+
payload, model, api_key
|
320 |
+
):
|
321 |
+
if line.startswith("data:"):
|
322 |
+
chunk_str = line[6:]
|
323 |
+
if not chunk_str or chunk_str.isspace():
|
324 |
+
logger.debug(
|
325 |
+
f"Received empty data line for model {model}, skipping."
|
326 |
+
)
|
327 |
+
continue
|
328 |
+
try:
|
329 |
+
chunk = json.loads(chunk_str)
|
330 |
+
usage_metadata = chunk.get("usageMetadata", {})
|
331 |
+
except json.JSONDecodeError:
|
332 |
+
logger.error(
|
333 |
+
f"Failed to decode JSON from stream for model {model}: {chunk_str}"
|
334 |
+
)
|
335 |
+
continue
|
336 |
+
openai_chunk = self.response_handler.handle_response(
|
337 |
+
chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata
|
338 |
+
)
|
339 |
+
if openai_chunk:
|
340 |
+
text = self._extract_text_from_openai_chunk(openai_chunk)
|
341 |
+
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
342 |
+
async for (
|
343 |
+
optimized_chunk_data
|
344 |
+
) in openai_optimizer.optimize_stream_output(
|
345 |
+
text,
|
346 |
+
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
347 |
+
lambda c: f"data: {json.dumps(c)}\n\n",
|
348 |
+
):
|
349 |
+
yield optimized_chunk_data
|
350 |
+
else:
|
351 |
+
if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"):
|
352 |
+
tool_call_flag = True
|
353 |
+
|
354 |
+
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
355 |
+
|
356 |
+
if tool_call_flag:
|
357 |
+
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n"
|
358 |
+
else:
|
359 |
+
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n"
|
360 |
+
|
361 |
+
async def _handle_stream_completion(
|
362 |
+
self, model: str, payload: Dict[str, Any], api_key: str
|
363 |
+
) -> AsyncGenerator[str, None]:
|
364 |
+
"""处理流式聊天完成,添加重试逻辑和假流式支持"""
|
365 |
+
retries = 0
|
366 |
+
max_retries = settings.MAX_RETRIES
|
367 |
+
is_success = False
|
368 |
+
status_code = None
|
369 |
+
final_api_key = api_key
|
370 |
+
|
371 |
+
while retries < max_retries:
|
372 |
+
start_time = time.perf_counter()
|
373 |
+
request_datetime = datetime.datetime.now()
|
374 |
+
current_attempt_key = final_api_key
|
375 |
+
|
376 |
+
try:
|
377 |
+
stream_generator = None
|
378 |
+
if settings.FAKE_STREAM_ENABLED:
|
379 |
+
logger.info(
|
380 |
+
f"Using fake stream logic for model: {model}, Attempt: {retries + 1}"
|
381 |
+
)
|
382 |
+
stream_generator = self._fake_stream_logic_impl(
|
383 |
+
model, payload, current_attempt_key
|
384 |
+
)
|
385 |
+
else:
|
386 |
+
logger.info(
|
387 |
+
f"Using real stream logic for model: {model}, Attempt: {retries + 1}"
|
388 |
+
)
|
389 |
+
stream_generator = self._real_stream_logic_impl(
|
390 |
+
model, payload, current_attempt_key
|
391 |
+
)
|
392 |
+
|
393 |
+
async for chunk_data in stream_generator:
|
394 |
+
yield chunk_data
|
395 |
+
|
396 |
+
yield "data: [DONE]\n\n"
|
397 |
+
logger.info(
|
398 |
+
f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}"
|
399 |
+
)
|
400 |
+
is_success = True
|
401 |
+
status_code = 200
|
402 |
+
break
|
403 |
+
|
404 |
+
except Exception as e:
|
405 |
+
retries += 1
|
406 |
+
is_success = False
|
407 |
+
error_log_msg = str(e)
|
408 |
+
logger.warning(
|
409 |
+
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
|
410 |
+
)
|
411 |
+
|
412 |
+
match = re.search(r"status code (\\d+)", error_log_msg)
|
413 |
+
if match:
|
414 |
+
status_code = int(match.group(1))
|
415 |
+
else:
|
416 |
+
if isinstance(e, asyncio.TimeoutError):
|
417 |
+
status_code = 408
|
418 |
+
else:
|
419 |
+
status_code = 500
|
420 |
+
|
421 |
+
await add_error_log(
|
422 |
+
gemini_key=current_attempt_key,
|
423 |
+
model_name=model,
|
424 |
+
error_type="openai-chat-stream",
|
425 |
+
error_log=error_log_msg,
|
426 |
+
error_code=status_code,
|
427 |
+
request_msg=payload,
|
428 |
+
)
|
429 |
+
|
430 |
+
if self.key_manager:
|
431 |
+
new_api_key = await self.key_manager.handle_api_failure(
|
432 |
+
current_attempt_key, retries
|
433 |
+
)
|
434 |
+
if new_api_key and new_api_key != current_attempt_key:
|
435 |
+
final_api_key = new_api_key
|
436 |
+
logger.info(
|
437 |
+
f"Switched to new API key for next attempt: {final_api_key}"
|
438 |
+
)
|
439 |
+
elif not new_api_key:
|
440 |
+
logger.error(
|
441 |
+
f"No valid API key available after {retries} retries, ceasing attempts for this request."
|
442 |
+
)
|
443 |
+
break
|
444 |
+
else:
|
445 |
+
logger.error(
|
446 |
+
"KeyManager not available, cannot switch API key. Ceasing attempts for this request."
|
447 |
+
)
|
448 |
+
break
|
449 |
+
|
450 |
+
if retries >= max_retries:
|
451 |
+
logger.error(
|
452 |
+
f"Max retries ({max_retries}) reached for streaming model {model}."
|
453 |
+
)
|
454 |
+
finally:
|
455 |
+
end_time = time.perf_counter()
|
456 |
+
latency_ms = int((end_time - start_time) * 1000)
|
457 |
+
await add_request_log(
|
458 |
+
model_name=model,
|
459 |
+
api_key=current_attempt_key,
|
460 |
+
is_success=is_success,
|
461 |
+
status_code=status_code,
|
462 |
+
latency_ms=latency_ms,
|
463 |
+
request_time=request_datetime,
|
464 |
+
)
|
465 |
+
|
466 |
+
if not is_success:
|
467 |
+
logger.error(
|
468 |
+
f"Streaming failed permanently for model {model} after {retries} attempts."
|
469 |
+
)
|
470 |
+
yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
|
471 |
+
yield "data: [DONE]\n\n"
|
472 |
+
|
473 |
+
async def create_image_chat_completion(
|
474 |
+
self, request: ChatRequest, api_key: str
|
475 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
476 |
+
|
477 |
+
image_generate_request = ImageGenerationRequest()
|
478 |
+
image_generate_request.prompt = request.messages[-1]["content"]
|
479 |
+
image_res = self.image_create_service.generate_images_chat(
|
480 |
+
image_generate_request
|
481 |
+
)
|
482 |
+
|
483 |
+
if request.stream:
|
484 |
+
return self._handle_stream_image_completion(
|
485 |
+
request.model, image_res, api_key
|
486 |
+
)
|
487 |
+
else:
|
488 |
+
return await self._handle_normal_image_completion(
|
489 |
+
request.model, image_res, api_key
|
490 |
+
)
|
491 |
+
|
492 |
+
async def _handle_stream_image_completion(
|
493 |
+
self, model: str, image_data: str, api_key: str
|
494 |
+
) -> AsyncGenerator[str, None]:
|
495 |
+
logger.info(f"Starting stream image completion for model: {model}")
|
496 |
+
start_time = time.perf_counter()
|
497 |
+
request_datetime = datetime.datetime.now()
|
498 |
+
is_success = False
|
499 |
+
status_code = None
|
500 |
+
|
501 |
+
try:
|
502 |
+
if image_data:
|
503 |
+
openai_chunk = self.response_handler.handle_image_chat_response(
|
504 |
+
image_data, model, stream=True, finish_reason=None
|
505 |
+
)
|
506 |
+
if openai_chunk:
|
507 |
+
# 提取文本内容
|
508 |
+
text = self._extract_text_from_openai_chunk(openai_chunk)
|
509 |
+
if text:
|
510 |
+
# 使用流式输出优化器处理文本输出
|
511 |
+
async for (
|
512 |
+
optimized_chunk
|
513 |
+
) in openai_optimizer.optimize_stream_output(
|
514 |
+
text,
|
515 |
+
lambda t: self._create_char_openai_chunk(openai_chunk, t),
|
516 |
+
lambda c: f"data: {json.dumps(c)}\n\n",
|
517 |
+
):
|
518 |
+
yield optimized_chunk
|
519 |
+
else:
|
520 |
+
# 如果没有文本内容(如图片URL等),整块输出
|
521 |
+
yield f"data: {json.dumps(openai_chunk)}\n\n"
|
522 |
+
yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
|
523 |
+
logger.info(
|
524 |
+
f"Stream image completion finished successfully for model: {model}"
|
525 |
+
)
|
526 |
+
is_success = True
|
527 |
+
status_code = 200
|
528 |
+
yield "data: [DONE]\n\n"
|
529 |
+
except Exception as e:
|
530 |
+
is_success = False
|
531 |
+
error_log_msg = f"Stream image completion failed for model {model}: {e}"
|
532 |
+
logger.error(error_log_msg)
|
533 |
+
status_code = 500
|
534 |
+
await add_error_log(
|
535 |
+
gemini_key=api_key,
|
536 |
+
model_name=model,
|
537 |
+
error_type="openai-image-stream",
|
538 |
+
error_log=error_log_msg,
|
539 |
+
error_code=status_code,
|
540 |
+
request_msg={"image_data_truncated": image_data[:1000]},
|
541 |
+
)
|
542 |
+
yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
|
543 |
+
yield "data: [DONE]\n\n"
|
544 |
+
finally:
|
545 |
+
end_time = time.perf_counter()
|
546 |
+
latency_ms = int((end_time - start_time) * 1000)
|
547 |
+
logger.info(
|
548 |
+
f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
549 |
+
)
|
550 |
+
await add_request_log(
|
551 |
+
model_name=model,
|
552 |
+
api_key=api_key,
|
553 |
+
is_success=is_success,
|
554 |
+
status_code=status_code,
|
555 |
+
latency_ms=latency_ms,
|
556 |
+
request_time=request_datetime,
|
557 |
+
)
|
558 |
+
|
559 |
+
async def _handle_normal_image_completion(
|
560 |
+
self, model: str, image_data: str, api_key: str
|
561 |
+
) -> Dict[str, Any]:
|
562 |
+
logger.info(f"Starting normal image completion for model: {model}")
|
563 |
+
start_time = time.perf_counter()
|
564 |
+
request_datetime = datetime.datetime.now()
|
565 |
+
is_success = False
|
566 |
+
status_code = None
|
567 |
+
result = None
|
568 |
+
|
569 |
+
try:
|
570 |
+
result = self.response_handler.handle_image_chat_response(
|
571 |
+
image_data, model, stream=False, finish_reason="stop"
|
572 |
+
)
|
573 |
+
logger.info(
|
574 |
+
f"Normal image completion finished successfully for model: {model}"
|
575 |
+
)
|
576 |
+
is_success = True
|
577 |
+
status_code = 200
|
578 |
+
return result
|
579 |
+
except Exception as e:
|
580 |
+
is_success = False
|
581 |
+
error_log_msg = f"Normal image completion failed for model {model}: {e}"
|
582 |
+
logger.error(error_log_msg)
|
583 |
+
status_code = 500
|
584 |
+
await add_error_log(
|
585 |
+
gemini_key=api_key,
|
586 |
+
model_name=model,
|
587 |
+
error_type="openai-image-non-stream",
|
588 |
+
error_log=error_log_msg,
|
589 |
+
error_code=status_code,
|
590 |
+
request_msg={"image_data_truncated": image_data[:1000]},
|
591 |
+
)
|
592 |
+
raise e
|
593 |
+
finally:
|
594 |
+
end_time = time.perf_counter()
|
595 |
+
latency_ms = int((end_time - start_time) * 1000)
|
596 |
+
logger.info(
|
597 |
+
f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
|
598 |
+
)
|
599 |
+
await add_request_log(
|
600 |
+
model_name=model,
|
601 |
+
api_key=api_key,
|
602 |
+
is_success=is_success,
|
603 |
+
status_code=status_code,
|
604 |
+
latency_ms=latency_ms,
|
605 |
+
request_time=request_datetime,
|
606 |
+
)
|
app/service/chat/vertex_express_chat_service.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/services/chat_service.py
|
2 |
+
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import datetime
|
6 |
+
import time
|
7 |
+
from typing import Any, AsyncGenerator, Dict, List
|
8 |
+
from app.config.config import settings
|
9 |
+
from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
10 |
+
from app.domain.gemini_models import GeminiRequest
|
11 |
+
from app.handler.response_handler import GeminiResponseHandler
|
12 |
+
from app.handler.stream_optimizer import gemini_optimizer
|
13 |
+
from app.log.logger import get_gemini_logger
|
14 |
+
from app.service.client.api_client import GeminiApiClient
|
15 |
+
from app.service.key.key_manager import KeyManager
|
16 |
+
from app.database.services import add_error_log, add_request_log
|
17 |
+
|
18 |
+
logger = get_gemini_logger()
|
19 |
+
|
20 |
+
|
21 |
+
def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
|
22 |
+
"""判断消息是否包含图片部分"""
|
23 |
+
for content in contents:
|
24 |
+
if "parts" in content:
|
25 |
+
for part in content["parts"]:
|
26 |
+
if "image_url" in part or "inline_data" in part:
|
27 |
+
return True
|
28 |
+
return False
|
29 |
+
|
30 |
+
|
31 |
+
def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
32 |
+
"""构建工具"""
|
33 |
+
|
34 |
+
def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
|
35 |
+
record = dict()
|
36 |
+
for item in tools:
|
37 |
+
if not item or not isinstance(item, dict):
|
38 |
+
continue
|
39 |
+
|
40 |
+
for k, v in item.items():
|
41 |
+
if k == "functionDeclarations" and v and isinstance(v, list):
|
42 |
+
functions = record.get("functionDeclarations", [])
|
43 |
+
functions.extend(v)
|
44 |
+
record["functionDeclarations"] = functions
|
45 |
+
else:
|
46 |
+
record[k] = v
|
47 |
+
return record
|
48 |
+
|
49 |
+
tool = dict()
|
50 |
+
if payload and isinstance(payload, dict) and "tools" in payload:
|
51 |
+
if payload.get("tools") and isinstance(payload.get("tools"), dict):
|
52 |
+
payload["tools"] = [payload.get("tools")]
|
53 |
+
items = payload.get("tools", [])
|
54 |
+
if items and isinstance(items, list):
|
55 |
+
tool.update(_merge_tools(items))
|
56 |
+
|
57 |
+
if (
|
58 |
+
settings.TOOLS_CODE_EXECUTION_ENABLED
|
59 |
+
and not (model.endswith("-search") or "-thinking" in model)
|
60 |
+
and not _has_image_parts(payload.get("contents", []))
|
61 |
+
):
|
62 |
+
tool["codeExecution"] = {}
|
63 |
+
if model.endswith("-search"):
|
64 |
+
tool["googleSearch"] = {}
|
65 |
+
|
66 |
+
# 解决 "Tool use with function calling is unsupported" 问题
|
67 |
+
if tool.get("functionDeclarations"):
|
68 |
+
tool.pop("googleSearch", None)
|
69 |
+
tool.pop("codeExecution", None)
|
70 |
+
|
71 |
+
return [tool] if tool else []
|
72 |
+
|
73 |
+
|
74 |
+
def _get_safety_settings(model: str) -> List[Dict[str, str]]:
|
75 |
+
"""获取安全设置"""
|
76 |
+
if model == "gemini-2.0-flash-exp":
|
77 |
+
return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
|
78 |
+
return settings.SAFETY_SETTINGS
|
79 |
+
|
80 |
+
|
81 |
+
def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
|
82 |
+
"""构建请求payload"""
|
83 |
+
request_dict = request.model_dump()
|
84 |
+
if request.generationConfig:
|
85 |
+
if request.generationConfig.maxOutputTokens is None:
|
86 |
+
# 如果未指定最大输出长度,则不传递该字段,解决截断的问题
|
87 |
+
request_dict["generationConfig"].pop("maxOutputTokens")
|
88 |
+
|
89 |
+
payload = {
|
90 |
+
"contents": request_dict.get("contents", []),
|
91 |
+
"tools": _build_tools(model, request_dict),
|
92 |
+
"safetySettings": _get_safety_settings(model),
|
93 |
+
"generationConfig": request_dict.get("generationConfig"),
|
94 |
+
"systemInstruction": request_dict.get("systemInstruction"),
|
95 |
+
}
|
96 |
+
|
97 |
+
if model.endswith("-image") or model.endswith("-image-generation"):
|
98 |
+
payload.pop("systemInstruction")
|
99 |
+
payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
|
100 |
+
|
101 |
+
if model.endswith("-non-thinking"):
|
102 |
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
|
103 |
+
if model in settings.THINKING_BUDGET_MAP:
|
104 |
+
payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
|
105 |
+
|
106 |
+
return payload
|
107 |
+
|
108 |
+
|
109 |
+
class GeminiChatService:
|
110 |
+
"""聊天服务"""
|
111 |
+
|
112 |
+
def __init__(self, base_url: str, key_manager: KeyManager):
|
113 |
+
self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
|
114 |
+
self.key_manager = key_manager
|
115 |
+
self.response_handler = GeminiResponseHandler()
|
116 |
+
|
117 |
+
def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
|
118 |
+
"""从响应中提取文本内容"""
|
119 |
+
if not response.get("candidates"):
|
120 |
+
return ""
|
121 |
+
|
122 |
+
candidate = response["candidates"][0]
|
123 |
+
content = candidate.get("content", {})
|
124 |
+
parts = content.get("parts", [])
|
125 |
+
|
126 |
+
if parts and "text" in parts[0]:
|
127 |
+
return parts[0].get("text", "")
|
128 |
+
return ""
|
129 |
+
|
130 |
+
def _create_char_response(
|
131 |
+
self, original_response: Dict[str, Any], text: str
|
132 |
+
) -> Dict[str, Any]:
|
133 |
+
"""创建包含指定文本的响应"""
|
134 |
+
response_copy = json.loads(json.dumps(original_response)) # 深拷贝
|
135 |
+
if response_copy.get("candidates") and response_copy["candidates"][0].get(
|
136 |
+
"content", {}
|
137 |
+
).get("parts"):
|
138 |
+
response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
|
139 |
+
return response_copy
|
140 |
+
|
141 |
+
async def generate_content(
|
142 |
+
self, model: str, request: GeminiRequest, api_key: str
|
143 |
+
) -> Dict[str, Any]:
|
144 |
+
"""生成内容"""
|
145 |
+
payload = _build_payload(model, request)
|
146 |
+
start_time = time.perf_counter()
|
147 |
+
request_datetime = datetime.datetime.now()
|
148 |
+
is_success = False
|
149 |
+
status_code = None
|
150 |
+
response = None
|
151 |
+
|
152 |
+
try:
|
153 |
+
response = await self.api_client.generate_content(payload, model, api_key)
|
154 |
+
is_success = True
|
155 |
+
status_code = 200
|
156 |
+
return self.response_handler.handle_response(response, model, stream=False)
|
157 |
+
except Exception as e:
|
158 |
+
is_success = False
|
159 |
+
error_log_msg = str(e)
|
160 |
+
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
161 |
+
match = re.search(r"status code (\d+)", error_log_msg)
|
162 |
+
if match:
|
163 |
+
status_code = int(match.group(1))
|
164 |
+
else:
|
165 |
+
status_code = 500
|
166 |
+
|
167 |
+
await add_error_log(
|
168 |
+
gemini_key=api_key,
|
169 |
+
model_name=model,
|
170 |
+
error_type="gemini-chat-non-stream",
|
171 |
+
error_log=error_log_msg,
|
172 |
+
error_code=status_code,
|
173 |
+
request_msg=payload
|
174 |
+
)
|
175 |
+
raise e
|
176 |
+
finally:
|
177 |
+
end_time = time.perf_counter()
|
178 |
+
latency_ms = int((end_time - start_time) * 1000)
|
179 |
+
await add_request_log(
|
180 |
+
model_name=model,
|
181 |
+
api_key=api_key,
|
182 |
+
is_success=is_success,
|
183 |
+
status_code=status_code,
|
184 |
+
latency_ms=latency_ms,
|
185 |
+
request_time=request_datetime
|
186 |
+
)
|
187 |
+
|
188 |
+
async def stream_generate_content(
|
189 |
+
self, model: str, request: GeminiRequest, api_key: str
|
190 |
+
) -> AsyncGenerator[str, None]:
|
191 |
+
"""流式生成内容"""
|
192 |
+
retries = 0
|
193 |
+
max_retries = settings.MAX_RETRIES
|
194 |
+
payload = _build_payload(model, request)
|
195 |
+
is_success = False
|
196 |
+
status_code = None
|
197 |
+
final_api_key = api_key
|
198 |
+
|
199 |
+
while retries < max_retries:
|
200 |
+
request_datetime = datetime.datetime.now()
|
201 |
+
start_time = time.perf_counter()
|
202 |
+
current_attempt_key = api_key
|
203 |
+
final_api_key = current_attempt_key # Update final key used
|
204 |
+
try:
|
205 |
+
async for line in self.api_client.stream_generate_content(
|
206 |
+
payload, model, current_attempt_key
|
207 |
+
):
|
208 |
+
# print(line)
|
209 |
+
if line.startswith("data:"):
|
210 |
+
line = line[6:]
|
211 |
+
response_data = self.response_handler.handle_response(
|
212 |
+
json.loads(line), model, stream=True
|
213 |
+
)
|
214 |
+
text = self._extract_text_from_response(response_data)
|
215 |
+
# 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
|
216 |
+
if text and settings.STREAM_OPTIMIZER_ENABLED:
|
217 |
+
# 使用流式输出优化器处理文本输出
|
218 |
+
async for (
|
219 |
+
optimized_chunk
|
220 |
+
) in gemini_optimizer.optimize_stream_output(
|
221 |
+
text,
|
222 |
+
lambda t: self._create_char_response(response_data, t),
|
223 |
+
lambda c: "data: " + json.dumps(c) + "\n\n",
|
224 |
+
):
|
225 |
+
yield optimized_chunk
|
226 |
+
else:
|
227 |
+
# 如果没有文本内容(如工具调用等),整块输出
|
228 |
+
yield "data: " + json.dumps(response_data) + "\n\n"
|
229 |
+
logger.info("Streaming completed successfully")
|
230 |
+
is_success = True
|
231 |
+
status_code = 200
|
232 |
+
break
|
233 |
+
except Exception as e:
|
234 |
+
retries += 1
|
235 |
+
is_success = False
|
236 |
+
error_log_msg = str(e)
|
237 |
+
logger.warning(
|
238 |
+
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
239 |
+
)
|
240 |
+
match = re.search(r"status code (\d+)", error_log_msg)
|
241 |
+
if match:
|
242 |
+
status_code = int(match.group(1))
|
243 |
+
else:
|
244 |
+
status_code = 500
|
245 |
+
|
246 |
+
await add_error_log(
|
247 |
+
gemini_key=current_attempt_key,
|
248 |
+
model_name=model,
|
249 |
+
error_type="gemini-chat-stream",
|
250 |
+
error_log=error_log_msg,
|
251 |
+
error_code=status_code,
|
252 |
+
request_msg=payload
|
253 |
+
)
|
254 |
+
|
255 |
+
api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
|
256 |
+
if api_key:
|
257 |
+
logger.info(f"Switched to new API key: {api_key}")
|
258 |
+
else:
|
259 |
+
logger.error(f"No valid API key available after {retries} retries.")
|
260 |
+
break
|
261 |
+
|
262 |
+
if retries >= max_retries:
|
263 |
+
logger.error(
|
264 |
+
f"Max retries ({max_retries}) reached for streaming."
|
265 |
+
)
|
266 |
+
break
|
267 |
+
finally:
|
268 |
+
end_time = time.perf_counter()
|
269 |
+
latency_ms = int((end_time - start_time) * 1000)
|
270 |
+
await add_request_log(
|
271 |
+
model_name=model,
|
272 |
+
api_key=final_api_key,
|
273 |
+
is_success=is_success,
|
274 |
+
status_code=status_code,
|
275 |
+
latency_ms=latency_ms,
|
276 |
+
request_time=request_datetime
|
277 |
+
)
|
app/service/client/api_client.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/services/chat/api_client.py
|
2 |
+
|
3 |
+
from typing import Dict, Any, AsyncGenerator, Optional
|
4 |
+
import httpx
|
5 |
+
import random
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from app.config.config import settings
|
8 |
+
from app.log.logger import get_api_client_logger
|
9 |
+
from app.core.constants import DEFAULT_TIMEOUT
|
10 |
+
|
11 |
+
logger = get_api_client_logger()
|
12 |
+
|
13 |
+
class ApiClient(ABC):
|
14 |
+
"""API客户端基类"""
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class GeminiApiClient(ApiClient):
|
26 |
+
"""Gemini API客户端"""
|
27 |
+
|
28 |
+
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
29 |
+
self.base_url = base_url
|
30 |
+
self.timeout = timeout
|
31 |
+
|
32 |
+
def _get_real_model(self, model: str) -> str:
|
33 |
+
if model.endswith("-search"):
|
34 |
+
model = model[:-7]
|
35 |
+
if model.endswith("-image"):
|
36 |
+
model = model[:-6]
|
37 |
+
if model.endswith("-non-thinking"):
|
38 |
+
model = model[:-13]
|
39 |
+
if "-search" in model and "-non-thinking" in model:
|
40 |
+
model = model[:-20]
|
41 |
+
return model
|
42 |
+
|
43 |
+
async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
44 |
+
"""获取可用的 Gemini 模型列表"""
|
45 |
+
timeout = httpx.Timeout(timeout=5)
|
46 |
+
|
47 |
+
proxy_to_use = None
|
48 |
+
if settings.PROXIES:
|
49 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
50 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
51 |
+
else:
|
52 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
53 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
54 |
+
|
55 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
56 |
+
url = f"{self.base_url}/models?key={api_key}&pageSize=1000"
|
57 |
+
try:
|
58 |
+
response = await client.get(url)
|
59 |
+
response.raise_for_status()
|
60 |
+
return response.json()
|
61 |
+
except httpx.HTTPStatusError as e:
|
62 |
+
logger.error(f"获取模型列表失败: {e.response.status_code}")
|
63 |
+
logger.error(e.response.text)
|
64 |
+
return None
|
65 |
+
except httpx.RequestError as e:
|
66 |
+
logger.error(f"请求模型列表失败: {e}")
|
67 |
+
return None
|
68 |
+
|
69 |
+
async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
|
70 |
+
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
71 |
+
model = self._get_real_model(model)
|
72 |
+
|
73 |
+
proxy_to_use = None
|
74 |
+
if settings.PROXIES:
|
75 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
76 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
77 |
+
else:
|
78 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
79 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
80 |
+
|
81 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
82 |
+
url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
|
83 |
+
response = await client.post(url, json=payload)
|
84 |
+
if response.status_code != 200:
|
85 |
+
error_content = response.text
|
86 |
+
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
87 |
+
return response.json()
|
88 |
+
|
89 |
+
async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
|
90 |
+
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
91 |
+
model = self._get_real_model(model)
|
92 |
+
|
93 |
+
proxy_to_use = None
|
94 |
+
if settings.PROXIES:
|
95 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
96 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
97 |
+
else:
|
98 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
99 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
100 |
+
|
101 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
102 |
+
url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
|
103 |
+
async with client.stream(method="POST", url=url, json=payload) as response:
|
104 |
+
if response.status_code != 200:
|
105 |
+
error_content = await response.aread()
|
106 |
+
error_msg = error_content.decode("utf-8")
|
107 |
+
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
108 |
+
async for line in response.aiter_lines():
|
109 |
+
yield line
|
110 |
+
|
111 |
+
|
112 |
+
class OpenaiApiClient(ApiClient):
|
113 |
+
"""OpenAI API客户端"""
|
114 |
+
|
115 |
+
def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
|
116 |
+
self.base_url = base_url
|
117 |
+
self.timeout = timeout
|
118 |
+
|
119 |
+
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
120 |
+
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
121 |
+
|
122 |
+
proxy_to_use = None
|
123 |
+
if settings.PROXIES:
|
124 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
125 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
126 |
+
else:
|
127 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
128 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
129 |
+
|
130 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
131 |
+
url = f"{self.base_url}/openai/models"
|
132 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
133 |
+
response = await client.get(url, headers=headers)
|
134 |
+
if response.status_code != 200:
|
135 |
+
error_content = response.text
|
136 |
+
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
137 |
+
return response.json()
|
138 |
+
|
139 |
+
async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
140 |
+
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
141 |
+
logger.info(f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}")
|
142 |
+
proxy_to_use = None
|
143 |
+
if settings.PROXIES:
|
144 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
145 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
146 |
+
else:
|
147 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
148 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
149 |
+
|
150 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
151 |
+
url = f"{self.base_url}/openai/chat/completions"
|
152 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
153 |
+
response = await client.post(url, json=payload, headers=headers)
|
154 |
+
if response.status_code != 200:
|
155 |
+
error_content = response.text
|
156 |
+
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
157 |
+
return response.json()
|
158 |
+
|
159 |
+
async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
|
160 |
+
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
161 |
+
proxy_to_use = None
|
162 |
+
if settings.PROXIES:
|
163 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
164 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
165 |
+
else:
|
166 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
167 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
168 |
+
|
169 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
170 |
+
url = f"{self.base_url}/openai/chat/completions"
|
171 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
172 |
+
async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
|
173 |
+
if response.status_code != 200:
|
174 |
+
error_content = await response.aread()
|
175 |
+
error_msg = error_content.decode("utf-8")
|
176 |
+
raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
|
177 |
+
async for line in response.aiter_lines():
|
178 |
+
yield line
|
179 |
+
|
180 |
+
async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
|
181 |
+
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
182 |
+
|
183 |
+
proxy_to_use = None
|
184 |
+
if settings.PROXIES:
|
185 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
186 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
187 |
+
else:
|
188 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
189 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
190 |
+
|
191 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
192 |
+
url = f"{self.base_url}/openai/embeddings"
|
193 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
194 |
+
payload = {
|
195 |
+
"input": input,
|
196 |
+
"model": model,
|
197 |
+
}
|
198 |
+
response = await client.post(url, json=payload, headers=headers)
|
199 |
+
if response.status_code != 200:
|
200 |
+
error_content = response.text
|
201 |
+
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
202 |
+
return response.json()
|
203 |
+
|
204 |
+
async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
|
205 |
+
timeout = httpx.Timeout(self.timeout, read=self.timeout)
|
206 |
+
|
207 |
+
proxy_to_use = None
|
208 |
+
if settings.PROXIES:
|
209 |
+
if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
|
210 |
+
proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
|
211 |
+
else:
|
212 |
+
proxy_to_use = random.choice(settings.PROXIES)
|
213 |
+
logger.info(f"Using proxy for getting models: {proxy_to_use}")
|
214 |
+
|
215 |
+
async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
|
216 |
+
url = f"{self.base_url}/openai/images/generations"
|
217 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
218 |
+
response = await client.post(url, json=payload, headers=headers)
|
219 |
+
if response.status_code != 200:
|
220 |
+
error_content = response.text
|
221 |
+
raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
|
222 |
+
return response.json()
|
app/service/config/config_service.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
配置服务模块
|
3 |
+
"""
|
4 |
+
|
5 |
+
import datetime
|
6 |
+
import json
|
7 |
+
from typing import Any, Dict, List
|
8 |
+
|
9 |
+
from dotenv import find_dotenv, load_dotenv
|
10 |
+
from fastapi import HTTPException
|
11 |
+
from sqlalchemy import insert, update
|
12 |
+
|
13 |
+
from app.config.config import Settings as ConfigSettings
|
14 |
+
from app.config.config import settings
|
15 |
+
from app.database.connection import database
|
16 |
+
from app.database.models import Settings
|
17 |
+
from app.database.services import get_all_settings
|
18 |
+
from app.log.logger import get_config_routes_logger
|
19 |
+
from app.service.key.key_manager import (
|
20 |
+
get_key_manager_instance,
|
21 |
+
reset_key_manager_instance,
|
22 |
+
)
|
23 |
+
from app.service.model.model_service import ModelService
|
24 |
+
|
25 |
+
logger = get_config_routes_logger()
|
26 |
+
|
27 |
+
|
28 |
+
class ConfigService:
|
29 |
+
"""配置服务类,用于管理应用程序配置"""
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
async def get_config() -> Dict[str, Any]:
|
33 |
+
return settings.model_dump()
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
async def update_config(config_data: Dict[str, Any]) -> Dict[str, Any]:
|
37 |
+
for key, value in config_data.items():
|
38 |
+
if hasattr(settings, key):
|
39 |
+
setattr(settings, key, value)
|
40 |
+
logger.debug(f"Updated setting in memory: {key}")
|
41 |
+
|
42 |
+
# 获取现有设置
|
43 |
+
existing_settings_raw: List[Dict[str, Any]] = await get_all_settings()
|
44 |
+
existing_settings_map: Dict[str, Dict[str, Any]] = {
|
45 |
+
s["key"]: s for s in existing_settings_raw
|
46 |
+
}
|
47 |
+
existing_keys = set(existing_settings_map.keys())
|
48 |
+
|
49 |
+
settings_to_update: List[Dict[str, Any]] = []
|
50 |
+
settings_to_insert: List[Dict[str, Any]] = []
|
51 |
+
now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
|
52 |
+
|
53 |
+
# 准备要更新或插入的数据
|
54 |
+
for key, value in config_data.items():
|
55 |
+
# 处理不同类型的值
|
56 |
+
if isinstance(value, list):
|
57 |
+
db_value = json.dumps(value)
|
58 |
+
elif isinstance(value, dict):
|
59 |
+
db_value = json.dumps(value)
|
60 |
+
elif isinstance(value, bool):
|
61 |
+
db_value = str(value).lower()
|
62 |
+
else:
|
63 |
+
db_value = str(value)
|
64 |
+
|
65 |
+
# 仅当值发生变化时才更新
|
66 |
+
if key in existing_keys and existing_settings_map[key]["value"] == db_value:
|
67 |
+
continue
|
68 |
+
|
69 |
+
description = f"{key}配置项"
|
70 |
+
|
71 |
+
data = {
|
72 |
+
"key": key,
|
73 |
+
"value": db_value,
|
74 |
+
"description": description,
|
75 |
+
"updated_at": now,
|
76 |
+
}
|
77 |
+
|
78 |
+
if key in existing_keys:
|
79 |
+
data["description"] = existing_settings_map[key].get(
|
80 |
+
"description", description
|
81 |
+
)
|
82 |
+
settings_to_update.append(data)
|
83 |
+
else:
|
84 |
+
data["created_at"] = now
|
85 |
+
settings_to_insert.append(data)
|
86 |
+
|
87 |
+
# 在事务中执行批量插入和更新
|
88 |
+
if settings_to_insert or settings_to_update:
|
89 |
+
try:
|
90 |
+
async with database.transaction():
|
91 |
+
if settings_to_insert:
|
92 |
+
query_insert = insert(Settings).values(settings_to_insert)
|
93 |
+
await database.execute(query=query_insert)
|
94 |
+
logger.info(
|
95 |
+
f"Bulk inserted {len(settings_to_insert)} settings."
|
96 |
+
)
|
97 |
+
|
98 |
+
if settings_to_update:
|
99 |
+
for setting_data in settings_to_update:
|
100 |
+
query_update = (
|
101 |
+
update(Settings)
|
102 |
+
.where(Settings.key == setting_data["key"])
|
103 |
+
.values(
|
104 |
+
value=setting_data["value"],
|
105 |
+
description=setting_data["description"],
|
106 |
+
updated_at=setting_data["updated_at"],
|
107 |
+
)
|
108 |
+
)
|
109 |
+
await database.execute(query=query_update)
|
110 |
+
logger.info(f"Updated {len(settings_to_update)} settings.")
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Failed to bulk update/insert settings: {str(e)}")
|
113 |
+
raise
|
114 |
+
|
115 |
+
# 重置并重新初始化 KeyManager
|
116 |
+
try:
|
117 |
+
await reset_key_manager_instance()
|
118 |
+
await get_key_manager_instance(settings.API_KEYS, settings.VERTEX_API_KEYS)
|
119 |
+
logger.info("KeyManager instance re-initialized with updated settings.")
|
120 |
+
except Exception as e:
|
121 |
+
logger.error(f"Failed to re-initialize KeyManager: {str(e)}")
|
122 |
+
|
123 |
+
return await ConfigService.get_config()
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
async def delete_key(key_to_delete: str) -> Dict[str, Any]:
|
127 |
+
"""删除单个API密钥"""
|
128 |
+
# 确保 settings.API_KEYS 是一个列表
|
129 |
+
if not isinstance(settings.API_KEYS, list):
|
130 |
+
settings.API_KEYS = []
|
131 |
+
|
132 |
+
original_keys_count = len(settings.API_KEYS)
|
133 |
+
# 创建一个不包含待删除密钥的新列表
|
134 |
+
updated_api_keys = [k for k in settings.API_KEYS if k != key_to_delete]
|
135 |
+
|
136 |
+
if len(updated_api_keys) < original_keys_count:
|
137 |
+
# 密钥已找到并从列表中移除
|
138 |
+
settings.API_KEYS = updated_api_keys # 首先更新内存中的 settings
|
139 |
+
# 使用 update_config 持久化更改,它同时处理数据库和 KeyManager
|
140 |
+
await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
|
141 |
+
logger.info(f"密钥 '{key_to_delete}' 已成功删除。")
|
142 |
+
return {"success": True, "message": f"密钥 '{key_to_delete}' 已成功删除。"}
|
143 |
+
else:
|
144 |
+
# 未找到密钥
|
145 |
+
logger.warning(f"尝试删除密钥 '{key_to_delete}',但未找到该密钥。")
|
146 |
+
return {"success": False, "message": f"未找到密钥 '{key_to_delete}'。"}
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
async def delete_selected_keys(keys_to_delete: List[str]) -> Dict[str, Any]:
|
150 |
+
"""批量删除选定的API密钥"""
|
151 |
+
if not isinstance(settings.API_KEYS, list):
|
152 |
+
settings.API_KEYS = []
|
153 |
+
|
154 |
+
deleted_count = 0
|
155 |
+
not_found_keys: List[str] = []
|
156 |
+
|
157 |
+
current_api_keys = list(settings.API_KEYS)
|
158 |
+
keys_actually_removed: List[str] = []
|
159 |
+
|
160 |
+
for key_to_del in keys_to_delete:
|
161 |
+
if key_to_del in current_api_keys:
|
162 |
+
current_api_keys.remove(key_to_del)
|
163 |
+
keys_actually_removed.append(key_to_del)
|
164 |
+
deleted_count += 1
|
165 |
+
else:
|
166 |
+
not_found_keys.append(key_to_del)
|
167 |
+
|
168 |
+
if deleted_count > 0:
|
169 |
+
settings.API_KEYS = current_api_keys
|
170 |
+
await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
|
171 |
+
logger.info(
|
172 |
+
f"成功删除 {deleted_count} 个密钥。密钥: {keys_actually_removed}"
|
173 |
+
)
|
174 |
+
message = f"成功删除 {deleted_count} 个密钥。"
|
175 |
+
if not_found_keys:
|
176 |
+
message += f" {len(not_found_keys)} 个密钥未找到: {not_found_keys}。"
|
177 |
+
return {
|
178 |
+
"success": True,
|
179 |
+
"message": message,
|
180 |
+
"deleted_count": deleted_count,
|
181 |
+
"not_found_keys": not_found_keys,
|
182 |
+
}
|
183 |
+
else:
|
184 |
+
message = "没有密钥被删除。"
|
185 |
+
if not_found_keys:
|
186 |
+
message = f"所有 {len(not_found_keys)} 个指定的密钥均未找到: {not_found_keys}。"
|
187 |
+
elif not keys_to_delete:
|
188 |
+
message = "未指定要删除的密钥。"
|
189 |
+
logger.warning(message)
|
190 |
+
return {
|
191 |
+
"success": False,
|
192 |
+
"message": message,
|
193 |
+
"deleted_count": 0,
|
194 |
+
"not_found_keys": not_found_keys,
|
195 |
+
}
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
async def reset_config() -> Dict[str, Any]:
|
199 |
+
"""
|
200 |
+
重置配置:优先从系统环境变量加载,然后从 .env 文件加载,
|
201 |
+
更新内存中的 settings 对象,并刷新 KeyManager。
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Dict[str, Any]: 重置后的配置字典
|
205 |
+
"""
|
206 |
+
# 1. 重新加载配置对象,它应该处理环境变量和 .env 的优先级
|
207 |
+
_reload_settings()
|
208 |
+
logger.info(
|
209 |
+
"Settings object reloaded, prioritizing system environment variables then .env file."
|
210 |
+
)
|
211 |
+
|
212 |
+
# 2. 重置并重新初始化 KeyManager
|
213 |
+
try:
|
214 |
+
await reset_key_manager_instance()
|
215 |
+
# 确保使用更新后的 settings 中的 API_KEYS
|
216 |
+
await get_key_manager_instance(settings.API_KEYS)
|
217 |
+
logger.info("KeyManager instance re-initialized with reloaded settings.")
|
218 |
+
except Exception as e:
|
219 |
+
logger.error(f"Failed to re-initialize KeyManager during reset: {str(e)}")
|
220 |
+
# 根据需要决定是否抛出异常或继续
|
221 |
+
# 这里选择记录错误并继续
|
222 |
+
|
223 |
+
# 3. 返回更新后的配置
|
224 |
+
return await ConfigService.get_config()
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
async def fetch_ui_models() -> List[Dict[str, Any]]:
|
228 |
+
"""获取用于UI显示的模型列表"""
|
229 |
+
try:
|
230 |
+
key_manager = await get_key_manager_instance()
|
231 |
+
model_service = ModelService()
|
232 |
+
|
233 |
+
api_key = await key_manager.get_first_valid_key()
|
234 |
+
if not api_key:
|
235 |
+
logger.error("No valid API keys available to fetch model list for UI.")
|
236 |
+
raise HTTPException(
|
237 |
+
status_code=500,
|
238 |
+
detail="No valid API keys available to fetch model list.",
|
239 |
+
)
|
240 |
+
|
241 |
+
models = await model_service.get_gemini_openai_models(api_key)
|
242 |
+
return models
|
243 |
+
except HTTPException as e:
|
244 |
+
raise e
|
245 |
+
except Exception as e:
|
246 |
+
logger.error(
|
247 |
+
f"Failed to fetch models for UI in ConfigService: {e}", exc_info=True
|
248 |
+
)
|
249 |
+
raise HTTPException(
|
250 |
+
status_code=500, detail=f"Failed to fetch models for UI: {str(e)}"
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
# 重新加载配置的函数
|
255 |
+
def _reload_settings():
|
256 |
+
"""重新加载环境变量并更新配置"""
|
257 |
+
# 显式加载 .env 文件,覆盖现有环境变量
|
258 |
+
load_dotenv(find_dotenv(), override=True)
|
259 |
+
# 更新现有 settings 对象的属性,而不是新建实例
|
260 |
+
for key, value in ConfigSettings().model_dump().items():
|
261 |
+
setattr(settings, key, value)
|
app/service/embedding/embedding_service.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import time
|
3 |
+
import re
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
import openai
|
7 |
+
from openai import APIStatusError
|
8 |
+
from openai.types import CreateEmbeddingResponse
|
9 |
+
|
10 |
+
from app.config.config import settings
|
11 |
+
from app.log.logger import get_embeddings_logger
|
12 |
+
from app.database.services import add_error_log, add_request_log
|
13 |
+
|
14 |
+
logger = get_embeddings_logger()
|
15 |
+
|
16 |
+
|
17 |
+
class EmbeddingService:
|
18 |
+
|
19 |
+
async def create_embedding(
|
20 |
+
self, input_text: Union[str, List[str]], model: str, api_key: str
|
21 |
+
) -> CreateEmbeddingResponse:
|
22 |
+
"""Create embeddings using OpenAI API with database logging"""
|
23 |
+
start_time = time.perf_counter()
|
24 |
+
request_datetime = datetime.datetime.now()
|
25 |
+
is_success = False
|
26 |
+
status_code = None
|
27 |
+
response = None
|
28 |
+
error_log_msg = ""
|
29 |
+
if isinstance(input_text, list):
|
30 |
+
request_msg_log = {"input_truncated": [str(item)[:100] + "..." if len(str(item)) > 100 else str(item) for item in input_text[:5]]}
|
31 |
+
if len(input_text) > 5:
|
32 |
+
request_msg_log["input_truncated"].append("...")
|
33 |
+
else:
|
34 |
+
request_msg_log = {"input_truncated": input_text[:1000] + "..." if len(input_text) > 1000 else input_text}
|
35 |
+
|
36 |
+
|
37 |
+
try:
|
38 |
+
client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL)
|
39 |
+
response = client.embeddings.create(input=input_text, model=model)
|
40 |
+
is_success = True
|
41 |
+
status_code = 200
|
42 |
+
return response
|
43 |
+
except APIStatusError as e:
|
44 |
+
is_success = False
|
45 |
+
status_code = e.status_code
|
46 |
+
error_log_msg = f"OpenAI API error: {e}"
|
47 |
+
logger.error(f"Error creating embedding (APIStatusError): {error_log_msg}")
|
48 |
+
raise e
|
49 |
+
except Exception as e:
|
50 |
+
is_success = False
|
51 |
+
error_log_msg = f"Generic error: {e}"
|
52 |
+
logger.error(f"Error creating embedding (Exception): {error_log_msg}")
|
53 |
+
match = re.search(r"status code (\d+)", str(e))
|
54 |
+
if match:
|
55 |
+
status_code = int(match.group(1))
|
56 |
+
else:
|
57 |
+
status_code = 500
|
58 |
+
raise e
|
59 |
+
finally:
|
60 |
+
end_time = time.perf_counter()
|
61 |
+
latency_ms = int((end_time - start_time) * 1000)
|
62 |
+
if not is_success:
|
63 |
+
await add_error_log(
|
64 |
+
gemini_key=api_key,
|
65 |
+
model_name=model,
|
66 |
+
error_type="openai-embedding",
|
67 |
+
error_log=error_log_msg,
|
68 |
+
error_code=status_code,
|
69 |
+
request_msg=request_msg_log
|
70 |
+
)
|
71 |
+
await add_request_log(
|
72 |
+
model_name=model,
|
73 |
+
api_key=api_key,
|
74 |
+
is_success=is_success,
|
75 |
+
status_code=status_code,
|
76 |
+
latency_ms=latency_ms,
|
77 |
+
request_time=request_datetime
|
78 |
+
)
|
app/service/error_log/error_log_service.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timedelta, timezone
|
2 |
+
from typing import Any, Dict, List, Optional
|
3 |
+
|
4 |
+
from sqlalchemy import delete, func, select
|
5 |
+
|
6 |
+
from app.config.config import settings
|
7 |
+
from app.database import services as db_services
|
8 |
+
from app.database.connection import database
|
9 |
+
from app.database.models import ErrorLog
|
10 |
+
from app.log.logger import get_error_log_logger
|
11 |
+
|
12 |
+
logger = get_error_log_logger()
|
13 |
+
|
14 |
+
|
15 |
+
async def delete_old_error_logs():
|
16 |
+
"""
|
17 |
+
Deletes error logs older than a specified number of days,
|
18 |
+
based on the AUTO_DELETE_ERROR_LOGS_ENABLED and AUTO_DELETE_ERROR_LOGS_DAYS settings.
|
19 |
+
"""
|
20 |
+
if not settings.AUTO_DELETE_ERROR_LOGS_ENABLED:
|
21 |
+
logger.info("Auto-deletion of error logs is disabled. Skipping.")
|
22 |
+
return
|
23 |
+
|
24 |
+
days_to_keep = settings.AUTO_DELETE_ERROR_LOGS_DAYS
|
25 |
+
if not isinstance(days_to_keep, int) or days_to_keep <= 0:
|
26 |
+
logger.error(
|
27 |
+
f"Invalid AUTO_DELETE_ERROR_LOGS_DAYS value: {days_to_keep}. Must be a positive integer. Skipping deletion."
|
28 |
+
)
|
29 |
+
return
|
30 |
+
|
31 |
+
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
32 |
+
|
33 |
+
logger.info(
|
34 |
+
f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})."
|
35 |
+
)
|
36 |
+
|
37 |
+
try:
|
38 |
+
if not database.is_connected:
|
39 |
+
await database.connect()
|
40 |
+
logger.info("Database connection established for deleting error logs.")
|
41 |
+
|
42 |
+
# First, count how many logs will be deleted (optional, for logging)
|
43 |
+
count_query = select(func.count(ErrorLog.id)).where(
|
44 |
+
ErrorLog.request_time < cutoff_date
|
45 |
+
)
|
46 |
+
num_logs_to_delete = await database.fetch_val(count_query)
|
47 |
+
|
48 |
+
if num_logs_to_delete == 0:
|
49 |
+
logger.info(
|
50 |
+
"No error logs found older than the specified period. No deletion needed."
|
51 |
+
)
|
52 |
+
return
|
53 |
+
|
54 |
+
logger.info(f"Found {num_logs_to_delete} error logs to delete.")
|
55 |
+
|
56 |
+
# Perform the deletion
|
57 |
+
query = delete(ErrorLog).where(ErrorLog.request_time < cutoff_date)
|
58 |
+
await database.execute(query)
|
59 |
+
logger.info(
|
60 |
+
f"Successfully deleted {num_logs_to_delete} error logs older than {days_to_keep} days."
|
61 |
+
)
|
62 |
+
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(
|
65 |
+
f"Error during automatic deletion of error logs: {e}", exc_info=True
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
async def process_get_error_logs(
|
70 |
+
limit: int,
|
71 |
+
offset: int,
|
72 |
+
key_search: Optional[str],
|
73 |
+
error_search: Optional[str],
|
74 |
+
error_code_search: Optional[str],
|
75 |
+
start_date: Optional[datetime],
|
76 |
+
end_date: Optional[datetime],
|
77 |
+
sort_by: str,
|
78 |
+
sort_order: str,
|
79 |
+
) -> Dict[str, Any]:
|
80 |
+
"""
|
81 |
+
处理错误日志的检索,支持分页和过滤。
|
82 |
+
"""
|
83 |
+
try:
|
84 |
+
logs_data = await db_services.get_error_logs(
|
85 |
+
limit=limit,
|
86 |
+
offset=offset,
|
87 |
+
key_search=key_search,
|
88 |
+
error_search=error_search,
|
89 |
+
error_code_search=error_code_search,
|
90 |
+
start_date=start_date,
|
91 |
+
end_date=end_date,
|
92 |
+
sort_by=sort_by,
|
93 |
+
sort_order=sort_order,
|
94 |
+
)
|
95 |
+
total_count = await db_services.get_error_logs_count(
|
96 |
+
key_search=key_search,
|
97 |
+
error_search=error_search,
|
98 |
+
error_code_search=error_code_search,
|
99 |
+
start_date=start_date,
|
100 |
+
end_date=end_date,
|
101 |
+
)
|
102 |
+
return {"logs": logs_data, "total": total_count}
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Service error in process_get_error_logs: {e}", exc_info=True)
|
105 |
+
raise
|
106 |
+
|
107 |
+
|
108 |
+
async def process_get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
|
109 |
+
"""
|
110 |
+
处理特定错误日志详细信息的检索。
|
111 |
+
如果未找到,则返回 None。
|
112 |
+
"""
|
113 |
+
try:
|
114 |
+
log_details = await db_services.get_error_log_details(log_id=log_id)
|
115 |
+
return log_details
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(
|
118 |
+
f"Service error in process_get_error_log_details for ID {log_id}: {e}",
|
119 |
+
exc_info=True,
|
120 |
+
)
|
121 |
+
raise
|
122 |
+
|
123 |
+
|
124 |
+
async def process_delete_error_logs_by_ids(log_ids: List[int]) -> int:
|
125 |
+
"""
|
126 |
+
按 ID 批量删除错误日志。
|
127 |
+
返回尝试删除的日志数量。
|
128 |
+
"""
|
129 |
+
if not log_ids:
|
130 |
+
return 0
|
131 |
+
try:
|
132 |
+
deleted_count = await db_services.delete_error_logs_by_ids(log_ids)
|
133 |
+
return deleted_count
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(
|
136 |
+
f"Service error in process_delete_error_logs_by_ids for IDs {log_ids}: {e}",
|
137 |
+
exc_info=True,
|
138 |
+
)
|
139 |
+
raise
|
140 |
+
|
141 |
+
|
142 |
+
async def process_delete_error_log_by_id(log_id: int) -> bool:
|
143 |
+
"""
|
144 |
+
按 ID 删除单个错误日志。
|
145 |
+
如果删除成功(或找到日志并尝试删除),则返回 True,否则返回 False。
|
146 |
+
"""
|
147 |
+
try:
|
148 |
+
success = await db_services.delete_error_log_by_id(log_id)
|
149 |
+
return success
|
150 |
+
except Exception as e:
|
151 |
+
logger.error(
|
152 |
+
f"Service error in process_delete_error_log_by_id for ID {log_id}: {e}",
|
153 |
+
exc_info=True,
|
154 |
+
)
|
155 |
+
raise
|
156 |
+
|
157 |
+
|
158 |
+
async def process_delete_all_error_logs() -> int:
|
159 |
+
"""
|
160 |
+
处理删除所有错误日志的请求。
|
161 |
+
返回删除的日志数量。
|
162 |
+
"""
|
163 |
+
try:
|
164 |
+
if not database.is_connected:
|
165 |
+
await database.connect()
|
166 |
+
logger.info("Database connection established for deleting all error logs.")
|
167 |
+
|
168 |
+
deleted_count = await db_services.delete_all_error_logs()
|
169 |
+
logger.info(
|
170 |
+
f"Successfully processed request to delete all error logs. Count: {deleted_count}"
|
171 |
+
)
|
172 |
+
return deleted_count
|
173 |
+
except Exception as e:
|
174 |
+
logger.error(
|
175 |
+
f"Service error in process_delete_all_error_logs: {e}",
|
176 |
+
exc_info=True,
|
177 |
+
)
|
178 |
+
raise
|
app/service/image/image_create_service.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import time
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
from google import genai
|
6 |
+
from google.genai import types
|
7 |
+
|
8 |
+
from app.config.config import settings
|
9 |
+
from app.core.constants import VALID_IMAGE_RATIOS
|
10 |
+
from app.domain.openai_models import ImageGenerationRequest
|
11 |
+
from app.log.logger import get_image_create_logger
|
12 |
+
from app.utils.uploader import ImageUploaderFactory
|
13 |
+
|
14 |
+
logger = get_image_create_logger()
|
15 |
+
|
16 |
+
|
17 |
+
class ImageCreateService:
|
18 |
+
def __init__(self, aspect_ratio="1:1"):
|
19 |
+
self.image_model = settings.CREATE_IMAGE_MODEL
|
20 |
+
self.aspect_ratio = aspect_ratio
|
21 |
+
|
22 |
+
def parse_prompt_parameters(self, prompt: str) -> tuple:
|
23 |
+
"""从prompt中解析参数
|
24 |
+
支持的格式:
|
25 |
+
- {n:数量} 例如: {n:2} 生成2张图片
|
26 |
+
- {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
|
27 |
+
"""
|
28 |
+
import re
|
29 |
+
|
30 |
+
# 默认值
|
31 |
+
n = 1
|
32 |
+
aspect_ratio = self.aspect_ratio
|
33 |
+
|
34 |
+
# 解析n参数
|
35 |
+
n_match = re.search(r"{n:(\d+)}", prompt)
|
36 |
+
if n_match:
|
37 |
+
n = int(n_match.group(1))
|
38 |
+
if n < 1 or n > 4:
|
39 |
+
raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
|
40 |
+
prompt = prompt.replace(n_match.group(0), "").strip()
|
41 |
+
|
42 |
+
# 解析ratio参数
|
43 |
+
ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
|
44 |
+
if ratio_match:
|
45 |
+
aspect_ratio = ratio_match.group(1)
|
46 |
+
if aspect_ratio not in VALID_IMAGE_RATIOS:
|
47 |
+
raise ValueError(
|
48 |
+
f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
|
49 |
+
)
|
50 |
+
prompt = prompt.replace(ratio_match.group(0), "").strip()
|
51 |
+
|
52 |
+
return prompt, n, aspect_ratio
|
53 |
+
|
54 |
+
def generate_images(self, request: ImageGenerationRequest):
|
55 |
+
client = genai.Client(api_key=settings.PAID_KEY)
|
56 |
+
|
57 |
+
if request.size == "1024x1024":
|
58 |
+
self.aspect_ratio = "1:1"
|
59 |
+
elif request.size == "1792x1024":
|
60 |
+
self.aspect_ratio = "16:9"
|
61 |
+
elif request.size == "1027x1792":
|
62 |
+
self.aspect_ratio = "9:16"
|
63 |
+
else:
|
64 |
+
raise ValueError(
|
65 |
+
f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
|
66 |
+
)
|
67 |
+
|
68 |
+
# 解析prompt中的参数
|
69 |
+
cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
|
70 |
+
request.prompt
|
71 |
+
)
|
72 |
+
request.prompt = cleaned_prompt
|
73 |
+
|
74 |
+
# 如果prompt中指定了n,则覆盖请求中的n
|
75 |
+
if prompt_n > 1:
|
76 |
+
request.n = prompt_n
|
77 |
+
|
78 |
+
# 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
|
79 |
+
if prompt_ratio != self.aspect_ratio:
|
80 |
+
self.aspect_ratio = prompt_ratio
|
81 |
+
|
82 |
+
response = client.models.generate_images(
|
83 |
+
model=self.image_model,
|
84 |
+
prompt=request.prompt,
|
85 |
+
config=types.GenerateImagesConfig(
|
86 |
+
number_of_images=request.n,
|
87 |
+
output_mime_type="image/png",
|
88 |
+
aspect_ratio=self.aspect_ratio,
|
89 |
+
safety_filter_level="BLOCK_LOW_AND_ABOVE",
|
90 |
+
person_generation="ALLOW_ADULT",
|
91 |
+
),
|
92 |
+
)
|
93 |
+
|
94 |
+
if response.generated_images:
|
95 |
+
images_data = []
|
96 |
+
for index, generated_image in enumerate(response.generated_images):
|
97 |
+
image_data = generated_image.image.image_bytes
|
98 |
+
image_uploader = None
|
99 |
+
|
100 |
+
if request.response_format == "b64_json":
|
101 |
+
base64_image = base64.b64encode(image_data).decode("utf-8")
|
102 |
+
images_data.append(
|
103 |
+
{"b64_json": base64_image, "revised_prompt": request.prompt}
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
current_date = time.strftime("%Y/%m/%d")
|
107 |
+
filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
|
108 |
+
|
109 |
+
if settings.UPLOAD_PROVIDER == "smms":
|
110 |
+
image_uploader = ImageUploaderFactory.create(
|
111 |
+
provider=settings.UPLOAD_PROVIDER,
|
112 |
+
api_key=settings.SMMS_SECRET_TOKEN,
|
113 |
+
)
|
114 |
+
elif settings.UPLOAD_PROVIDER == "picgo":
|
115 |
+
image_uploader = ImageUploaderFactory.create(
|
116 |
+
provider=settings.UPLOAD_PROVIDER,
|
117 |
+
api_key=settings.PICGO_API_KEY,
|
118 |
+
)
|
119 |
+
elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
|
120 |
+
image_uploader = ImageUploaderFactory.create(
|
121 |
+
provider=settings.UPLOAD_PROVIDER,
|
122 |
+
base_url=settings.CLOUDFLARE_IMGBED_URL,
|
123 |
+
auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
raise ValueError(
|
127 |
+
f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
|
128 |
+
)
|
129 |
+
|
130 |
+
upload_response = image_uploader.upload(image_data, filename)
|
131 |
+
|
132 |
+
images_data.append(
|
133 |
+
{
|
134 |
+
"url": f"{upload_response.data.url}",
|
135 |
+
"revised_prompt": request.prompt,
|
136 |
+
}
|
137 |
+
)
|
138 |
+
|
139 |
+
response_data = {
|
140 |
+
"created": int(time.time()),
|
141 |
+
"data": images_data,
|
142 |
+
}
|
143 |
+
return response_data
|
144 |
+
else:
|
145 |
+
raise Exception("I can't generate these images")
|
146 |
+
|
147 |
+
def generate_images_chat(self, request: ImageGenerationRequest) -> str:
|
148 |
+
response = self.generate_images(request)
|
149 |
+
image_datas = response["data"]
|
150 |
+
if image_datas:
|
151 |
+
markdown_images = []
|
152 |
+
for index, image_data in enumerate(image_datas):
|
153 |
+
if "url" in image_data:
|
154 |
+
markdown_images.append(
|
155 |
+
f""
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
# 如果是base64格式,创建data URL
|
159 |
+
markdown_images.append(
|
160 |
+
f""
|
161 |
+
)
|
162 |
+
return "\n".join(markdown_images)
|
app/service/key/key_manager.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from itertools import cycle
|
3 |
+
from typing import Dict, Union
|
4 |
+
|
5 |
+
from app.config.config import settings
|
6 |
+
from app.log.logger import get_key_manager_logger
|
7 |
+
|
8 |
+
logger = get_key_manager_logger()
|
9 |
+
|
10 |
+
|
11 |
+
class KeyManager:
|
12 |
+
def __init__(self, api_keys: list, vertex_api_keys: list):
|
13 |
+
self.api_keys = api_keys
|
14 |
+
self.vertex_api_keys = vertex_api_keys
|
15 |
+
self.key_cycle = cycle(api_keys)
|
16 |
+
self.vertex_key_cycle = cycle(vertex_api_keys)
|
17 |
+
self.key_cycle_lock = asyncio.Lock()
|
18 |
+
self.vertex_key_cycle_lock = asyncio.Lock()
|
19 |
+
self.failure_count_lock = asyncio.Lock()
|
20 |
+
self.vertex_failure_count_lock = asyncio.Lock()
|
21 |
+
self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
|
22 |
+
self.vertex_key_failure_counts: Dict[str, int] = {
|
23 |
+
key: 0 for key in vertex_api_keys
|
24 |
+
}
|
25 |
+
self.MAX_FAILURES = settings.MAX_FAILURES
|
26 |
+
self.paid_key = settings.PAID_KEY
|
27 |
+
|
28 |
+
async def get_paid_key(self) -> str:
|
29 |
+
return self.paid_key
|
30 |
+
|
31 |
+
async def get_next_key(self) -> str:
|
32 |
+
"""获取下一个API key"""
|
33 |
+
async with self.key_cycle_lock:
|
34 |
+
return next(self.key_cycle)
|
35 |
+
|
36 |
+
async def get_next_vertex_key(self) -> str:
|
37 |
+
"""获取下一个 Vertex API key"""
|
38 |
+
async with self.vertex_key_cycle_lock:
|
39 |
+
return next(self.vertex_key_cycle)
|
40 |
+
|
41 |
+
async def is_key_valid(self, key: str) -> bool:
|
42 |
+
"""检查key是否有效"""
|
43 |
+
async with self.failure_count_lock:
|
44 |
+
return self.key_failure_counts[key] < self.MAX_FAILURES
|
45 |
+
|
46 |
+
async def is_vertex_key_valid(self, key: str) -> bool:
|
47 |
+
"""检查 Vertex key 是否有效"""
|
48 |
+
async with self.vertex_failure_count_lock:
|
49 |
+
return self.vertex_key_failure_counts[key] < self.MAX_FAILURES
|
50 |
+
|
51 |
+
async def reset_failure_counts(self):
|
52 |
+
"""重置所有key的失败计数"""
|
53 |
+
async with self.failure_count_lock:
|
54 |
+
for key in self.key_failure_counts:
|
55 |
+
self.key_failure_counts[key] = 0
|
56 |
+
|
57 |
+
async def reset_vertex_failure_counts(self):
|
58 |
+
"""重置所有 Vertex key 的失败计数"""
|
59 |
+
async with self.vertex_failure_count_lock:
|
60 |
+
for key in self.vertex_key_failure_counts:
|
61 |
+
self.vertex_key_failure_counts[key] = 0
|
62 |
+
|
63 |
+
async def reset_key_failure_count(self, key: str) -> bool:
|
64 |
+
"""重置指定key的失败计数"""
|
65 |
+
async with self.failure_count_lock:
|
66 |
+
if key in self.key_failure_counts:
|
67 |
+
self.key_failure_counts[key] = 0
|
68 |
+
logger.info(f"Reset failure count for key: {key}")
|
69 |
+
return True
|
70 |
+
logger.warning(
|
71 |
+
f"Attempt to reset failure count for non-existent key: {key}"
|
72 |
+
)
|
73 |
+
return False
|
74 |
+
|
75 |
+
async def reset_vertex_key_failure_count(self, key: str) -> bool:
|
76 |
+
"""重置指定 Vertex key 的失败计数"""
|
77 |
+
async with self.vertex_failure_count_lock:
|
78 |
+
if key in self.vertex_key_failure_counts:
|
79 |
+
self.vertex_key_failure_counts[key] = 0
|
80 |
+
logger.info(f"Reset failure count for Vertex key: {key}")
|
81 |
+
return True
|
82 |
+
logger.warning(
|
83 |
+
f"Attempt to reset failure count for non-existent Vertex key: {key}"
|
84 |
+
)
|
85 |
+
return False
|
86 |
+
|
87 |
+
async def get_next_working_key(self) -> str:
|
88 |
+
"""获取下一可用的API key"""
|
89 |
+
initial_key = await self.get_next_key()
|
90 |
+
current_key = initial_key
|
91 |
+
|
92 |
+
while True:
|
93 |
+
if await self.is_key_valid(current_key):
|
94 |
+
return current_key
|
95 |
+
|
96 |
+
current_key = await self.get_next_key()
|
97 |
+
if current_key == initial_key:
|
98 |
+
return current_key
|
99 |
+
|
100 |
+
async def get_next_working_vertex_key(self) -> str:
|
101 |
+
"""获取下一可用的 Vertex API key"""
|
102 |
+
initial_key = await self.get_next_vertex_key()
|
103 |
+
current_key = initial_key
|
104 |
+
|
105 |
+
while True:
|
106 |
+
if await self.is_vertex_key_valid(current_key):
|
107 |
+
return current_key
|
108 |
+
|
109 |
+
current_key = await self.get_next_vertex_key()
|
110 |
+
if current_key == initial_key:
|
111 |
+
return current_key
|
112 |
+
|
113 |
+
async def handle_api_failure(self, api_key: str, retries: int) -> str:
|
114 |
+
"""处理API调用失败"""
|
115 |
+
async with self.failure_count_lock:
|
116 |
+
self.key_failure_counts[api_key] += 1
|
117 |
+
if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
|
118 |
+
logger.warning(
|
119 |
+
f"API key {api_key} has failed {self.MAX_FAILURES} times"
|
120 |
+
)
|
121 |
+
if retries < settings.MAX_RETRIES:
|
122 |
+
return await self.get_next_working_key()
|
123 |
+
else:
|
124 |
+
return ""
|
125 |
+
|
126 |
+
async def handle_vertex_api_failure(self, api_key: str, retries: int) -> str:
|
127 |
+
"""处理 Vertex API 调用失败"""
|
128 |
+
async with self.vertex_failure_count_lock:
|
129 |
+
self.vertex_key_failure_counts[api_key] += 1
|
130 |
+
if self.vertex_key_failure_counts[api_key] >= self.MAX_FAILURES:
|
131 |
+
logger.warning(
|
132 |
+
f"Vertex API key {api_key} has failed {self.MAX_FAILURES} times"
|
133 |
+
)
|
134 |
+
|
135 |
+
def get_fail_count(self, key: str) -> int:
|
136 |
+
"""获取指定密钥的失败次数"""
|
137 |
+
return self.key_failure_counts.get(key, 0)
|
138 |
+
|
139 |
+
def get_vertex_fail_count(self, key: str) -> int:
|
140 |
+
"""获取指定 Vertex 密钥的失败次数"""
|
141 |
+
return self.vertex_key_failure_counts.get(key, 0)
|
142 |
+
|
143 |
+
async def get_keys_by_status(self) -> dict:
|
144 |
+
"""获取分类后的API key列表,包括失败次数"""
|
145 |
+
valid_keys = {}
|
146 |
+
invalid_keys = {}
|
147 |
+
|
148 |
+
async with self.failure_count_lock:
|
149 |
+
for key in self.api_keys:
|
150 |
+
fail_count = self.key_failure_counts[key]
|
151 |
+
if fail_count < self.MAX_FAILURES:
|
152 |
+
valid_keys[key] = fail_count
|
153 |
+
else:
|
154 |
+
invalid_keys[key] = fail_count
|
155 |
+
|
156 |
+
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
157 |
+
|
158 |
+
async def get_vertex_keys_by_status(self) -> dict:
|
159 |
+
"""获取分类后的 Vertex API key 列表,包括失败次数"""
|
160 |
+
valid_keys = {}
|
161 |
+
invalid_keys = {}
|
162 |
+
|
163 |
+
async with self.vertex_failure_count_lock:
|
164 |
+
for key in self.vertex_api_keys:
|
165 |
+
fail_count = self.vertex_key_failure_counts[key]
|
166 |
+
if fail_count < self.MAX_FAILURES:
|
167 |
+
valid_keys[key] = fail_count
|
168 |
+
else:
|
169 |
+
invalid_keys[key] = fail_count
|
170 |
+
return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
|
171 |
+
|
172 |
+
async def get_first_valid_key(self) -> str:
|
173 |
+
"""获取第一个有效的API key"""
|
174 |
+
async with self.failure_count_lock:
|
175 |
+
for key in self.key_failure_counts:
|
176 |
+
if self.key_failure_counts[key] < self.MAX_FAILURES:
|
177 |
+
return key
|
178 |
+
if self.api_keys:
|
179 |
+
return self.api_keys[0]
|
180 |
+
if not self.api_keys:
|
181 |
+
logger.warning(
|
182 |
+
"API key list is empty, cannot get first valid key.")
|
183 |
+
return ""
|
184 |
+
return self.api_keys[0]
|
185 |
+
|
186 |
+
|
187 |
+
_singleton_instance = None
|
188 |
+
_singleton_lock = asyncio.Lock()
|
189 |
+
_preserved_failure_counts: Union[Dict[str, int], None] = None
|
190 |
+
_preserved_vertex_failure_counts: Union[Dict[str, int], None] = None
|
191 |
+
_preserved_old_api_keys_for_reset: Union[list, None] = None
|
192 |
+
_preserved_vertex_old_api_keys_for_reset: Union[list, None] = None
|
193 |
+
_preserved_next_key_in_cycle: Union[str, None] = None
|
194 |
+
_preserved_vertex_next_key_in_cycle: Union[str, None] = None
|
195 |
+
|
196 |
+
|
197 |
+
async def get_key_manager_instance(
|
198 |
+
api_keys: list = None, vertex_api_keys: list = None
|
199 |
+
) -> KeyManager:
|
200 |
+
"""
|
201 |
+
获取 KeyManager 单例实例。
|
202 |
+
|
203 |
+
如果尚未创建实例,将使用提供的 api_keys,vertex_api_keys 初始化 KeyManager。
|
204 |
+
如果已创建实例,则忽略 api_keys 参数,返回现有单例。
|
205 |
+
如果在重置后调用,会尝试恢复之前的状态(失败计数、循环位置)。
|
206 |
+
"""
|
207 |
+
global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle
|
208 |
+
|
209 |
+
async with _singleton_lock:
|
210 |
+
if _singleton_instance is None:
|
211 |
+
if api_keys is None:
|
212 |
+
raise ValueError(
|
213 |
+
"API keys are required to initialize or re-initialize the KeyManager instance."
|
214 |
+
)
|
215 |
+
if vertex_api_keys is None:
|
216 |
+
raise ValueError(
|
217 |
+
"Vertex API keys are required to initialize or re-initialize the KeyManager instance."
|
218 |
+
)
|
219 |
+
|
220 |
+
if not api_keys:
|
221 |
+
logger.warning(
|
222 |
+
"Initializing KeyManager with an empty list of API keys."
|
223 |
+
)
|
224 |
+
if not vertex_api_keys:
|
225 |
+
logger.warning(
|
226 |
+
"Initializing KeyManager with an empty list of Vertex API keys."
|
227 |
+
)
|
228 |
+
|
229 |
+
_singleton_instance = KeyManager(api_keys, vertex_api_keys)
|
230 |
+
logger.info(
|
231 |
+
f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex API keys."
|
232 |
+
)
|
233 |
+
|
234 |
+
# 1. 恢复失败计数
|
235 |
+
if _preserved_failure_counts:
|
236 |
+
current_failure_counts = {
|
237 |
+
key: 0 for key in _singleton_instance.api_keys
|
238 |
+
}
|
239 |
+
for key, count in _preserved_failure_counts.items():
|
240 |
+
if key in current_failure_counts:
|
241 |
+
current_failure_counts[key] = count
|
242 |
+
_singleton_instance.key_failure_counts = current_failure_counts
|
243 |
+
logger.info("Inherited failure counts for applicable keys.")
|
244 |
+
_preserved_failure_counts = None
|
245 |
+
|
246 |
+
if _preserved_vertex_failure_counts:
|
247 |
+
current_vertex_failure_counts = {
|
248 |
+
key: 0 for key in _singleton_instance.vertex_api_keys
|
249 |
+
}
|
250 |
+
for key, count in _preserved_vertex_failure_counts.items():
|
251 |
+
if key in current_vertex_failure_counts:
|
252 |
+
current_vertex_failure_counts[key] = count
|
253 |
+
_singleton_instance.vertex_key_failure_counts = (
|
254 |
+
current_vertex_failure_counts
|
255 |
+
)
|
256 |
+
logger.info(
|
257 |
+
"Inherited failure counts for applicable Vertex keys.")
|
258 |
+
_preserved_vertex_failure_counts = None
|
259 |
+
|
260 |
+
# 2. 调整 key_cycle 的起始点
|
261 |
+
start_key_for_new_cycle = None
|
262 |
+
if (
|
263 |
+
_preserved_old_api_keys_for_reset
|
264 |
+
and _preserved_next_key_in_cycle
|
265 |
+
and _singleton_instance.api_keys
|
266 |
+
):
|
267 |
+
try:
|
268 |
+
start_idx_in_old = _preserved_old_api_keys_for_reset.index(
|
269 |
+
_preserved_next_key_in_cycle
|
270 |
+
)
|
271 |
+
|
272 |
+
for i in range(len(_preserved_old_api_keys_for_reset)):
|
273 |
+
current_old_key_idx = (start_idx_in_old + i) % len(
|
274 |
+
_preserved_old_api_keys_for_reset
|
275 |
+
)
|
276 |
+
key_candidate = _preserved_old_api_keys_for_reset[
|
277 |
+
current_old_key_idx
|
278 |
+
]
|
279 |
+
if key_candidate in _singleton_instance.api_keys:
|
280 |
+
start_key_for_new_cycle = key_candidate
|
281 |
+
break
|
282 |
+
except ValueError:
|
283 |
+
logger.warning(
|
284 |
+
f"Preserved next key '{_preserved_next_key_in_cycle}' not found in preserved old API keys. "
|
285 |
+
"New cycle will start from the beginning of the new list."
|
286 |
+
)
|
287 |
+
except Exception as e:
|
288 |
+
logger.error(
|
289 |
+
f"Error determining start key for new cycle from preserved state: {e}. "
|
290 |
+
"New cycle will start from the beginning."
|
291 |
+
)
|
292 |
+
|
293 |
+
if start_key_for_new_cycle and _singleton_instance.api_keys:
|
294 |
+
try:
|
295 |
+
target_idx = _singleton_instance.api_keys.index(
|
296 |
+
start_key_for_new_cycle
|
297 |
+
)
|
298 |
+
for _ in range(target_idx):
|
299 |
+
next(_singleton_instance.key_cycle)
|
300 |
+
logger.info(
|
301 |
+
f"Key cycle in new instance advanced. Next call to get_next_key() will yield: {start_key_for_new_cycle}"
|
302 |
+
)
|
303 |
+
except ValueError:
|
304 |
+
logger.warning(
|
305 |
+
f"Determined start key '{start_key_for_new_cycle}' not found in new API keys during cycle advancement. "
|
306 |
+
"New cycle will start from the beginning."
|
307 |
+
)
|
308 |
+
except StopIteration:
|
309 |
+
logger.error(
|
310 |
+
"StopIteration while advancing key cycle, implies empty new API key list previously missed."
|
311 |
+
)
|
312 |
+
except Exception as e:
|
313 |
+
logger.error(
|
314 |
+
f"Error advancing new key cycle: {e}. Cycle will start from beginning."
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
if _singleton_instance.api_keys:
|
318 |
+
logger.info(
|
319 |
+
"New key cycle will start from the beginning of the new API key list (no specific start key determined or needed)."
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
logger.info(
|
323 |
+
"New key cycle not applicable as the new API key list is empty."
|
324 |
+
)
|
325 |
+
|
326 |
+
# 清理所有保存的状态
|
327 |
+
_preserved_old_api_keys_for_reset = None
|
328 |
+
_preserved_next_key_in_cycle = None
|
329 |
+
|
330 |
+
# 3. 调整 vertex_key_cycle 的起始点
|
331 |
+
start_key_for_new_vertex_cycle = None
|
332 |
+
if (
|
333 |
+
_preserved_vertex_old_api_keys_for_reset
|
334 |
+
and _preserved_vertex_next_key_in_cycle
|
335 |
+
and _singleton_instance.vertex_api_keys
|
336 |
+
):
|
337 |
+
try:
|
338 |
+
start_idx_in_old = _preserved_vertex_old_api_keys_for_reset.index(
|
339 |
+
_preserved_vertex_next_key_in_cycle
|
340 |
+
)
|
341 |
+
|
342 |
+
for i in range(len(_preserved_vertex_old_api_keys_for_reset)):
|
343 |
+
current_old_key_idx = (start_idx_in_old + i) % len(
|
344 |
+
_preserved_vertex_old_api_keys_for_reset
|
345 |
+
)
|
346 |
+
key_candidate = _preserved_vertex_old_api_keys_for_reset[
|
347 |
+
current_old_key_idx
|
348 |
+
]
|
349 |
+
if key_candidate in _singleton_instance.vertex_api_keys:
|
350 |
+
start_key_for_new_vertex_cycle = key_candidate
|
351 |
+
break
|
352 |
+
except ValueError:
|
353 |
+
logger.warning(
|
354 |
+
f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex API keys. "
|
355 |
+
"New cycle will start from the beginning of the new list."
|
356 |
+
)
|
357 |
+
except Exception as e:
|
358 |
+
logger.error(
|
359 |
+
f"Error determining start key for new Vertex key cycle from preserved state: {e}. "
|
360 |
+
"New cycle will start from the beginning."
|
361 |
+
)
|
362 |
+
|
363 |
+
if start_key_for_new_vertex_cycle and _singleton_instance.vertex_api_keys:
|
364 |
+
try:
|
365 |
+
target_idx = _singleton_instance.vertex_api_keys.index(
|
366 |
+
start_key_for_new_vertex_cycle
|
367 |
+
)
|
368 |
+
for _ in range(target_idx):
|
369 |
+
next(_singleton_instance.vertex_key_cycle)
|
370 |
+
logger.info(
|
371 |
+
f"Vertex key cycle in new instance advanced. Next call to get_next_vertex_key() will yield: {start_key_for_new_vertex_cycle}"
|
372 |
+
)
|
373 |
+
except ValueError:
|
374 |
+
logger.warning(
|
375 |
+
f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex API keys during cycle advancement. "
|
376 |
+
"New cycle will start from the beginning."
|
377 |
+
)
|
378 |
+
except StopIteration:
|
379 |
+
logger.error(
|
380 |
+
"StopIteration while advancing Vertex key cycle, implies empty new Vertex API key list previously missed."
|
381 |
+
)
|
382 |
+
except Exception as e:
|
383 |
+
logger.error(
|
384 |
+
f"Error advancing new Vertex key cycle: {e}. Cycle will start from beginning."
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
if _singleton_instance.vertex_api_keys:
|
388 |
+
logger.info(
|
389 |
+
"New Vertex key cycle will start from the beginning of the new Vertex API key list (no specific start key determined or needed)."
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
logger.info(
|
393 |
+
"New Vertex key cycle not applicable as the new Vertex API key list is empty."
|
394 |
+
)
|
395 |
+
|
396 |
+
# 清理所有保存的状态
|
397 |
+
_preserved_vertex_old_api_keys_for_reset = None
|
398 |
+
_preserved_vertex_next_key_in_cycle = None
|
399 |
+
|
400 |
+
return _singleton_instance
|
401 |
+
|
402 |
+
|
403 |
+
async def reset_key_manager_instance():
|
404 |
+
"""
|
405 |
+
重置 KeyManager 单例实例。
|
406 |
+
将保存当前实例的状态(失败计数、旧 API keys、下一个 key 提示)
|
407 |
+
以供下一次 get_key_manager_instance 调用时恢复。
|
408 |
+
"""
|
409 |
+
global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle
|
410 |
+
async with _singleton_lock:
|
411 |
+
if _singleton_instance:
|
412 |
+
# 1. 保存失败计数
|
413 |
+
_preserved_failure_counts = _singleton_instance.key_failure_counts.copy()
|
414 |
+
_preserved_vertex_failure_counts = _singleton_instance.vertex_key_failure_counts.copy()
|
415 |
+
|
416 |
+
# 2. 保存旧的 API keys 列表
|
417 |
+
_preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy()
|
418 |
+
_preserved_vertex_old_api_keys_for_reset = _singleton_instance.vertex_api_keys.copy()
|
419 |
+
|
420 |
+
# 3. 保存 key_cycle 的下一个 key 提示
|
421 |
+
try:
|
422 |
+
if _singleton_instance.api_keys:
|
423 |
+
_preserved_next_key_in_cycle = (
|
424 |
+
await _singleton_instance.get_next_key()
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
_preserved_next_key_in_cycle = None
|
428 |
+
except StopIteration:
|
429 |
+
logger.warning(
|
430 |
+
"Could not preserve next key hint: key cycle was empty or exhausted in old instance."
|
431 |
+
)
|
432 |
+
_preserved_next_key_in_cycle = None
|
433 |
+
except Exception as e:
|
434 |
+
logger.error(
|
435 |
+
f"Error preserving next key hint during reset: {e}")
|
436 |
+
_preserved_next_key_in_cycle = None
|
437 |
+
|
438 |
+
# 4. 保存 vertex_key_cycle 的下一个 key 提示
|
439 |
+
try:
|
440 |
+
if _singleton_instance.vertex_api_keys:
|
441 |
+
_preserved_vertex_next_key_in_cycle = (
|
442 |
+
await _singleton_instance.get_next_vertex_key()
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
_preserved_vertex_next_key_in_cycle = None
|
446 |
+
except StopIteration:
|
447 |
+
logger.warning(
|
448 |
+
"Could not preserve next key hint: Vertex key cycle was empty or exhausted in old instance."
|
449 |
+
)
|
450 |
+
_preserved_vertex_next_key_in_cycle = None
|
451 |
+
except Exception as e:
|
452 |
+
logger.error(
|
453 |
+
f"Error preserving next key hint during reset: {e}")
|
454 |
+
_preserved_vertex_next_key_in_cycle = None
|
455 |
+
|
456 |
+
_singleton_instance = None
|
457 |
+
logger.info(
|
458 |
+
"KeyManager instance has been reset. State (failure counts, old keys, next key hint) preserved for next instantiation."
|
459 |
+
)
|
460 |
+
else:
|
461 |
+
logger.info(
|
462 |
+
"KeyManager instance was not set (or already reset), no reset action performed."
|
463 |
+
)
|
app/service/model/model_service.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime, timezone
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
from app.config.config import settings
|
5 |
+
from app.log.logger import get_model_logger
|
6 |
+
from app.service.client.api_client import GeminiApiClient
|
7 |
+
|
8 |
+
logger = get_model_logger()
|
9 |
+
|
10 |
+
|
11 |
+
class ModelService:
|
12 |
+
async def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
13 |
+
api_client = GeminiApiClient(base_url=settings.BASE_URL)
|
14 |
+
gemini_models = await api_client.get_models(api_key)
|
15 |
+
|
16 |
+
if gemini_models is None:
|
17 |
+
logger.error("从 API 客户端获取模型列表失败。")
|
18 |
+
return None
|
19 |
+
|
20 |
+
try:
|
21 |
+
filtered_models_list = []
|
22 |
+
for model in gemini_models.get("models", []):
|
23 |
+
model_id = model["name"].split("/")[-1]
|
24 |
+
if model_id not in settings.FILTERED_MODELS:
|
25 |
+
filtered_models_list.append(model)
|
26 |
+
else:
|
27 |
+
logger.debug(f"Filtered out model: {model_id}")
|
28 |
+
|
29 |
+
gemini_models["models"] = filtered_models_list
|
30 |
+
return gemini_models
|
31 |
+
except Exception as e:
|
32 |
+
logger.error(f"处理模型列表时出错: {e}")
|
33 |
+
return None
|
34 |
+
|
35 |
+
async def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
|
36 |
+
"""获取 Gemini 模型并转换为 OpenAI 格式"""
|
37 |
+
gemini_models = await self.get_gemini_models(api_key)
|
38 |
+
if gemini_models is None:
|
39 |
+
return None
|
40 |
+
|
41 |
+
return await self.convert_to_openai_models_format(gemini_models)
|
42 |
+
|
43 |
+
async def convert_to_openai_models_format(
|
44 |
+
self, gemini_models: Dict[str, Any]
|
45 |
+
) -> Dict[str, Any]:
|
46 |
+
openai_format = {"object": "list", "data": [], "success": True}
|
47 |
+
|
48 |
+
for model in gemini_models.get("models", []):
|
49 |
+
model_id = model["name"].split("/")[-1]
|
50 |
+
openai_model = {
|
51 |
+
"id": model_id,
|
52 |
+
"object": "model",
|
53 |
+
"created": int(datetime.now(timezone.utc).timestamp()),
|
54 |
+
"owned_by": "google",
|
55 |
+
"permission": [],
|
56 |
+
"root": model["name"],
|
57 |
+
"parent": None,
|
58 |
+
}
|
59 |
+
openai_format["data"].append(openai_model)
|
60 |
+
|
61 |
+
if model_id in settings.SEARCH_MODELS:
|
62 |
+
search_model = openai_model.copy()
|
63 |
+
search_model["id"] = f"{model_id}-search"
|
64 |
+
openai_format["data"].append(search_model)
|
65 |
+
if model_id in settings.IMAGE_MODELS:
|
66 |
+
image_model = openai_model.copy()
|
67 |
+
image_model["id"] = f"{model_id}-image"
|
68 |
+
openai_format["data"].append(image_model)
|
69 |
+
if model_id in settings.THINKING_MODELS:
|
70 |
+
non_thinking_model = openai_model.copy()
|
71 |
+
non_thinking_model["id"] = f"{model_id}-non-thinking"
|
72 |
+
openai_format["data"].append(non_thinking_model)
|
73 |
+
|
74 |
+
if settings.CREATE_IMAGE_MODEL:
|
75 |
+
image_model = openai_model.copy()
|
76 |
+
image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
|
77 |
+
openai_format["data"].append(image_model)
|
78 |
+
return openai_format
|
79 |
+
|
80 |
+
async def check_model_support(self, model: str) -> bool:
|
81 |
+
if not model or not isinstance(model, str):
|
82 |
+
return False
|
83 |
+
|
84 |
+
model = model.strip()
|
85 |
+
if model.endswith("-search"):
|
86 |
+
model = model[:-7]
|
87 |
+
return model in settings.SEARCH_MODELS
|
88 |
+
if model.endswith("-image"):
|
89 |
+
model = model[:-6]
|
90 |
+
return model in settings.IMAGE_MODELS
|
91 |
+
|
92 |
+
return model not in settings.FILTERED_MODELS
|
app/service/openai_compatiable/openai_compatiable_service.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from typing import Any, AsyncGenerator, Dict, Union
|
7 |
+
|
8 |
+
from app.config.config import settings
|
9 |
+
from app.database.services import (
|
10 |
+
add_error_log,
|
11 |
+
add_request_log,
|
12 |
+
)
|
13 |
+
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
|
14 |
+
from app.service.client.api_client import OpenaiApiClient
|
15 |
+
from app.service.key.key_manager import KeyManager
|
16 |
+
from app.log.logger import get_openai_compatible_logger
|
17 |
+
|
18 |
+
logger = get_openai_compatible_logger()
|
19 |
+
|
20 |
+
class OpenAICompatiableService:
|
21 |
+
|
22 |
+
def __init__(self, base_url: str, key_manager: KeyManager = None):
|
23 |
+
self.key_manager = key_manager
|
24 |
+
self.base_url = base_url
|
25 |
+
self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
|
26 |
+
|
27 |
+
async def get_models(self, api_key: str) -> Dict[str, Any]:
|
28 |
+
return await self.api_client.get_models(api_key)
|
29 |
+
|
30 |
+
async def create_chat_completion(
|
31 |
+
self,
|
32 |
+
request: ChatRequest,
|
33 |
+
api_key: str,
|
34 |
+
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
|
35 |
+
"""创建聊天完成"""
|
36 |
+
request_dict = request.model_dump()
|
37 |
+
# 移除值为null的
|
38 |
+
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
39 |
+
del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
|
40 |
+
if request.stream:
|
41 |
+
return self._handle_stream_completion(request.model, request_dict, api_key)
|
42 |
+
return await self._handle_normal_completion(request.model, request_dict, api_key)
|
43 |
+
|
44 |
+
async def generate_images(
|
45 |
+
self,
|
46 |
+
request: ImageGenerationRequest,
|
47 |
+
) -> Dict[str, Any]:
|
48 |
+
"""生成图片"""
|
49 |
+
request_dict = request.model_dump()
|
50 |
+
# 移除值为null的
|
51 |
+
request_dict = {k: v for k, v in request_dict.items() if v is not None}
|
52 |
+
api_key = settings.PAID_KEY
|
53 |
+
return await self.api_client.generate_images(request_dict, api_key)
|
54 |
+
|
55 |
+
async def create_embeddings(
|
56 |
+
self,
|
57 |
+
input_text: str,
|
58 |
+
model: str,
|
59 |
+
api_key: str,
|
60 |
+
) -> Dict[str, Any]:
|
61 |
+
"""创建嵌入"""
|
62 |
+
return await self.api_client.create_embeddings(input_text, model, api_key)
|
63 |
+
|
64 |
+
async def _handle_normal_completion(
|
65 |
+
self, model: str, request: dict, api_key: str
|
66 |
+
) -> Dict[str, Any]:
|
67 |
+
"""处理普通聊天完成"""
|
68 |
+
start_time = time.perf_counter()
|
69 |
+
request_datetime = datetime.datetime.now()
|
70 |
+
is_success = False
|
71 |
+
status_code = None
|
72 |
+
response = None
|
73 |
+
try:
|
74 |
+
response = await self.api_client.generate_content(request, api_key)
|
75 |
+
is_success = True
|
76 |
+
status_code = 200
|
77 |
+
return response
|
78 |
+
except Exception as e:
|
79 |
+
is_success = False
|
80 |
+
error_log_msg = str(e)
|
81 |
+
logger.error(f"Normal API call failed with error: {error_log_msg}")
|
82 |
+
match = re.search(r"status code (\d+)", error_log_msg)
|
83 |
+
if match:
|
84 |
+
status_code = int(match.group(1))
|
85 |
+
else:
|
86 |
+
status_code = 500
|
87 |
+
|
88 |
+
await add_error_log(
|
89 |
+
gemini_key=api_key,
|
90 |
+
model_name=model,
|
91 |
+
error_type="openai-compatiable-non-stream",
|
92 |
+
error_log=error_log_msg,
|
93 |
+
error_code=status_code,
|
94 |
+
request_msg=request,
|
95 |
+
)
|
96 |
+
raise e
|
97 |
+
finally:
|
98 |
+
end_time = time.perf_counter()
|
99 |
+
latency_ms = int((end_time - start_time) * 1000)
|
100 |
+
await add_request_log(
|
101 |
+
model_name=model,
|
102 |
+
api_key=api_key,
|
103 |
+
is_success=is_success,
|
104 |
+
status_code=status_code,
|
105 |
+
latency_ms=latency_ms,
|
106 |
+
request_time=request_datetime,
|
107 |
+
)
|
108 |
+
|
109 |
+
async def _handle_stream_completion(
|
110 |
+
self, model: str, payload: dict, api_key: str
|
111 |
+
) -> AsyncGenerator[str, None]:
|
112 |
+
"""处理流式聊天完成,添加重试逻辑"""
|
113 |
+
retries = 0
|
114 |
+
max_retries = settings.MAX_RETRIES
|
115 |
+
is_success = False
|
116 |
+
status_code = None
|
117 |
+
final_api_key = api_key
|
118 |
+
|
119 |
+
while retries < max_retries:
|
120 |
+
start_time = time.perf_counter()
|
121 |
+
request_datetime = datetime.datetime.now()
|
122 |
+
current_attempt_key = api_key
|
123 |
+
final_api_key = current_attempt_key
|
124 |
+
try:
|
125 |
+
async for line in self.api_client.stream_generate_content(
|
126 |
+
payload, current_attempt_key
|
127 |
+
):
|
128 |
+
if line.startswith("data:"):
|
129 |
+
# print(line)
|
130 |
+
yield line + "\n\n"
|
131 |
+
logger.info("Streaming completed successfully")
|
132 |
+
is_success = True
|
133 |
+
status_code = 200
|
134 |
+
break
|
135 |
+
except Exception as e:
|
136 |
+
retries += 1
|
137 |
+
is_success = False
|
138 |
+
error_log_msg = str(e)
|
139 |
+
logger.warning(
|
140 |
+
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
|
141 |
+
)
|
142 |
+
match = re.search(r"status code (\d+)", error_log_msg)
|
143 |
+
if match:
|
144 |
+
status_code = int(match.group(1))
|
145 |
+
else:
|
146 |
+
status_code = 500
|
147 |
+
|
148 |
+
await add_error_log(
|
149 |
+
gemini_key=current_attempt_key,
|
150 |
+
model_name=model,
|
151 |
+
error_type="openai-compatiable-stream",
|
152 |
+
error_log=error_log_msg,
|
153 |
+
error_code=status_code,
|
154 |
+
request_msg=payload,
|
155 |
+
)
|
156 |
+
|
157 |
+
if self.key_manager:
|
158 |
+
api_key = await self.key_manager.handle_api_failure(
|
159 |
+
current_attempt_key, retries
|
160 |
+
)
|
161 |
+
if api_key:
|
162 |
+
logger.info(f"Switched to new API key: {api_key}")
|
163 |
+
else:
|
164 |
+
logger.error(
|
165 |
+
f"No valid API key available after {retries} retries."
|
166 |
+
)
|
167 |
+
break
|
168 |
+
else:
|
169 |
+
logger.error("KeyManager not available for retry logic.")
|
170 |
+
break
|
171 |
+
|
172 |
+
if retries >= max_retries:
|
173 |
+
logger.error(f"Max retries ({max_retries}) reached for streaming.")
|
174 |
+
break
|
175 |
+
finally:
|
176 |
+
end_time = time.perf_counter()
|
177 |
+
latency_ms = int((end_time - start_time) * 1000)
|
178 |
+
await add_request_log(
|
179 |
+
model_name=model,
|
180 |
+
api_key=final_api_key,
|
181 |
+
is_success=is_success,
|
182 |
+
status_code=status_code,
|
183 |
+
latency_ms=latency_ms,
|
184 |
+
request_time=request_datetime,
|
185 |
+
)
|
186 |
+
if not is_success and retries >= max_retries:
|
187 |
+
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
|
188 |
+
yield "data: [DONE]\n\n"
|
189 |
+
|
190 |
+
|
app/service/request_log/request_log_service.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Service for request log operations.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from datetime import datetime, timedelta, timezone
|
6 |
+
|
7 |
+
from sqlalchemy import delete
|
8 |
+
|
9 |
+
from app.database.connection import database
|
10 |
+
from app.config.config import settings
|
11 |
+
from app.database.models import RequestLog
|
12 |
+
from app.log.logger import get_request_log_logger
|
13 |
+
|
14 |
+
logger = get_request_log_logger()
|
15 |
+
|
16 |
+
|
17 |
+
async def delete_old_request_logs_task():
|
18 |
+
"""
|
19 |
+
定时删除旧的请求日志。
|
20 |
+
"""
|
21 |
+
if not settings.AUTO_DELETE_REQUEST_LOGS_ENABLED:
|
22 |
+
logger.info(
|
23 |
+
"Auto-delete for request logs is disabled by settings. Skipping task."
|
24 |
+
)
|
25 |
+
return
|
26 |
+
|
27 |
+
days_to_keep = settings.AUTO_DELETE_REQUEST_LOGS_DAYS
|
28 |
+
logger.info(
|
29 |
+
f"Starting scheduled task to delete old request logs older than {days_to_keep} days."
|
30 |
+
)
|
31 |
+
|
32 |
+
try:
|
33 |
+
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
34 |
+
|
35 |
+
query = delete(RequestLog).where(RequestLog.request_time < cutoff_date)
|
36 |
+
|
37 |
+
if not database.is_connected:
|
38 |
+
logger.info("Connecting to database for request log deletion.")
|
39 |
+
await database.connect()
|
40 |
+
|
41 |
+
result = await database.execute(query)
|
42 |
+
logger.info(
|
43 |
+
f"Request logs older than {cutoff_date} potentially deleted. Rows affected: {result.rowcount if result else 'N/A'}"
|
44 |
+
)
|
45 |
+
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(
|
48 |
+
f"An error occurred during the scheduled request log deletion: {str(e)}",
|
49 |
+
exc_info=True,
|
50 |
+
)
|
app/service/stats/stats_service.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/service/stats_service.py
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from sqlalchemy import and_, case, func, or_, select
|
7 |
+
|
8 |
+
from app.database.connection import database
|
9 |
+
from app.database.models import RequestLog
|
10 |
+
from app.log.logger import get_stats_logger
|
11 |
+
|
12 |
+
logger = get_stats_logger()
|
13 |
+
|
14 |
+
|
15 |
+
class StatsService:
|
16 |
+
"""Service class for handling statistics related operations."""
|
17 |
+
|
18 |
+
async def get_calls_in_last_seconds(self, seconds: int) -> dict[str, int]:
|
19 |
+
"""获取过去 N 秒内的调用次数 (总数、成功、失败)"""
|
20 |
+
try:
|
21 |
+
cutoff_time = datetime.datetime.now() - datetime.timedelta(seconds=seconds)
|
22 |
+
query = select(
|
23 |
+
func.count(RequestLog.id).label("total"),
|
24 |
+
func.sum(
|
25 |
+
case(
|
26 |
+
(
|
27 |
+
and_(
|
28 |
+
RequestLog.status_code >= 200,
|
29 |
+
RequestLog.status_code < 300,
|
30 |
+
),
|
31 |
+
1,
|
32 |
+
),
|
33 |
+
else_=0,
|
34 |
+
)
|
35 |
+
).label("success"),
|
36 |
+
func.sum(
|
37 |
+
case(
|
38 |
+
(
|
39 |
+
or_(
|
40 |
+
RequestLog.status_code < 200,
|
41 |
+
RequestLog.status_code >= 300,
|
42 |
+
),
|
43 |
+
1,
|
44 |
+
),
|
45 |
+
(RequestLog.status_code is None, 1),
|
46 |
+
else_=0,
|
47 |
+
)
|
48 |
+
).label("failure"),
|
49 |
+
).where(RequestLog.request_time >= cutoff_time)
|
50 |
+
result = await database.fetch_one(query)
|
51 |
+
if result:
|
52 |
+
return {
|
53 |
+
"total": result["total"] or 0,
|
54 |
+
"success": result["success"] or 0,
|
55 |
+
"failure": result["failure"] or 0,
|
56 |
+
}
|
57 |
+
return {"total": 0, "success": 0, "failure": 0}
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Failed to get calls in last {seconds} seconds: {e}")
|
60 |
+
return {"total": 0, "success": 0, "failure": 0}
|
61 |
+
|
62 |
+
async def get_calls_in_last_minutes(self, minutes: int) -> dict[str, int]:
|
63 |
+
"""获取过去 N 分钟内的调用次数 (总数、成功、失败)"""
|
64 |
+
return await self.get_calls_in_last_seconds(minutes * 60)
|
65 |
+
|
66 |
+
async def get_calls_in_last_hours(self, hours: int) -> dict[str, int]:
|
67 |
+
"""获取过去 N 小时内的调用次数 (总数、成功、失败)"""
|
68 |
+
return await self.get_calls_in_last_seconds(hours * 3600)
|
69 |
+
|
70 |
+
async def get_calls_in_current_month(self) -> dict[str, int]:
|
71 |
+
"""获取当前自然月内的调用次数 (总数、成功、失败)"""
|
72 |
+
try:
|
73 |
+
now = datetime.datetime.now()
|
74 |
+
start_of_month = now.replace(
|
75 |
+
day=1, hour=0, minute=0, second=0, microsecond=0
|
76 |
+
)
|
77 |
+
query = select(
|
78 |
+
func.count(RequestLog.id).label("total"),
|
79 |
+
func.sum(
|
80 |
+
case(
|
81 |
+
(
|
82 |
+
and_(
|
83 |
+
RequestLog.status_code >= 200,
|
84 |
+
RequestLog.status_code < 300,
|
85 |
+
),
|
86 |
+
1,
|
87 |
+
),
|
88 |
+
else_=0,
|
89 |
+
)
|
90 |
+
).label("success"),
|
91 |
+
func.sum(
|
92 |
+
case(
|
93 |
+
(
|
94 |
+
or_(
|
95 |
+
RequestLog.status_code < 200,
|
96 |
+
RequestLog.status_code >= 300,
|
97 |
+
),
|
98 |
+
1,
|
99 |
+
),
|
100 |
+
(RequestLog.status_code is None, 1),
|
101 |
+
else_=0,
|
102 |
+
)
|
103 |
+
).label("failure"),
|
104 |
+
).where(RequestLog.request_time >= start_of_month)
|
105 |
+
result = await database.fetch_one(query)
|
106 |
+
if result:
|
107 |
+
return {
|
108 |
+
"total": result["total"] or 0,
|
109 |
+
"success": result["success"] or 0,
|
110 |
+
"failure": result["failure"] or 0,
|
111 |
+
}
|
112 |
+
return {"total": 0, "success": 0, "failure": 0}
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Failed to get calls in current month: {e}")
|
115 |
+
return {"total": 0, "success": 0, "failure": 0}
|
116 |
+
|
117 |
+
async def get_api_usage_stats(self) -> dict:
|
118 |
+
"""获取所有需要的 API 使用统计数据 (总数、成功、失败)"""
|
119 |
+
try:
|
120 |
+
stats_1m = await self.get_calls_in_last_minutes(1)
|
121 |
+
stats_1h = await self.get_calls_in_last_hours(1)
|
122 |
+
stats_24h = await self.get_calls_in_last_hours(24)
|
123 |
+
stats_month = await self.get_calls_in_current_month()
|
124 |
+
|
125 |
+
return {
|
126 |
+
"calls_1m": stats_1m,
|
127 |
+
"calls_1h": stats_1h,
|
128 |
+
"calls_24h": stats_24h,
|
129 |
+
"calls_month": stats_month,
|
130 |
+
}
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"Failed to get API usage stats: {e}")
|
133 |
+
default_stat = {"total": 0, "success": 0, "failure": 0}
|
134 |
+
return {
|
135 |
+
"calls_1m": default_stat.copy(),
|
136 |
+
"calls_1h": default_stat.copy(),
|
137 |
+
"calls_24h": default_stat.copy(),
|
138 |
+
"calls_month": default_stat.copy(),
|
139 |
+
}
|
140 |
+
|
141 |
+
async def get_api_call_details(self, period: str) -> list[dict]:
|
142 |
+
"""
|
143 |
+
获取指定时间段内的 API 调用详情
|
144 |
+
|
145 |
+
Args:
|
146 |
+
period: 时间段标识 ('1m', '1h', '24h')
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
包含调用详情的字典列表,每个字典包含 timestamp, key, model, status
|
150 |
+
|
151 |
+
Raises:
|
152 |
+
ValueError: 如果 period 无效
|
153 |
+
"""
|
154 |
+
now = datetime.datetime.now()
|
155 |
+
if period == "1m":
|
156 |
+
start_time = now - datetime.timedelta(minutes=1)
|
157 |
+
elif period == "1h":
|
158 |
+
start_time = now - datetime.timedelta(hours=1)
|
159 |
+
elif period == "24h":
|
160 |
+
start_time = now - datetime.timedelta(hours=24)
|
161 |
+
else:
|
162 |
+
raise ValueError(f"无效的时间段标识: {period}")
|
163 |
+
|
164 |
+
try:
|
165 |
+
query = (
|
166 |
+
select(
|
167 |
+
RequestLog.request_time.label("timestamp"),
|
168 |
+
RequestLog.api_key.label("key"),
|
169 |
+
RequestLog.model_name.label("model"),
|
170 |
+
RequestLog.status_code,
|
171 |
+
)
|
172 |
+
.where(RequestLog.request_time >= start_time)
|
173 |
+
.order_by(RequestLog.request_time.desc())
|
174 |
+
)
|
175 |
+
|
176 |
+
results = await database.fetch_all(query)
|
177 |
+
|
178 |
+
details = []
|
179 |
+
for row in results:
|
180 |
+
status = "failure"
|
181 |
+
if row["status_code"] is not None:
|
182 |
+
status = "success" if 200 <= row["status_code"] < 300 else "failure"
|
183 |
+
details.append(
|
184 |
+
{
|
185 |
+
"timestamp": row[
|
186 |
+
"timestamp"
|
187 |
+
].isoformat(),
|
188 |
+
"key": row["key"],
|
189 |
+
"model": row["model"],
|
190 |
+
"status": status,
|
191 |
+
}
|
192 |
+
)
|
193 |
+
logger.info(
|
194 |
+
f"Retrieved {len(details)} API call details for period '{period}'"
|
195 |
+
)
|
196 |
+
return details
|
197 |
+
|
198 |
+
except Exception as e:
|
199 |
+
logger.error(
|
200 |
+
f"Failed to get API call details for period '{period}': {e}")
|
201 |
+
raise
|
202 |
+
|
203 |
+
async def get_key_usage_details_last_24h(self, key: str) -> Union[dict, None]:
|
204 |
+
"""
|
205 |
+
获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
|
206 |
+
|
207 |
+
Args:
|
208 |
+
key: 要查询的 API 密钥。
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
一个字典,其中键是模型名称,值是调用次数。
|
212 |
+
如果查询出错或没有找到记录,可能返回 None 或空字典。
|
213 |
+
Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
|
214 |
+
"""
|
215 |
+
logger.info(
|
216 |
+
f"Fetching usage details for key ending in ...{key[-4:]} for the last 24h."
|
217 |
+
)
|
218 |
+
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=24)
|
219 |
+
|
220 |
+
try:
|
221 |
+
query = (
|
222 |
+
select(
|
223 |
+
RequestLog.model_name, func.count(
|
224 |
+
RequestLog.id).label("call_count")
|
225 |
+
)
|
226 |
+
.where(
|
227 |
+
RequestLog.api_key == key,
|
228 |
+
RequestLog.request_time >= cutoff_time,
|
229 |
+
RequestLog.model_name.isnot(None),
|
230 |
+
)
|
231 |
+
.group_by(RequestLog.model_name)
|
232 |
+
.order_by(func.count(RequestLog.id).desc())
|
233 |
+
)
|
234 |
+
|
235 |
+
results = await database.fetch_all(query)
|
236 |
+
|
237 |
+
if not results:
|
238 |
+
logger.info(
|
239 |
+
f"No usage details found for key ending in ...{key[-4:]} in the last 24h."
|
240 |
+
)
|
241 |
+
return {}
|
242 |
+
|
243 |
+
usage_details = {row["model_name"]: row["call_count"]
|
244 |
+
for row in results}
|
245 |
+
logger.info(
|
246 |
+
f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}"
|
247 |
+
)
|
248 |
+
return usage_details
|
249 |
+
|
250 |
+
except Exception as e:
|
251 |
+
logger.error(
|
252 |
+
f"Failed to get key usage details for key ending in ...{key[-4:]}: {e}",
|
253 |
+
exc_info=True,
|
254 |
+
)
|
255 |
+
raise
|
app/service/tts/tts_service.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import io
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
import wave
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
from google import genai
|
9 |
+
|
10 |
+
from app.config.config import settings
|
11 |
+
from app.database.services import add_error_log, add_request_log
|
12 |
+
from app.domain.openai_models import TTSRequest
|
13 |
+
from app.log.logger import get_openai_logger
|
14 |
+
|
15 |
+
logger = get_openai_logger()
|
16 |
+
|
17 |
+
|
18 |
+
def _create_wav_file(audio_data: bytes) -> bytes:
|
19 |
+
"""Creates a WAV file in memory from raw audio data."""
|
20 |
+
with io.BytesIO() as wav_file:
|
21 |
+
with wave.open(wav_file, "wb") as wf:
|
22 |
+
wf.setnchannels(1) # Mono
|
23 |
+
wf.setsampwidth(2) # 16-bit
|
24 |
+
wf.setframerate(24000) # 24kHz sample rate
|
25 |
+
wf.writeframes(audio_data)
|
26 |
+
return wav_file.getvalue()
|
27 |
+
|
28 |
+
|
29 |
+
class TTSService:
|
30 |
+
async def create_tts(self, request: TTSRequest, api_key: str) -> Optional[bytes]:
|
31 |
+
"""
|
32 |
+
使用 Google Gemini SDK 创建音频。
|
33 |
+
"""
|
34 |
+
start_time = time.perf_counter()
|
35 |
+
request_datetime = datetime.datetime.now()
|
36 |
+
is_success = False
|
37 |
+
status_code = None
|
38 |
+
response = None
|
39 |
+
error_log_msg = ""
|
40 |
+
try:
|
41 |
+
client = genai.Client(api_key=api_key)
|
42 |
+
response =await client.aio.models.generate_content(
|
43 |
+
model=settings.TTS_MODEL,
|
44 |
+
contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
|
45 |
+
config={
|
46 |
+
"response_modalities": ["Audio"],
|
47 |
+
"speech_config": {
|
48 |
+
"voice_config": {
|
49 |
+
"prebuilt_voice_config": {
|
50 |
+
"voice_name": settings.TTS_VOICE_NAME
|
51 |
+
}
|
52 |
+
}
|
53 |
+
},
|
54 |
+
},
|
55 |
+
)
|
56 |
+
if (
|
57 |
+
response.candidates
|
58 |
+
and response.candidates[0].content.parts
|
59 |
+
and response.candidates[0].content.parts[0].inline_data
|
60 |
+
):
|
61 |
+
raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
|
62 |
+
is_success = True
|
63 |
+
status_code = 200
|
64 |
+
return _create_wav_file(raw_audio_data)
|
65 |
+
except Exception as e:
|
66 |
+
is_success = False
|
67 |
+
error_log_msg = f"Generic error: {e}"
|
68 |
+
logger.error(f"An error occurred in TTSService: {error_log_msg}")
|
69 |
+
match = re.search(r"status code (\d+)", str(e))
|
70 |
+
if match:
|
71 |
+
status_code = int(match.group(1))
|
72 |
+
else:
|
73 |
+
status_code = 500
|
74 |
+
raise
|
75 |
+
finally:
|
76 |
+
end_time = time.perf_counter()
|
77 |
+
latency_ms = int((end_time - start_time) * 1000)
|
78 |
+
if not is_success:
|
79 |
+
await add_error_log(
|
80 |
+
gemini_key=api_key,
|
81 |
+
model_name=settings.TTS_MODEL,
|
82 |
+
error_type="google-tts",
|
83 |
+
error_log=error_log_msg,
|
84 |
+
error_code=status_code,
|
85 |
+
request_msg=request.input
|
86 |
+
)
|
87 |
+
await add_request_log(
|
88 |
+
model_name=settings.TTS_MODEL,
|
89 |
+
api_key=api_key,
|
90 |
+
is_success=is_success,
|
91 |
+
status_code=status_code,
|
92 |
+
latency_ms=latency_ms,
|
93 |
+
request_time=request_datetime
|
94 |
+
)
|