CatPtain commited on
Commit
76b9762
·
verified ·
1 Parent(s): 7865ae2

Upload 77 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. Dockerfile +23 -0
  3. app/config/config.py +479 -0
  4. app/core/application.py +153 -0
  5. app/core/constants.py +79 -0
  6. app/core/security.py +90 -0
  7. app/database/__init__.py +3 -0
  8. app/database/connection.py +71 -0
  9. app/database/initialization.py +77 -0
  10. app/database/models.py +62 -0
  11. app/database/services.py +429 -0
  12. app/domain/gemini_models.py +79 -0
  13. app/domain/image_models.py +20 -0
  14. app/domain/openai_models.py +42 -0
  15. app/exception/exceptions.py +140 -0
  16. app/handler/error_handler.py +32 -0
  17. app/handler/message_converter.py +349 -0
  18. app/handler/response_handler.py +360 -0
  19. app/handler/retry_handler.py +50 -0
  20. app/handler/stream_optimizer.py +143 -0
  21. app/log/logger.py +233 -0
  22. app/main.py +15 -0
  23. app/middleware/middleware.py +80 -0
  24. app/middleware/request_logging_middleware.py +40 -0
  25. app/middleware/smart_routing_middleware.py +210 -0
  26. app/router/config_routes.py +133 -0
  27. app/router/error_log_routes.py +233 -0
  28. app/router/gemini_routes.py +374 -0
  29. app/router/openai_compatiable_routes.py +113 -0
  30. app/router/openai_routes.py +175 -0
  31. app/router/routes.py +187 -0
  32. app/router/scheduler_routes.py +57 -0
  33. app/router/stats_routes.py +55 -0
  34. app/router/version_routes.py +37 -0
  35. app/router/vertex_express_routes.py +146 -0
  36. app/scheduler/scheduled_tasks.py +159 -0
  37. app/service/chat/gemini_chat_service.py +287 -0
  38. app/service/chat/openai_chat_service.py +606 -0
  39. app/service/chat/vertex_express_chat_service.py +277 -0
  40. app/service/client/api_client.py +222 -0
  41. app/service/config/config_service.py +261 -0
  42. app/service/embedding/embedding_service.py +78 -0
  43. app/service/error_log/error_log_service.py +178 -0
  44. app/service/image/image_create_service.py +162 -0
  45. app/service/key/key_manager.py +463 -0
  46. app/service/model/model_service.py +92 -0
  47. app/service/openai_compatiable/openai_compatiable_service.py +190 -0
  48. app/service/request_log/request_log_service.py +50 -0
  49. app/service/stats/stats_service.py +255 -0
  50. app/service/tts/tts_service.py +94 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ files/image.png filter=lfs diff=lfs merge=lfs -text
37
+ files/image1.png filter=lfs diff=lfs merge=lfs -text
38
+ files/image2.png filter=lfs diff=lfs merge=lfs -text
39
+ files/image3.png filter=lfs diff=lfs merge=lfs -text
40
+ files/image4.png filter=lfs diff=lfs merge=lfs -text
41
+ files/image5.png filter=lfs diff=lfs merge=lfs -text
42
+ files/image6.png filter=lfs diff=lfs merge=lfs -text
43
+ files/image7.png filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # 复制所需文件到容器中
6
+ COPY ./requirements.txt /app
7
+ COPY ./VERSION /app
8
+
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+ COPY ./app /app/app
11
+ ENV API_KEYS='["your_api_key_1"]'
12
+ ENV ALLOWED_TOKENS='["your_token_1"]'
13
+ ENV BASE_URL=https://generativelanguage.googleapis.com/v1beta
14
+ ENV TOOLS_CODE_EXECUTION_ENABLED=false
15
+ ENV IMAGE_MODELS='["gemini-2.0-flash-exp"]'
16
+ ENV SEARCH_MODELS='["gemini-2.0-flash-exp","gemini-2.0-pro-exp"]'
17
+ ENV URL_NORMALIZATION_ENABLED=false
18
+
19
+ # Expose port
20
+ EXPOSE 7860
21
+
22
+ # Run the application
23
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--no-access-log"]
app/config/config.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 应用程序配置模块
3
+ """
4
+
5
+ import datetime
6
+ import json
7
+ from typing import Any, Dict, List, Type
8
+
9
+ from pydantic import ValidationError, ValidationInfo, field_validator
10
+ from pydantic_settings import BaseSettings
11
+ from sqlalchemy import insert, select, update
12
+
13
+ from app.core.constants import (
14
+ API_VERSION,
15
+ DEFAULT_CREATE_IMAGE_MODEL,
16
+ DEFAULT_FILTER_MODELS,
17
+ DEFAULT_MODEL,
18
+ DEFAULT_SAFETY_SETTINGS,
19
+ DEFAULT_STREAM_CHUNK_SIZE,
20
+ DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
21
+ DEFAULT_STREAM_MAX_DELAY,
22
+ DEFAULT_STREAM_MIN_DELAY,
23
+ DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
24
+ DEFAULT_TIMEOUT,
25
+ MAX_RETRIES,
26
+ )
27
+ from app.log.logger import Logger
28
+
29
+
30
+ class Settings(BaseSettings):
31
+ # 数据库配置
32
+ DATABASE_TYPE: str = "mysql" # sqlite 或 mysql
33
+ SQLITE_DATABASE: str = "default_db"
34
+ MYSQL_HOST: str = ""
35
+ MYSQL_PORT: int = 3306
36
+ MYSQL_USER: str = ""
37
+ MYSQL_PASSWORD: str = ""
38
+ MYSQL_DATABASE: str = ""
39
+ MYSQL_SOCKET: str = ""
40
+
41
+ # 验证 MySQL 配置
42
+ @field_validator(
43
+ "MYSQL_HOST", "MYSQL_PORT", "MYSQL_USER", "MYSQL_PASSWORD", "MYSQL_DATABASE"
44
+ )
45
+ def validate_mysql_config(cls, v: Any, info: ValidationInfo) -> Any:
46
+ if info.data.get("DATABASE_TYPE") == "mysql":
47
+ if v is None or v == "":
48
+ raise ValueError(
49
+ "MySQL configuration is required when DATABASE_TYPE is 'mysql'"
50
+ )
51
+ return v
52
+
53
+ # API相关配置
54
+ API_KEYS: List[str]
55
+ ALLOWED_TOKENS: List[str]
56
+ BASE_URL: str = f"https://generativelanguage.googleapis.com/{API_VERSION}"
57
+ AUTH_TOKEN: str = ""
58
+ MAX_FAILURES: int = 3
59
+ TEST_MODEL: str = DEFAULT_MODEL
60
+ TIME_OUT: int = DEFAULT_TIMEOUT
61
+ MAX_RETRIES: int = MAX_RETRIES
62
+ PROXIES: List[str] = []
63
+ PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: bool = True # 是否使用一致性哈希来选择代理
64
+ VERTEX_API_KEYS: List[str] = []
65
+ VERTEX_EXPRESS_BASE_URL: str = "https://aiplatform.googleapis.com/v1beta1/publishers/google"
66
+
67
+ # 智能路由配置
68
+ URL_NORMALIZATION_ENABLED: bool = False # 是否启用智能路由映射功能
69
+
70
+ # 模型相关配置
71
+ SEARCH_MODELS: List[str] = ["gemini-2.0-flash-exp"]
72
+ IMAGE_MODELS: List[str] = ["gemini-2.0-flash-exp"]
73
+ FILTERED_MODELS: List[str] = DEFAULT_FILTER_MODELS
74
+ TOOLS_CODE_EXECUTION_ENABLED: bool = False
75
+ SHOW_SEARCH_LINK: bool = True
76
+ SHOW_THINKING_PROCESS: bool = True
77
+ THINKING_MODELS: List[str] = []
78
+ THINKING_BUDGET_MAP: Dict[str, float] = {}
79
+
80
+ # TTS相关配置
81
+ TTS_MODEL: str = "gemini-2.5-flash-preview-tts"
82
+ TTS_VOICE_NAME: str = "Zephyr"
83
+ TTS_SPEED: str = "normal"
84
+
85
+ # 图像生成相关配置
86
+ PAID_KEY: str = ""
87
+ CREATE_IMAGE_MODEL: str = DEFAULT_CREATE_IMAGE_MODEL
88
+ UPLOAD_PROVIDER: str = "smms"
89
+ SMMS_SECRET_TOKEN: str = ""
90
+ PICGO_API_KEY: str = ""
91
+ CLOUDFLARE_IMGBED_URL: str = ""
92
+ CLOUDFLARE_IMGBED_AUTH_CODE: str = ""
93
+
94
+ # 流式输出优化器配置
95
+ STREAM_OPTIMIZER_ENABLED: bool = False
96
+ STREAM_MIN_DELAY: float = DEFAULT_STREAM_MIN_DELAY
97
+ STREAM_MAX_DELAY: float = DEFAULT_STREAM_MAX_DELAY
98
+ STREAM_SHORT_TEXT_THRESHOLD: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD
99
+ STREAM_LONG_TEXT_THRESHOLD: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD
100
+ STREAM_CHUNK_SIZE: int = DEFAULT_STREAM_CHUNK_SIZE
101
+
102
+ # 假流式配置 (Fake Streaming Configuration)
103
+ FAKE_STREAM_ENABLED: bool = False # 是否启用假流式输出
104
+ FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS: int = 5 # 假流式发送空数据的间隔时间(秒)
105
+
106
+ # 调度器配置
107
+ CHECK_INTERVAL_HOURS: int = 1 # 默认检查间隔为1小时
108
+ TIMEZONE: str = "Asia/Shanghai" # 默认时区
109
+
110
+ # github
111
+ GITHUB_REPO_OWNER: str = "snailyp"
112
+ GITHUB_REPO_NAME: str = "gemini-balance"
113
+
114
+ # 日志配置
115
+ LOG_LEVEL: str = "INFO"
116
+ AUTO_DELETE_ERROR_LOGS_ENABLED: bool = True
117
+ AUTO_DELETE_ERROR_LOGS_DAYS: int = 7
118
+ AUTO_DELETE_REQUEST_LOGS_ENABLED: bool = False
119
+ AUTO_DELETE_REQUEST_LOGS_DAYS: int = 30
120
+ SAFETY_SETTINGS: List[Dict[str, str]] = DEFAULT_SAFETY_SETTINGS
121
+
122
+
123
+ def __init__(self, **kwargs):
124
+ super().__init__(**kwargs)
125
+ # 设置默认AUTH_TOKEN(如果未提供)
126
+ if not self.AUTH_TOKEN and self.ALLOWED_TOKENS:
127
+ self.AUTH_TOKEN = self.ALLOWED_TOKENS[0]
128
+
129
+
130
+ # 创建全局配置实例
131
+ settings = Settings()
132
+
133
+
134
+ def _parse_db_value(key: str, db_value: str, target_type: Type) -> Any:
135
+ """尝试将数据库字符串值解析为目标 Python 类型"""
136
+ from app.log.logger import get_config_logger
137
+
138
+ logger = get_config_logger()
139
+ try:
140
+ # 处理 List[str]
141
+ if target_type == List[str]:
142
+ try:
143
+ parsed = json.loads(db_value)
144
+ if isinstance(parsed, list):
145
+ return [str(item) for item in parsed]
146
+ except json.JSONDecodeError:
147
+ return [item.strip() for item in db_value.split(",") if item.strip()]
148
+ logger.warning(
149
+ f"Could not parse '{db_value}' as List[str] for key '{key}', falling back to comma split or empty list."
150
+ )
151
+ return [item.strip() for item in db_value.split(",") if item.strip()]
152
+ # 处理 Dict[str, float]
153
+ elif target_type == Dict[str, float]:
154
+ parsed_dict = {}
155
+ try:
156
+ parsed = json.loads(db_value)
157
+ if isinstance(parsed, dict):
158
+ parsed_dict = {str(k): float(v) for k, v in parsed.items()}
159
+ else:
160
+ logger.warning(
161
+ f"Parsed DB value for key '{key}' is not a dictionary type. Value: {db_value}"
162
+ )
163
+ except (json.JSONDecodeError, ValueError, TypeError) as e1:
164
+ if isinstance(e1, json.JSONDecodeError) and "'" in db_value:
165
+ logger.warning(
166
+ f"Failed initial JSON parse for key '{key}'. Attempting to replace single quotes. Error: {e1}"
167
+ )
168
+ try:
169
+ corrected_db_value = db_value.replace("'", '"')
170
+ parsed = json.loads(corrected_db_value)
171
+ if isinstance(parsed, dict):
172
+ parsed_dict = {str(k): float(v) for k, v in parsed.items()}
173
+ else:
174
+ logger.warning(
175
+ f"Parsed DB value (after quote replacement) for key '{key}' is not a dictionary type. Value: {corrected_db_value}"
176
+ )
177
+ except (json.JSONDecodeError, ValueError, TypeError) as e2:
178
+ logger.error(
179
+ f"Could not parse '{db_value}' as Dict[str, float] for key '{key}' even after replacing quotes: {e2}. Returning empty dict."
180
+ )
181
+ else:
182
+ logger.error(
183
+ f"Could not parse '{db_value}' as Dict[str, float] for key '{key}': {e1}. Returning empty dict."
184
+ )
185
+ return parsed_dict
186
+ # 处理 List[Dict[str, str]]
187
+ elif target_type == List[Dict[str, str]]:
188
+ try:
189
+ parsed = json.loads(db_value)
190
+ if isinstance(parsed, list):
191
+ # 验证列表中的每个元素是否为字典,并且键和值都是字符串
192
+ valid = all(
193
+ isinstance(item, dict)
194
+ and all(isinstance(k, str) for k in item.keys())
195
+ and all(isinstance(v, str) for v in item.values())
196
+ for item in parsed
197
+ )
198
+ if valid:
199
+ return parsed
200
+ else:
201
+ logger.warning(
202
+ f"Invalid structure in List[Dict[str, str]] for key '{key}'. Value: {db_value}"
203
+ )
204
+ return []
205
+ else:
206
+ logger.warning(
207
+ f"Parsed DB value for key '{key}' is not a list type. Value: {db_value}"
208
+ )
209
+ return []
210
+ except json.JSONDecodeError:
211
+ logger.error(
212
+ f"Could not parse '{db_value}' as JSON for List[Dict[str, str]] for key '{key}'. Returning empty list."
213
+ )
214
+ return []
215
+ except Exception as e:
216
+ logger.error(
217
+ f"Error parsing List[Dict[str, str]] for key '{key}': {e}. Value: {db_value}. Returning empty list."
218
+ )
219
+ return []
220
+ # 处理 bool
221
+ elif target_type == bool:
222
+ return db_value.lower() in ("true", "1", "yes", "on")
223
+ # 处理 int
224
+ elif target_type == int:
225
+ return int(db_value)
226
+ # 处理 float
227
+ elif target_type == float:
228
+ return float(db_value)
229
+ # 默认为 str 或其他 pydantic 能直接处理的类型
230
+ else:
231
+ return db_value
232
+ except (ValueError, TypeError, json.JSONDecodeError) as e:
233
+ logger.warning(
234
+ f"Failed to parse db_value '{db_value}' for key '{key}' as type {target_type}: {e}. Using original string value."
235
+ )
236
+ return db_value # 解析失败则返回原始字符串
237
+
238
+
239
+ async def sync_initial_settings():
240
+ """
241
+ 应用启动时同步配置:
242
+ 1. 从数据库加载设置。
243
+ 2. 将数据库设置合并到内存 settings (数据库优先)。
244
+ 3. 将最终的内存 settings 同步回数据库。
245
+ """
246
+ from app.log.logger import get_config_logger
247
+
248
+ logger = get_config_logger()
249
+ # 延迟导入以避免循环依赖和确保数据库连接已初始化
250
+ from app.database.connection import database
251
+ from app.database.models import Settings as SettingsModel
252
+
253
+ global settings
254
+ logger.info("Starting initial settings synchronization...")
255
+
256
+ if not database.is_connected:
257
+ try:
258
+ await database.connect()
259
+ logger.info("Database connection established for initial sync.")
260
+ except Exception as e:
261
+ logger.error(
262
+ f"Failed to connect to database for initial settings sync: {e}. Skipping sync."
263
+ )
264
+ return
265
+
266
+ try:
267
+ # 1. 从数据库加载设置
268
+ db_settings_raw: List[Dict[str, Any]] = []
269
+ try:
270
+ query = select(SettingsModel.key, SettingsModel.value)
271
+ results = await database.fetch_all(query)
272
+ db_settings_raw = [
273
+ {"key": row["key"], "value": row["value"]} for row in results
274
+ ]
275
+ logger.info(f"Fetched {len(db_settings_raw)} settings from database.")
276
+ except Exception as e:
277
+ logger.error(
278
+ f"Failed to fetch settings from database: {e}. Proceeding with environment/dotenv settings."
279
+ )
280
+ # 即使数据库读取失败,也要继续执行,确保基于 env/dotenv 的配置能同步到数据库
281
+
282
+ db_settings_map: Dict[str, str] = {
283
+ s["key"]: s["value"] for s in db_settings_raw
284
+ }
285
+
286
+ # 2. 将数据库设置合并到内存 settings (数据库优先)
287
+ updated_in_memory = False
288
+
289
+ for key, db_value in db_settings_map.items():
290
+ if key == "DATABASE_TYPE":
291
+ logger.debug(
292
+ f"Skipping update of '{key}' in memory from database. "
293
+ "This setting is controlled by environment/dotenv."
294
+ )
295
+ continue
296
+ if hasattr(settings, key):
297
+ target_type = Settings.__annotations__.get(key)
298
+ if target_type:
299
+ try:
300
+ parsed_db_value = _parse_db_value(key, db_value, target_type)
301
+ memory_value = getattr(settings, key)
302
+
303
+ # 比较解析后的值和内存中的值
304
+ # 注意:对于列表等复杂类型,直接比较可能不够健壮,但这里简化处理
305
+ if parsed_db_value != memory_value:
306
+ # 检查类型是否匹配,以防解析函数返回了不兼容的类型
307
+ type_match = False
308
+ if target_type == List[str] and isinstance(
309
+ parsed_db_value, list
310
+ ):
311
+ type_match = True
312
+ elif target_type == Dict[str, float] and isinstance(
313
+ parsed_db_value, dict
314
+ ):
315
+ type_match = True
316
+ elif target_type not in (
317
+ List[str],
318
+ Dict[str, float],
319
+ ) and isinstance(parsed_db_value, target_type):
320
+ type_match = True
321
+
322
+ if type_match:
323
+ setattr(settings, key, parsed_db_value)
324
+ logger.debug(
325
+ f"Updated setting '{key}' in memory from database value ({target_type})."
326
+ )
327
+ updated_in_memory = True
328
+ else:
329
+ logger.warning(
330
+ f"Parsed DB value type mismatch for key '{key}'. Expected {target_type}, got {type(parsed_db_value)}. Skipping update."
331
+ )
332
+
333
+ except Exception as e:
334
+ logger.error(
335
+ f"Error processing database setting for key '{key}': {e}"
336
+ )
337
+ else:
338
+ logger.warning(
339
+ f"Database setting '{key}' not found in Settings model definition. Ignoring."
340
+ )
341
+
342
+ # 如果内存中有更新,重新验证 Pydantic 模型(可选但推荐)
343
+ if updated_in_memory:
344
+ try:
345
+ # 重新加载以确保类型转换和验证
346
+ settings = Settings(**settings.model_dump())
347
+ logger.info(
348
+ "Settings object re-validated after merging database values."
349
+ )
350
+ except ValidationError as e:
351
+ logger.error(
352
+ f"Validation error after merging database settings: {e}. Settings might be inconsistent."
353
+ )
354
+
355
+ # 3. 将最终的内存 settings 同步回数据库
356
+ final_memory_settings = settings.model_dump()
357
+ settings_to_update: List[Dict[str, Any]] = []
358
+ settings_to_insert: List[Dict[str, Any]] = []
359
+ now = datetime.datetime.now(datetime.timezone.utc)
360
+
361
+ existing_db_keys = set(db_settings_map.keys())
362
+
363
+ for key, value in final_memory_settings.items():
364
+ if key == "DATABASE_TYPE":
365
+ logger.debug(
366
+ f"Skipping synchronization of '{key}' to database. "
367
+ "This setting is controlled by environment/dotenv."
368
+ )
369
+ continue
370
+
371
+ # 序列化值为字符串或 JSON 字符串
372
+ if isinstance(value, (list, dict)):
373
+ db_value = json.dumps(
374
+ value, ensure_ascii=False
375
+ )
376
+ elif isinstance(value, bool):
377
+ db_value = str(value).lower()
378
+ elif value is None:
379
+ db_value = ""
380
+ else:
381
+ db_value = str(value)
382
+
383
+ data = {
384
+ "key": key,
385
+ "value": db_value,
386
+ "description": f"{key} configuration setting",
387
+ "updated_at": now,
388
+ }
389
+
390
+ if key in existing_db_keys:
391
+ # 仅当值与数据库中的不同时才更新
392
+ if db_settings_map[key] != db_value:
393
+ settings_to_update.append(data)
394
+ else:
395
+ # 如果键不在数据库中,则插入
396
+ data["created_at"] = now
397
+ settings_to_insert.append(data)
398
+
399
+ # 在事务中执行批量插入和更新
400
+ if settings_to_insert or settings_to_update:
401
+ try:
402
+ async with database.transaction():
403
+ if settings_to_insert:
404
+ # 获取现有描述以避免覆盖
405
+ query_existing = select(
406
+ SettingsModel.key, SettingsModel.description
407
+ ).where(
408
+ SettingsModel.key.in_(
409
+ [s["key"] for s in settings_to_insert]
410
+ )
411
+ )
412
+ existing_desc = {
413
+ row["key"]: row["description"]
414
+ for row in await database.fetch_all(query_existing)
415
+ }
416
+ for item in settings_to_insert:
417
+ item["description"] = existing_desc.get(
418
+ item["key"], item["description"]
419
+ )
420
+
421
+ query_insert = insert(SettingsModel).values(settings_to_insert)
422
+ await database.execute(query=query_insert)
423
+ logger.info(
424
+ f"Synced (inserted) {len(settings_to_insert)} settings to database."
425
+ )
426
+
427
+ if settings_to_update:
428
+ # 获取现有描述以避免覆盖
429
+ query_existing = select(
430
+ SettingsModel.key, SettingsModel.description
431
+ ).where(
432
+ SettingsModel.key.in_(
433
+ [s["key"] for s in settings_to_update]
434
+ )
435
+ )
436
+ existing_desc = {
437
+ row["key"]: row["description"]
438
+ for row in await database.fetch_all(query_existing)
439
+ }
440
+
441
+ for setting_data in settings_to_update:
442
+ setting_data["description"] = existing_desc.get(
443
+ setting_data["key"], setting_data["description"]
444
+ )
445
+ query_update = (
446
+ update(SettingsModel)
447
+ .where(SettingsModel.key == setting_data["key"])
448
+ .values(
449
+ value=setting_data["value"],
450
+ description=setting_data["description"],
451
+ updated_at=setting_data["updated_at"],
452
+ )
453
+ )
454
+ await database.execute(query=query_update)
455
+ logger.info(
456
+ f"Synced (updated) {len(settings_to_update)} settings to database."
457
+ )
458
+ except Exception as e:
459
+ logger.error(
460
+ f"Failed to sync settings to database during startup: {str(e)}"
461
+ )
462
+ else:
463
+ logger.info(
464
+ "No setting changes detected between memory and database during initial sync."
465
+ )
466
+
467
+ # 刷新日志等级
468
+ Logger.update_log_levels(final_memory_settings.get("LOG_LEVEL"))
469
+
470
+ except Exception as e:
471
+ logger.error(f"An unexpected error occurred during initial settings sync: {e}")
472
+ finally:
473
+ if database.is_connected:
474
+ try:
475
+ pass
476
+ except Exception as e:
477
+ logger.error(f"Error disconnecting database after initial sync: {e}")
478
+
479
+ logger.info("Initial settings synchronization finished.")
app/core/application.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ from pathlib import Path
3
+
4
+ from fastapi import FastAPI
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.templating import Jinja2Templates
7
+
8
+ from app.config.config import settings, sync_initial_settings
9
+ from app.database.connection import connect_to_db, disconnect_from_db
10
+ from app.database.initialization import initialize_database
11
+ from app.exception.exceptions import setup_exception_handlers
12
+ from app.log.logger import get_application_logger
13
+ from app.middleware.middleware import setup_middlewares
14
+ from app.router.routes import setup_routers
15
+ from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
16
+ from app.service.key.key_manager import get_key_manager_instance
17
+ from app.service.update.update_service import check_for_updates
18
+ from app.utils.helpers import get_current_version
19
+
20
+ logger = get_application_logger()
21
+
22
+ PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
23
+ STATIC_DIR = PROJECT_ROOT / "app" / "static"
24
+ TEMPLATES_DIR = PROJECT_ROOT / "app" / "templates"
25
+
26
+ # 初始化模板引擎,并添加全局变量
27
+ templates = Jinja2Templates(directory="app/templates")
28
+
29
+
30
+ # 定义一个函数来更新模板全局变量
31
+ def update_template_globals(app: FastAPI, update_info: dict):
32
+ # Jinja2Templates 实例没有直接更新全局变量的方法
33
+ # 我们需要在请求上下文中传递这些变量,或者修改 Jinja 环境
34
+ # 更简单的方法是将其存储在 app.state 中,并在渲染时传递
35
+ app.state.update_info = update_info
36
+ logger.info(f"Update info stored in app.state: {update_info}")
37
+
38
+
39
+ # --- Helper functions for lifespan ---
40
+ async def _setup_database_and_config(app_settings):
41
+ """Initializes database, syncs settings, and initializes KeyManager."""
42
+ initialize_database()
43
+ logger.info("Database initialized successfully")
44
+ await connect_to_db()
45
+ await sync_initial_settings()
46
+ await get_key_manager_instance(app_settings.API_KEYS, app_settings.VERTEX_API_KEYS)
47
+ logger.info("Database, config sync, and KeyManager initialized successfully")
48
+
49
+
50
+ async def _shutdown_database():
51
+ """Disconnects from the database."""
52
+ await disconnect_from_db()
53
+
54
+
55
+ def _start_scheduler():
56
+ """Starts the background scheduler."""
57
+ try:
58
+ start_scheduler()
59
+ logger.info("Scheduler started successfully.")
60
+ except Exception as e:
61
+ logger.error(f"Failed to start scheduler: {e}")
62
+
63
+
64
+ def _stop_scheduler():
65
+ """Stops the background scheduler."""
66
+ stop_scheduler()
67
+
68
+
69
+ async def _perform_update_check(app: FastAPI):
70
+ """Checks for updates and stores the info in app.state."""
71
+ update_available, latest_version, error_message = await check_for_updates()
72
+ current_version = get_current_version()
73
+ update_info = {
74
+ "update_available": update_available,
75
+ "latest_version": latest_version,
76
+ "error_message": error_message,
77
+ "current_version": current_version,
78
+ }
79
+ if not hasattr(app, "state"):
80
+ from starlette.datastructures import State
81
+
82
+ app.state = State()
83
+ app.state.update_info = update_info
84
+ logger.info(f"Update check completed. Info: {update_info}")
85
+
86
+
87
+ @asynccontextmanager
88
+ async def lifespan(app: FastAPI):
89
+ """
90
+ Manages the application startup and shutdown events.
91
+
92
+ Args:
93
+ app: FastAPI应用实例
94
+ """
95
+ logger.info("Application starting up...")
96
+ try:
97
+ await _setup_database_and_config(settings)
98
+ await _perform_update_check(app)
99
+ _start_scheduler()
100
+
101
+ except Exception as e:
102
+ logger.critical(
103
+ f"Critical error during application startup: {str(e)}", exc_info=True
104
+ )
105
+
106
+ yield
107
+
108
+ logger.info("Application shutting down...")
109
+ _stop_scheduler()
110
+ await _shutdown_database()
111
+
112
+
113
+ def create_app() -> FastAPI:
114
+ """
115
+ 创建并配置FastAPI应用程序实例
116
+
117
+ Returns:
118
+ FastAPI: 配置好的FastAPI应用程序实例
119
+ """
120
+
121
+ # 创建FastAPI应用
122
+ current_version = get_current_version()
123
+ app = FastAPI(
124
+ title="Gemini Balance API",
125
+ description="Gemini API代理服务,支持负载均衡和密钥管理",
126
+ version=current_version,
127
+ lifespan=lifespan,
128
+ )
129
+
130
+ if not hasattr(app, "state"):
131
+ from starlette.datastructures import State
132
+
133
+ app.state = State()
134
+ app.state.update_info = {
135
+ "update_available": False,
136
+ "latest_version": None,
137
+ "error_message": "Initializing...",
138
+ "current_version": current_version,
139
+ }
140
+
141
+ # 配置静态文件
142
+ app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
143
+
144
+ # 配置中间件
145
+ setup_middlewares(app)
146
+
147
+ # 配置异常处理器
148
+ setup_exception_handlers(app)
149
+
150
+ # 配置路由
151
+ setup_routers(app)
152
+
153
+ return app
app/core/constants.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 常量定义模块
3
+ """
4
+
5
+ # API相关常量
6
+ API_VERSION = "v1beta"
7
+ DEFAULT_TIMEOUT = 300 # 秒
8
+ MAX_RETRIES = 3 # 最大重试次数
9
+
10
+ # 模型相关常量
11
+ SUPPORTED_ROLES = ["user", "model", "system"]
12
+ DEFAULT_MODEL = "gemini-1.5-flash"
13
+ DEFAULT_TEMPERATURE = 0.7
14
+ DEFAULT_MAX_TOKENS = 8192
15
+ DEFAULT_TOP_P = 0.9
16
+ DEFAULT_TOP_K = 40
17
+ DEFAULT_FILTER_MODELS = [
18
+ "gemini-1.0-pro-vision-latest",
19
+ "gemini-pro-vision",
20
+ "chat-bison-001",
21
+ "text-bison-001",
22
+ "embedding-gecko-001"
23
+ ]
24
+ DEFAULT_CREATE_IMAGE_MODEL = "imagen-3.0-generate-002"
25
+
26
+ # 图像生成相关常量
27
+ VALID_IMAGE_RATIOS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
28
+
29
+ # 上传提供商
30
+ UPLOAD_PROVIDERS = ["smms", "picgo", "cloudflare_imgbed"]
31
+ DEFAULT_UPLOAD_PROVIDER = "smms"
32
+
33
+ # 流式输出相关常量
34
+ DEFAULT_STREAM_MIN_DELAY = 0.016
35
+ DEFAULT_STREAM_MAX_DELAY = 0.024
36
+ DEFAULT_STREAM_SHORT_TEXT_THRESHOLD = 10
37
+ DEFAULT_STREAM_LONG_TEXT_THRESHOLD = 50
38
+ DEFAULT_STREAM_CHUNK_SIZE = 5
39
+
40
+ # 正则表达式模式
41
+ IMAGE_URL_PATTERN = r'!\[(.*?)\]\((.*?)\)'
42
+ DATA_URL_PATTERN = r'data:([^;]+);base64,(.+)'
43
+
44
+ # Audio/Video Settings
45
+ SUPPORTED_AUDIO_FORMATS = ["wav", "mp3", "flac", "ogg"]
46
+ SUPPORTED_VIDEO_FORMATS = ["mp4", "mov", "avi", "webm"]
47
+ MAX_AUDIO_SIZE_BYTES = 50 * 1024 * 1024 # Example: 50MB limit for Base64 payload
48
+ MAX_VIDEO_SIZE_BYTES = 200 * 1024 * 1024 # Example: 200MB limit
49
+
50
+ # Optional: Define MIME type mappings if needed, or handle directly in converter
51
+ AUDIO_FORMAT_TO_MIMETYPE = {
52
+ "wav": "audio/wav",
53
+ "mp3": "audio/mpeg",
54
+ "flac": "audio/flac",
55
+ "ogg": "audio/ogg",
56
+ }
57
+
58
+ VIDEO_FORMAT_TO_MIMETYPE = {
59
+ "mp4": "video/mp4",
60
+ "mov": "video/quicktime",
61
+ "avi": "video/x-msvideo",
62
+ "webm": "video/webm",
63
+ }
64
+
65
+ GEMINI_2_FLASH_EXP_SAFETY_SETTINGS = [
66
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
67
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
68
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
69
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
70
+ {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "OFF"},
71
+ ]
72
+
73
+ DEFAULT_SAFETY_SETTINGS = [
74
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
75
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
76
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
77
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
78
+ {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"},
79
+ ]
app/core/security.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from fastapi import Header, HTTPException
4
+
5
+ from app.config.config import settings
6
+ from app.log.logger import get_security_logger
7
+
8
+ logger = get_security_logger()
9
+
10
+
11
+ def verify_auth_token(token: str) -> bool:
12
+ return token == settings.AUTH_TOKEN
13
+
14
+
15
+ class SecurityService:
16
+
17
+ async def verify_key(self, key: str):
18
+ if key not in settings.ALLOWED_TOKENS and key != settings.AUTH_TOKEN:
19
+ logger.error("Invalid key")
20
+ raise HTTPException(status_code=401, detail="Invalid key")
21
+ return key
22
+
23
+ async def verify_authorization(
24
+ self, authorization: Optional[str] = Header(None)
25
+ ) -> str:
26
+ if not authorization:
27
+ logger.error("Missing Authorization header")
28
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
29
+
30
+ if not authorization.startswith("Bearer "):
31
+ logger.error("Invalid Authorization header format")
32
+ raise HTTPException(
33
+ status_code=401, detail="Invalid Authorization header format"
34
+ )
35
+
36
+ token = authorization.replace("Bearer ", "")
37
+ if token not in settings.ALLOWED_TOKENS and token != settings.AUTH_TOKEN:
38
+ logger.error("Invalid token")
39
+ raise HTTPException(status_code=401, detail="Invalid token")
40
+
41
+ return token
42
+
43
+ async def verify_goog_api_key(
44
+ self, x_goog_api_key: Optional[str] = Header(None)
45
+ ) -> str:
46
+ """验证Google API Key"""
47
+ if not x_goog_api_key:
48
+ logger.error("Missing x-goog-api-key header")
49
+ raise HTTPException(status_code=401, detail="Missing x-goog-api-key header")
50
+
51
+ if (
52
+ x_goog_api_key not in settings.ALLOWED_TOKENS
53
+ and x_goog_api_key != settings.AUTH_TOKEN
54
+ ):
55
+ logger.error("Invalid x-goog-api-key")
56
+ raise HTTPException(status_code=401, detail="Invalid x-goog-api-key")
57
+
58
+ return x_goog_api_key
59
+
60
+ async def verify_auth_token(
61
+ self, authorization: Optional[str] = Header(None)
62
+ ) -> str:
63
+ if not authorization:
64
+ logger.error("Missing auth_token header")
65
+ raise HTTPException(status_code=401, detail="Missing auth_token header")
66
+ token = authorization.replace("Bearer ", "")
67
+ if token != settings.AUTH_TOKEN:
68
+ logger.error("Invalid auth_token")
69
+ raise HTTPException(status_code=401, detail="Invalid auth_token")
70
+
71
+ return token
72
+
73
+ async def verify_key_or_goog_api_key(
74
+ self, key: Optional[str] = None , x_goog_api_key: Optional[str] = Header(None)
75
+ ) -> str:
76
+ """验证URL中的key或请求头中的x-goog-api-key"""
77
+ # 如果URL中的key有效,直接返回
78
+ if key in settings.ALLOWED_TOKENS or key == settings.AUTH_TOKEN:
79
+ return key
80
+
81
+ # 否则检查请求头中的x-goog-api-key
82
+ if not x_goog_api_key:
83
+ logger.error("Invalid key and missing x-goog-api-key header")
84
+ raise HTTPException(status_code=401, detail="Invalid key and missing x-goog-api-key header")
85
+
86
+ if x_goog_api_key not in settings.ALLOWED_TOKENS and x_goog_api_key != settings.AUTH_TOKEN:
87
+ logger.error("Invalid key and invalid x-goog-api-key")
88
+ raise HTTPException(status_code=401, detail="Invalid key and invalid x-goog-api-key")
89
+
90
+ return x_goog_api_key
app/database/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ 数据库模块
3
+ """
app/database/connection.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 数据库连接池模块
3
+ """
4
+ from pathlib import Path
5
+ from urllib.parse import quote_plus
6
+ from databases import Database
7
+ from sqlalchemy import create_engine, MetaData
8
+ from sqlalchemy.ext.declarative import declarative_base
9
+
10
+ from app.config.config import settings
11
+ from app.log.logger import get_database_logger
12
+
13
+ logger = get_database_logger()
14
+
15
+ # 数据库URL
16
+ if settings.DATABASE_TYPE == "sqlite":
17
+ # 确保 data 目录存在
18
+ data_dir = Path("data")
19
+ data_dir.mkdir(exist_ok=True)
20
+ db_path = data_dir / settings.SQLITE_DATABASE
21
+ DATABASE_URL = f"sqlite:///{db_path}"
22
+ elif settings.DATABASE_TYPE == "mysql":
23
+ if settings.MYSQL_SOCKET:
24
+ DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@/{settings.MYSQL_DATABASE}?unix_socket={settings.MYSQL_SOCKET}"
25
+ else:
26
+ DATABASE_URL = f"mysql+pymysql://{settings.MYSQL_USER}:{quote_plus(settings.MYSQL_PASSWORD)}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/{settings.MYSQL_DATABASE}"
27
+ else:
28
+ raise ValueError("Unsupported database type. Please set DATABASE_TYPE to 'sqlite' or 'mysql'.")
29
+
30
+ # 创建数据库引擎
31
+ # pool_pre_ping=True: 在从连接池获取连接前执行简单的 "ping" 测试,确保连接有效
32
+ engine = create_engine(DATABASE_URL, pool_pre_ping=True)
33
+
34
+ # 创建元数据对象
35
+ metadata = MetaData()
36
+
37
+ # 创建基类
38
+ Base = declarative_base(metadata=metadata)
39
+
40
+ # 创建数据库连接池,并配置连接池参数,在sqlite中不使用连接池
41
+ # min_size/max_size: 连接池的最小/最大连接数
42
+ # pool_recycle=3600: 连接在池中允许存在的最大秒数(生命周期)。
43
+ # 设置为 3600 秒(1小时),确保在 MySQL 默认的 wait_timeout (通常8小时) 或其他网络超时之前回收连接。
44
+ # 如果遇到连接失效问题,可以尝试调低此值,使其小于实际的 wait_timeout 或网络超时时间。
45
+ # databases 库会自动处理连接失效后的重连尝试。
46
+ if settings.DATABASE_TYPE == "sqlite":
47
+ database = Database(DATABASE_URL)
48
+ else:
49
+ database = Database(DATABASE_URL, min_size=5, max_size=20, pool_recycle=1800)
50
+
51
+ async def connect_to_db():
52
+ """
53
+ 连接到数据库
54
+ """
55
+ try:
56
+ await database.connect()
57
+ logger.info(f"Connected to {settings.DATABASE_TYPE}")
58
+ except Exception as e:
59
+ logger.error(f"Failed to connect to database: {str(e)}")
60
+ raise
61
+
62
+
63
+ async def disconnect_from_db():
64
+ """
65
+ 断开数据库连接
66
+ """
67
+ try:
68
+ await database.disconnect()
69
+ logger.info(f"Disconnected from {settings.DATABASE_TYPE}")
70
+ except Exception as e:
71
+ logger.error(f"Failed to disconnect from database: {str(e)}")
app/database/initialization.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 数据库初始化模块
3
+ """
4
+ from dotenv import dotenv_values
5
+
6
+ from sqlalchemy import inspect
7
+ from sqlalchemy.orm import Session
8
+
9
+ from app.database.connection import engine, Base
10
+ from app.database.models import Settings
11
+ from app.log.logger import get_database_logger
12
+
13
+ logger = get_database_logger()
14
+
15
+
16
+ def create_tables():
17
+ """
18
+ 创建数据库表
19
+ """
20
+ try:
21
+ # 创建所有表
22
+ Base.metadata.create_all(engine)
23
+ logger.info("Database tables created successfully")
24
+ except Exception as e:
25
+ logger.error(f"Failed to create database tables: {str(e)}")
26
+ raise
27
+
28
+
29
+ def import_env_to_settings():
30
+ """
31
+ 将.env文件中的配置项导入到t_settings表中
32
+ """
33
+ try:
34
+ # 获取.env文件中的所有配置项
35
+ env_values = dotenv_values(".env")
36
+
37
+ # 获取检查器
38
+ inspector = inspect(engine)
39
+
40
+ # 检查t_settings表是否存在
41
+ if "t_settings" in inspector.get_table_names():
42
+ # 使用Session进行数据库操作
43
+ with Session(engine) as session:
44
+ # 获取所有现有的配置项
45
+ current_settings = {setting.key: setting for setting in session.query(Settings).all()}
46
+
47
+ # 遍历所有配置项
48
+ for key, value in env_values.items():
49
+ # 检查配置项是否已存在
50
+ if key not in current_settings:
51
+ # 插入配置项
52
+ new_setting = Settings(key=key, value=value)
53
+ session.add(new_setting)
54
+ logger.info(f"Inserted setting: {key}")
55
+
56
+ # 提交事务
57
+ session.commit()
58
+
59
+ logger.info("Environment variables imported to settings table successfully")
60
+ except Exception as e:
61
+ logger.error(f"Failed to import environment variables to settings table: {str(e)}")
62
+ raise
63
+
64
+
65
+ def initialize_database():
66
+ """
67
+ 初始化数据库
68
+ """
69
+ try:
70
+ # 创建表
71
+ create_tables()
72
+
73
+ # 导入环境变量
74
+ import_env_to_settings()
75
+ except Exception as e:
76
+ logger.error(f"Failed to initialize database: {str(e)}")
77
+ raise
app/database/models.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 数据库模型模块
3
+ """
4
+ import datetime
5
+ from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
6
+
7
+ from app.database.connection import Base
8
+
9
+
10
+ class Settings(Base):
11
+ """
12
+ 设置表,对应.env中的配置项
13
+ """
14
+ __tablename__ = "t_settings"
15
+
16
+ id = Column(Integer, primary_key=True, autoincrement=True)
17
+ key = Column(String(100), nullable=False, unique=True, comment="配置项键名")
18
+ value = Column(Text, nullable=True, comment="配置项值")
19
+ description = Column(String(255), nullable=True, comment="配置项描述")
20
+ created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
21
+ updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
22
+
23
+ def __repr__(self):
24
+ return f"<Settings(key='{self.key}', value='{self.value}')>"
25
+
26
+
27
+ class ErrorLog(Base):
28
+ """
29
+ 错误日志表
30
+ """
31
+ __tablename__ = "t_error_logs"
32
+
33
+ id = Column(Integer, primary_key=True, autoincrement=True)
34
+ gemini_key = Column(String(100), nullable=True, comment="Gemini API密钥")
35
+ model_name = Column(String(100), nullable=True, comment="模型名称")
36
+ error_type = Column(String(50), nullable=True, comment="错误类型")
37
+ error_log = Column(Text, nullable=True, comment="错误日志")
38
+ error_code = Column(Integer, nullable=True, comment="错误代码")
39
+ request_msg = Column(JSON, nullable=True, comment="请求消息")
40
+ request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
41
+
42
+ def __repr__(self):
43
+ return f"<ErrorLog(id='{self.id}', gemini_key='{self.gemini_key}')>"
44
+
45
+
46
+ class RequestLog(Base):
47
+ """
48
+ API 请求日志表
49
+ """
50
+
51
+ __tablename__ = "t_request_log"
52
+
53
+ id = Column(Integer, primary_key=True, autoincrement=True)
54
+ request_time = Column(DateTime, default=datetime.datetime.now, comment="请求时间")
55
+ model_name = Column(String(100), nullable=True, comment="模型名称")
56
+ api_key = Column(String(100), nullable=True, comment="使用的API密钥")
57
+ is_success = Column(Boolean, nullable=False, comment="请求是否成功")
58
+ status_code = Column(Integer, nullable=True, comment="API响应状态码")
59
+ latency_ms = Column(Integer, nullable=True, comment="请求耗时(毫秒)")
60
+
61
+ def __repr__(self):
62
+ return f"<RequestLog(id='{self.id}', key='{self.api_key[:4]}...', success='{self.is_success}')>"
app/database/services.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 数据库服务模块
3
+ """
4
+ from typing import List, Optional, Dict, Any, Union
5
+ from datetime import datetime
6
+ from sqlalchemy import func, desc, asc, select, insert, update, delete
7
+ import json
8
+ from app.database.connection import database
9
+ from app.database.models import Settings, ErrorLog, RequestLog
10
+ from app.log.logger import get_database_logger
11
+
12
+ logger = get_database_logger()
13
+
14
+
15
+ async def get_all_settings() -> List[Dict[str, Any]]:
16
+ """
17
+ 获取所有设置
18
+
19
+ Returns:
20
+ List[Dict[str, Any]]: 设置列表
21
+ """
22
+ try:
23
+ query = select(Settings)
24
+ result = await database.fetch_all(query)
25
+ return [dict(row) for row in result]
26
+ except Exception as e:
27
+ logger.error(f"Failed to get all settings: {str(e)}")
28
+ raise
29
+
30
+
31
+ async def get_setting(key: str) -> Optional[Dict[str, Any]]:
32
+ """
33
+ 获取指定键的设置
34
+
35
+ Args:
36
+ key: 设置键名
37
+
38
+ Returns:
39
+ Optional[Dict[str, Any]]: 设置信息,如果不存在则返回None
40
+ """
41
+ try:
42
+ query = select(Settings).where(Settings.key == key)
43
+ result = await database.fetch_one(query)
44
+ return dict(result) if result else None
45
+ except Exception as e:
46
+ logger.error(f"Failed to get setting {key}: {str(e)}")
47
+ raise
48
+
49
+
50
+ async def update_setting(key: str, value: str, description: Optional[str] = None) -> bool:
51
+ """
52
+ 更新设置
53
+
54
+ Args:
55
+ key: 设置键名
56
+ value: 设置值
57
+ description: 设置描述
58
+
59
+ Returns:
60
+ bool: 是否更新成功
61
+ """
62
+ try:
63
+ # 检查设置是否存在
64
+ setting = await get_setting(key)
65
+
66
+ if setting:
67
+ # 更新设置
68
+ query = (
69
+ update(Settings)
70
+ .where(Settings.key == key)
71
+ .values(
72
+ value=value,
73
+ description=description if description else setting["description"],
74
+ updated_at=datetime.now()
75
+ )
76
+ )
77
+ await database.execute(query)
78
+ logger.info(f"Updated setting: {key}")
79
+ return True
80
+ else:
81
+ # 插入设置
82
+ query = (
83
+ insert(Settings)
84
+ .values(
85
+ key=key,
86
+ value=value,
87
+ description=description,
88
+ created_at=datetime.now(),
89
+ updated_at=datetime.now()
90
+ )
91
+ )
92
+ await database.execute(query)
93
+ logger.info(f"Inserted setting: {key}")
94
+ return True
95
+ except Exception as e:
96
+ logger.error(f"Failed to update setting {key}: {str(e)}")
97
+ return False
98
+
99
+
100
+ async def add_error_log(
101
+ gemini_key: Optional[str] = None,
102
+ model_name: Optional[str] = None,
103
+ error_type: Optional[str] = None,
104
+ error_log: Optional[str] = None,
105
+ error_code: Optional[int] = None,
106
+ request_msg: Optional[Union[Dict[str, Any], str]] = None
107
+ ) -> bool:
108
+ """
109
+ 添加错误日志
110
+
111
+ Args:
112
+ gemini_key: Gemini API密钥
113
+ error_log: 错误日志
114
+ error_code: 错误代码 (例如 HTTP 状态码)
115
+ request_msg: 请求消息
116
+
117
+ Returns:
118
+ bool: 是否添加成功
119
+ """
120
+ try:
121
+ # 如果request_msg是字典,则转换为JSON字符串
122
+ if isinstance(request_msg, dict):
123
+ request_msg_json = request_msg
124
+ elif isinstance(request_msg, str):
125
+ try:
126
+ request_msg_json = json.loads(request_msg)
127
+ except json.JSONDecodeError:
128
+ request_msg_json = {"message": request_msg}
129
+ else:
130
+ request_msg_json = None
131
+
132
+ # 插入错误日志
133
+ query = (
134
+ insert(ErrorLog)
135
+ .values(
136
+ gemini_key=gemini_key,
137
+ error_type=error_type,
138
+ error_log=error_log,
139
+ model_name=model_name,
140
+ error_code=error_code,
141
+ request_msg=request_msg_json,
142
+ request_time=datetime.now()
143
+ )
144
+ )
145
+ await database.execute(query)
146
+ logger.info(f"Added error log for key: {gemini_key}")
147
+ return True
148
+ except Exception as e:
149
+ logger.error(f"Failed to add error log: {str(e)}")
150
+ return False
151
+
152
+
153
+ async def get_error_logs(
154
+ limit: int = 20,
155
+ offset: int = 0,
156
+ key_search: Optional[str] = None,
157
+ error_search: Optional[str] = None,
158
+ error_code_search: Optional[str] = None,
159
+ start_date: Optional[datetime] = None,
160
+ end_date: Optional[datetime] = None,
161
+ sort_by: str = 'id',
162
+ sort_order: str = 'desc'
163
+ ) -> List[Dict[str, Any]]:
164
+ """
165
+ 获取错误日志,支持搜索、日期过滤和排序
166
+
167
+ Args:
168
+ limit (int): 限制数量
169
+ offset (int): 偏移量
170
+ key_search (Optional[str]): Gemini密钥搜索词 (模糊匹��)
171
+ error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
172
+ error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
173
+ start_date (Optional[datetime]): 开始日期时间
174
+ end_date (Optional[datetime]): 结束日期时间
175
+ sort_by (str): 排序字段 (例如 'id', 'request_time')
176
+ sort_order (str): 排序顺序 ('asc' or 'desc')
177
+
178
+ Returns:
179
+ List[Dict[str, Any]]: 错误日志列表
180
+ """
181
+ try:
182
+ query = select(
183
+ ErrorLog.id,
184
+ ErrorLog.gemini_key,
185
+ ErrorLog.model_name,
186
+ ErrorLog.error_type,
187
+ ErrorLog.error_log,
188
+ ErrorLog.error_code,
189
+ ErrorLog.request_time
190
+ )
191
+
192
+ if key_search:
193
+ query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
194
+ if error_search:
195
+ query = query.where(
196
+ (ErrorLog.error_type.ilike(f"%{error_search}%")) |
197
+ (ErrorLog.error_log.ilike(f"%{error_search}%"))
198
+ )
199
+ if start_date:
200
+ query = query.where(ErrorLog.request_time >= start_date)
201
+ if end_date:
202
+ query = query.where(ErrorLog.request_time < end_date)
203
+ if error_code_search:
204
+ try:
205
+ error_code_int = int(error_code_search)
206
+ query = query.where(ErrorLog.error_code == error_code_int)
207
+ except ValueError:
208
+ logger.warning(f"Invalid format for error_code_search: '{error_code_search}'. Expected an integer. Skipping error code filter.")
209
+
210
+ sort_column = getattr(ErrorLog, sort_by, ErrorLog.id)
211
+ if sort_order.lower() == 'asc':
212
+ query = query.order_by(asc(sort_column))
213
+ else:
214
+ query = query.order_by(desc(sort_column))
215
+
216
+ query = query.limit(limit).offset(offset)
217
+
218
+ result = await database.fetch_all(query)
219
+ return [dict(row) for row in result]
220
+ except Exception as e:
221
+ logger.exception(f"Failed to get error logs with filters: {str(e)}")
222
+ raise
223
+
224
+
225
+ async def get_error_logs_count(
226
+ key_search: Optional[str] = None,
227
+ error_search: Optional[str] = None,
228
+ error_code_search: Optional[str] = None,
229
+ start_date: Optional[datetime] = None,
230
+ end_date: Optional[datetime] = None
231
+ ) -> int:
232
+ """
233
+ 获取符合条件的错误日志总数
234
+
235
+ Args:
236
+ key_search (Optional[str]): Gemini密钥搜索词 (模糊匹配)
237
+ error_search (Optional[str]): 错误类型或日志内容搜索词 (模糊匹配)
238
+ error_code_search (Optional[str]): 错误码搜索词 (精确匹配)
239
+ start_date (Optional[datetime]): 开始日期时间
240
+ end_date (Optional[datetime]): 结束日期时间
241
+
242
+ Returns:
243
+ int: 日志总数
244
+ """
245
+ try:
246
+ query = select(func.count()).select_from(ErrorLog)
247
+
248
+ if key_search:
249
+ query = query.where(ErrorLog.gemini_key.ilike(f"%{key_search}%"))
250
+ if error_search:
251
+ query = query.where(
252
+ (ErrorLog.error_type.ilike(f"%{error_search}%")) |
253
+ (ErrorLog.error_log.ilike(f"%{error_search}%"))
254
+ )
255
+ if start_date:
256
+ query = query.where(ErrorLog.request_time >= start_date)
257
+ if end_date:
258
+ query = query.where(ErrorLog.request_time < end_date)
259
+ if error_code_search:
260
+ try:
261
+ error_code_int = int(error_code_search)
262
+ query = query.where(ErrorLog.error_code == error_code_int)
263
+ except ValueError:
264
+ logger.warning(f"Invalid format for error_code_search in count: '{error_code_search}'. Expected an integer. Skipping error code filter.")
265
+
266
+
267
+ count_result = await database.fetch_one(query)
268
+ return count_result[0] if count_result else 0
269
+ except Exception as e:
270
+ logger.exception(f"Failed to count error logs with filters: {str(e)}")
271
+ raise
272
+
273
+
274
+ # 新增函数:获取单条错误日志详情
275
+ async def get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
276
+ """
277
+ 根据 ID 获取单个错误日志的详细信息
278
+
279
+ Args:
280
+ log_id (int): 错误日志的 ID
281
+
282
+ Returns:
283
+ Optional[Dict[str, Any]]: 包含日志详细信息的字典,如果未找到则返回 None
284
+ """
285
+ try:
286
+ query = select(ErrorLog).where(ErrorLog.id == log_id)
287
+ result = await database.fetch_one(query)
288
+ if result:
289
+ # 将 request_msg (JSONB) 转换为字符串以便在 API 中返回
290
+ log_dict = dict(result)
291
+ if 'request_msg' in log_dict and log_dict['request_msg'] is not None:
292
+ # 确保即使是 None 或非 JSON 数据也能处理
293
+ try:
294
+ log_dict['request_msg'] = json.dumps(log_dict['request_msg'], ensure_ascii=False, indent=2)
295
+ except TypeError:
296
+ log_dict['request_msg'] = str(log_dict['request_msg'])
297
+ return log_dict
298
+ else:
299
+ return None
300
+ except Exception as e:
301
+ logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
302
+ raise
303
+
304
+
305
+ async def delete_error_logs_by_ids(log_ids: List[int]) -> int:
306
+ """
307
+ 根据提供的 ID 列表批量删除错误日志 (异步)。
308
+
309
+ Args:
310
+ log_ids: 要删除的错误日志 ID 列表。
311
+
312
+ Returns:
313
+ int: 实际删除的日志数量。
314
+ """
315
+ if not log_ids:
316
+ return 0
317
+ try:
318
+ # 使用 databases 执行删除
319
+ query = delete(ErrorLog).where(ErrorLog.id.in_(log_ids))
320
+ # execute 返回受影响的行数,但 databases 库的 execute 不直接返回 rowcount
321
+ # 我们需要先查询是否存在,或者依赖数据库约束/触发器(如果适用)
322
+ # 或者,我们可以执行删除并假设成功,除非抛出异常
323
+ # 为了简单起见,我们执行删除并记录日志,不精确返回删除数量
324
+ # 如果需要精确数量,需要先执行 SELECT COUNT(*)
325
+ await database.execute(query)
326
+ # 注意:databases 的 execute 不返回 rowcount,所以我们不能直接返回删除的数量
327
+ # 返回 log_ids 的长度作为尝试删除的数量,或者返回 0/1 表示操作尝试
328
+ logger.info(f"Attempted bulk deletion for error logs with IDs: {log_ids}")
329
+ return len(log_ids) # 返回尝试删除的数量
330
+ except Exception as e:
331
+ # 数据库连接或执行错误
332
+ logger.error(f"Error during bulk deletion of error logs {log_ids}: {e}", exc_info=True)
333
+ raise
334
+
335
+ async def delete_error_log_by_id(log_id: int) -> bool:
336
+ """
337
+ 根据 ID 删除单个错误日志 (异步)。
338
+
339
+ Args:
340
+ log_id: 要删除的错误日志 ID。
341
+
342
+ Returns:
343
+ bool: 如果成功删除返回 True,否则返回 False。
344
+ """
345
+ try:
346
+ # 先检查是否存在 (可选,但更明确)
347
+ check_query = select(ErrorLog.id).where(ErrorLog.id == log_id)
348
+ exists = await database.fetch_one(check_query)
349
+
350
+ if not exists:
351
+ logger.warning(f"Attempted to delete non-existent error log with ID: {log_id}")
352
+ return False
353
+
354
+ # 执行删除
355
+ delete_query = delete(ErrorLog).where(ErrorLog.id == log_id)
356
+ await database.execute(delete_query)
357
+ logger.info(f"Successfully deleted error log with ID: {log_id}")
358
+ return True
359
+ except Exception as e:
360
+ logger.error(f"Error deleting error log with ID {log_id}: {e}", exc_info=True)
361
+ raise
362
+
363
+
364
+ async def delete_all_error_logs() -> int:
365
+ """
366
+ 删除所有错误日志条目。
367
+
368
+ Returns:
369
+ int: 被删除的错误日志数量。
370
+ """
371
+ try:
372
+ # 1. 获取删除前的总数
373
+ count_query = select(func.count()).select_from(ErrorLog)
374
+ total_to_delete = await database.fetch_val(count_query)
375
+
376
+ if total_to_delete == 0:
377
+ logger.info("No error logs found to delete.")
378
+ return 0
379
+
380
+ # 2. 执行删除操作
381
+ delete_query = delete(ErrorLog)
382
+ await database.execute(delete_query)
383
+
384
+ logger.info(f"Successfully deleted all {total_to_delete} error logs.")
385
+ return total_to_delete
386
+ except Exception as e:
387
+ logger.error(f"Failed to delete all error logs: {str(e)}", exc_info=True)
388
+ raise
389
+
390
+
391
+ # 新增函数:添加请求日志
392
+ async def add_request_log(
393
+ model_name: Optional[str],
394
+ api_key: Optional[str],
395
+ is_success: bool,
396
+ status_code: Optional[int] = None,
397
+ latency_ms: Optional[int] = None,
398
+ request_time: Optional[datetime] = None
399
+ ) -> bool:
400
+ """
401
+ 添加 API 请求日志
402
+
403
+ Args:
404
+ model_name: 模型名称
405
+ api_key: 使用的 API 密钥
406
+ is_success: 请求是否成功
407
+ status_code: API 响应状态码
408
+ latency_ms: 请求耗时(毫秒)
409
+ request_time: 请求发生时间 (如果为 None, 则使用当前时间)
410
+
411
+ Returns:
412
+ bool: 是否添加成功
413
+ """
414
+ try:
415
+ log_time = request_time if request_time else datetime.now()
416
+
417
+ query = insert(RequestLog).values(
418
+ request_time=log_time,
419
+ model_name=model_name,
420
+ api_key=api_key,
421
+ is_success=is_success,
422
+ status_code=status_code,
423
+ latency_ms=latency_ms
424
+ )
425
+ await database.execute(query)
426
+ return True
427
+ except Exception as e:
428
+ logger.error(f"Failed to add request log: {str(e)}")
429
+ return False
app/domain/gemini_models.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Literal, Optional, Union
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from app.core.constants import DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
6
+
7
+
8
+ class SafetySetting(BaseModel):
9
+ category: Optional[
10
+ Literal[
11
+ "HARM_CATEGORY_HATE_SPEECH",
12
+ "HARM_CATEGORY_DANGEROUS_CONTENT",
13
+ "HARM_CATEGORY_HARASSMENT",
14
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT",
15
+ "HARM_CATEGORY_CIVIC_INTEGRITY",
16
+ ]
17
+ ] = None
18
+ threshold: Optional[
19
+ Literal[
20
+ "HARM_BLOCK_THRESHOLD_UNSPECIFIED",
21
+ "BLOCK_LOW_AND_ABOVE",
22
+ "BLOCK_MEDIUM_AND_ABOVE",
23
+ "BLOCK_ONLY_HIGH",
24
+ "BLOCK_NONE",
25
+ "OFF",
26
+ ]
27
+ ] = None
28
+
29
+
30
+ class GenerationConfig(BaseModel):
31
+ stopSequences: Optional[List[str]] = None
32
+ responseMimeType: Optional[str] = None
33
+ responseSchema: Optional[Dict[str, Any]] = None
34
+ candidateCount: Optional[int] = 1
35
+ maxOutputTokens: Optional[int] = None
36
+ temperature: Optional[float] = DEFAULT_TEMPERATURE
37
+ topP: Optional[float] = DEFAULT_TOP_P
38
+ topK: Optional[int] = DEFAULT_TOP_K
39
+ presencePenalty: Optional[float] = None
40
+ frequencyPenalty: Optional[float] = None
41
+ responseLogprobs: Optional[bool] = None
42
+ logprobs: Optional[int] = None
43
+ thinkingConfig: Optional[Dict[str, Any]] = None
44
+
45
+
46
+ class SystemInstruction(BaseModel):
47
+ role: Optional[str] = "system"
48
+ parts: Union[List[Dict[str, Any]], Dict[str, Any]]
49
+
50
+
51
+ class GeminiContent(BaseModel):
52
+ role: Optional[str] = None
53
+ parts: List[Dict[str, Any]]
54
+
55
+
56
+ class GeminiRequest(BaseModel):
57
+ contents: List[GeminiContent] = []
58
+ tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
59
+ safetySettings: Optional[List[SafetySetting]] = Field(
60
+ default=None, alias="safety_settings"
61
+ )
62
+ generationConfig: Optional[GenerationConfig] = Field(
63
+ default=None, alias="generation_config"
64
+ )
65
+ systemInstruction: Optional[SystemInstruction] = Field(
66
+ default=None, alias="system_instruction"
67
+ )
68
+
69
+ class Config:
70
+ populate_by_name = True
71
+
72
+
73
+ class ResetSelectedKeysRequest(BaseModel):
74
+ keys: List[str]
75
+ key_type: str
76
+
77
+
78
+ class VerifySelectedKeysRequest(BaseModel):
79
+ keys: List[str]
app/domain/image_models.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+
4
+ class ImageMetadata:
5
+ def __init__(self, width: int, height: int, filename: str, size: int, url: str, delete_url: Union[str, None] = None):
6
+ self.width = width
7
+ self.height = height
8
+ self.filename = filename
9
+ self.size = size
10
+ self.url = url
11
+ self.delete_url = delete_url
12
+ class UploadResponse:
13
+ def __init__(self, success: bool, code: str, message: str, data: ImageMetadata):
14
+ self.success = success
15
+ self.code = code
16
+ self.message = message
17
+ self.data = data
18
+ class ImageUploader:
19
+ def upload(self, file: bytes, filename: str) -> UploadResponse:
20
+ raise NotImplementedError
app/domain/openai_models.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ from app.core.constants import DEFAULT_MODEL, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P
5
+
6
+
7
+ class ChatRequest(BaseModel):
8
+ messages: List[dict]
9
+ model: str = DEFAULT_MODEL
10
+ temperature: Optional[float] = DEFAULT_TEMPERATURE
11
+ stream: Optional[bool] = False
12
+ max_tokens: Optional[int] = None
13
+ top_p: Optional[float] = DEFAULT_TOP_P
14
+ top_k: Optional[int] = DEFAULT_TOP_K
15
+ stop: Optional[Union[List[str],str]] = None
16
+ reasoning_effort: Optional[str] = None
17
+ tools: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = []
18
+ tool_choice: Optional[str] = None
19
+ response_format: Optional[dict] = None
20
+
21
+
22
+ class EmbeddingRequest(BaseModel):
23
+ input: Union[str, List[str]]
24
+ model: str = "text-embedding-004"
25
+ encoding_format: Optional[str] = "float"
26
+
27
+
28
+ class ImageGenerationRequest(BaseModel):
29
+ model: str = "imagen-3.0-generate-002"
30
+ prompt: str = ""
31
+ n: int = 1
32
+ size: Optional[str] = "1024x1024"
33
+ quality: Optional[str] = None
34
+ style: Optional[str] = None
35
+ response_format: Optional[str] = "url"
36
+
37
+
38
+ class TTSRequest(BaseModel):
39
+ model: str = "gemini-2.5-flash-preview-tts"
40
+ input: str
41
+ voice: str = "Kore"
42
+ response_format: Optional[str] = "wav"
app/exception/exceptions.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 异常处理模块,定义应用程序中使用的自定义异常和异常处理器
3
+ """
4
+
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.exceptions import RequestValidationError
7
+ from fastapi.responses import JSONResponse
8
+ from starlette.exceptions import HTTPException as StarletteHTTPException
9
+
10
+ from app.log.logger import get_exceptions_logger
11
+
12
+ logger = get_exceptions_logger()
13
+
14
+
15
+ class APIError(Exception):
16
+ """API错误基类"""
17
+
18
+ def __init__(self, status_code: int, detail: str, error_code: str = None):
19
+ self.status_code = status_code
20
+ self.detail = detail
21
+ self.error_code = error_code or "api_error"
22
+ super().__init__(self.detail)
23
+
24
+
25
+ class AuthenticationError(APIError):
26
+ """认证错误"""
27
+
28
+ def __init__(self, detail: str = "Authentication failed"):
29
+ super().__init__(
30
+ status_code=401, detail=detail, error_code="authentication_error"
31
+ )
32
+
33
+
34
+ class AuthorizationError(APIError):
35
+ """授权错误"""
36
+
37
+ def __init__(self, detail: str = "Not authorized to access this resource"):
38
+ super().__init__(
39
+ status_code=403, detail=detail, error_code="authorization_error"
40
+ )
41
+
42
+
43
+ class ResourceNotFoundError(APIError):
44
+ """资源未找到错误"""
45
+
46
+ def __init__(self, detail: str = "Resource not found"):
47
+ super().__init__(
48
+ status_code=404, detail=detail, error_code="resource_not_found"
49
+ )
50
+
51
+
52
+ class ModelNotSupportedError(APIError):
53
+ """模型不支持错误"""
54
+
55
+ def __init__(self, model: str):
56
+ super().__init__(
57
+ status_code=400,
58
+ detail=f"Model {model} is not supported",
59
+ error_code="model_not_supported",
60
+ )
61
+
62
+
63
+ class APIKeyError(APIError):
64
+ """API密钥错误"""
65
+
66
+ def __init__(self, detail: str = "Invalid or expired API key"):
67
+ super().__init__(status_code=401, detail=detail, error_code="api_key_error")
68
+
69
+
70
+ class ServiceUnavailableError(APIError):
71
+ """服务不可用错误"""
72
+
73
+ def __init__(self, detail: str = "Service temporarily unavailable"):
74
+ super().__init__(
75
+ status_code=503, detail=detail, error_code="service_unavailable"
76
+ )
77
+
78
+
79
+ def setup_exception_handlers(app: FastAPI) -> None:
80
+ """
81
+ 设置应用程序的异常处理器
82
+
83
+ Args:
84
+ app: FastAPI应用程序实例
85
+ """
86
+
87
+ @app.exception_handler(APIError)
88
+ async def api_error_handler(request: Request, exc: APIError):
89
+ """处理API错误"""
90
+ logger.error(f"API Error: {exc.detail} (Code: {exc.error_code})")
91
+ return JSONResponse(
92
+ status_code=exc.status_code,
93
+ content={"error": {"code": exc.error_code, "message": exc.detail}},
94
+ )
95
+
96
+ @app.exception_handler(StarletteHTTPException)
97
+ async def http_exception_handler(request: Request, exc: StarletteHTTPException):
98
+ """处理HTTP异常"""
99
+ logger.error(f"HTTP Exception: {exc.detail} (Status: {exc.status_code})")
100
+ return JSONResponse(
101
+ status_code=exc.status_code,
102
+ content={"error": {"code": "http_error", "message": exc.detail}},
103
+ )
104
+
105
+ @app.exception_handler(RequestValidationError)
106
+ async def validation_exception_handler(
107
+ request: Request, exc: RequestValidationError
108
+ ):
109
+ """处理请求验证错误"""
110
+ error_details = []
111
+ for error in exc.errors():
112
+ error_details.append(
113
+ {"loc": error["loc"], "msg": error["msg"], "type": error["type"]}
114
+ )
115
+
116
+ logger.error(f"Validation Error: {error_details}")
117
+ return JSONResponse(
118
+ status_code=422,
119
+ content={
120
+ "error": {
121
+ "code": "validation_error",
122
+ "message": "Request validation failed",
123
+ "details": error_details,
124
+ }
125
+ },
126
+ )
127
+
128
+ @app.exception_handler(Exception)
129
+ async def general_exception_handler(request: Request, exc: Exception):
130
+ """处理通用异常"""
131
+ logger.exception(f"Unhandled Exception: {str(exc)}")
132
+ return JSONResponse(
133
+ status_code=500,
134
+ content={
135
+ "error": {
136
+ "code": "internal_server_error",
137
+ "message": "An unexpected error occurred",
138
+ }
139
+ },
140
+ )
app/handler/error_handler.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ from fastapi import HTTPException
3
+ import logging
4
+
5
+ @asynccontextmanager
6
+ async def handle_route_errors(logger: logging.Logger, operation_name: str, success_message: str = None, failure_message: str = None):
7
+ """
8
+ 一个异步上下文管理器,用于统一处理 FastAPI 路由中的常见错误和日志记录。
9
+
10
+ Args:
11
+ logger: 用于记录日志的 Logger 实例。
12
+ operation_name: 操作的名称,用于日志记录和错误详情。
13
+ success_message: 操作成功时记录的自定义消息 (可选)。
14
+ failure_message: 操作失败时记录的自定义消息 (可选)。
15
+ """
16
+ default_success_msg = f"{operation_name} request successful"
17
+ default_failure_msg = f"{operation_name} request failed"
18
+
19
+ logger.info("-" * 50 + operation_name + "-" * 50)
20
+ try:
21
+ yield
22
+ logger.info(success_message or default_success_msg)
23
+ except HTTPException as http_exc:
24
+ # 如果已经是 HTTPException,直接重新抛出,保留原始状态码和详情
25
+ logger.error(f"{failure_message or default_failure_msg}: {http_exc.detail} (Status: {http_exc.status_code})")
26
+ raise http_exc
27
+ except Exception as e:
28
+ # 对于其他所有异常,记录错误并抛出标准的 500 错误
29
+ logger.error(f"{failure_message or default_failure_msg}: {str(e)}")
30
+ raise HTTPException(
31
+ status_code=500, detail=f"Internal server error during {operation_name}"
32
+ ) from e
app/handler/message_converter.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import re
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import requests
8
+
9
+ from app.core.constants import (
10
+ AUDIO_FORMAT_TO_MIMETYPE,
11
+ DATA_URL_PATTERN,
12
+ IMAGE_URL_PATTERN,
13
+ MAX_AUDIO_SIZE_BYTES,
14
+ MAX_VIDEO_SIZE_BYTES,
15
+ SUPPORTED_AUDIO_FORMATS,
16
+ SUPPORTED_ROLES,
17
+ SUPPORTED_VIDEO_FORMATS,
18
+ VIDEO_FORMAT_TO_MIMETYPE,
19
+ )
20
+ from app.log.logger import get_message_converter_logger
21
+
22
+ logger = get_message_converter_logger()
23
+
24
+
25
+ class MessageConverter(ABC):
26
+ """消息转换器基类"""
27
+
28
+ @abstractmethod
29
+ def convert(
30
+ self, messages: List[Dict[str, Any]]
31
+ ) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
32
+ pass
33
+
34
+
35
+ def _get_mime_type_and_data(base64_string):
36
+ """
37
+ 从 base64 字符串中提取 MIME 类型和数据。
38
+
39
+ 参数:
40
+ base64_string (str): 可能包含 MIME 类型信息的 base64 字符串
41
+
42
+ 返回:
43
+ tuple: (mime_type, encoded_data)
44
+ """
45
+ # 检查字符串是否以 "data:" 格式开始
46
+ if base64_string.startswith("data:"):
47
+ # 提取 MIME 类型和数据
48
+ pattern = DATA_URL_PATTERN
49
+ match = re.match(pattern, base64_string)
50
+ if match:
51
+ mime_type = (
52
+ "image/jpeg" if match.group(1) == "image/jpg" else match.group(1)
53
+ )
54
+ encoded_data = match.group(2)
55
+ return mime_type, encoded_data
56
+
57
+ # 如果不是预期格式,假定它只是数据部分
58
+ return None, base64_string
59
+
60
+
61
+ def _convert_image(image_url: str) -> Dict[str, Any]:
62
+ if image_url.startswith("data:image"):
63
+ mime_type, encoded_data = _get_mime_type_and_data(image_url)
64
+ return {"inline_data": {"mime_type": mime_type, "data": encoded_data}}
65
+ else:
66
+ encoded_data = _convert_image_to_base64(image_url)
67
+ return {"inline_data": {"mime_type": "image/png", "data": encoded_data}}
68
+
69
+
70
+ def _convert_image_to_base64(url: str) -> str:
71
+ """
72
+ 将图片URL转换为base64编码
73
+ Args:
74
+ url: 图片URL
75
+ Returns:
76
+ str: base64编码的图片数据
77
+ """
78
+ response = requests.get(url)
79
+ if response.status_code == 200:
80
+ # 将图片内容转换为base64
81
+ img_data = base64.b64encode(response.content).decode("utf-8")
82
+ return img_data
83
+ else:
84
+ raise Exception(f"Failed to fetch image: {response.status_code}")
85
+
86
+
87
+ def _process_text_with_image(text: str) -> List[Dict[str, Any]]:
88
+ """
89
+ 处理可能包含图片URL的文本,提取图片并转换为base64
90
+
91
+ Args:
92
+ text: 可能包含图片URL的文本
93
+
94
+ Returns:
95
+ List[Dict[str, Any]]: 包含文本和图片的部分列表
96
+ """
97
+ parts = []
98
+ img_url_match = re.search(IMAGE_URL_PATTERN, text)
99
+ if img_url_match:
100
+ # 提取URL
101
+ img_url = img_url_match.group(2)
102
+ # 将URL对应的图片转换为base64
103
+ try:
104
+ base64_data = _convert_image_to_base64(img_url)
105
+ parts.append(
106
+ {"inline_data": {"mimeType": "image/png", "data": base64_data}}
107
+ )
108
+ except Exception:
109
+ # 如果转换失败,回退到文本模式
110
+ parts.append({"text": text})
111
+ else:
112
+ # 没有图片URL,作为纯文本处理
113
+ parts.append({"text": text})
114
+ return parts
115
+
116
+
117
+ class OpenAIMessageConverter(MessageConverter):
118
+ """OpenAI消息格式转换器"""
119
+
120
+ def _validate_media_data(
121
+ self, format: str, data: str, supported_formats: List[str], max_size: int
122
+ ) -> tuple[Optional[str], Optional[str]]:
123
+ """Validates format and size of Base64 media data."""
124
+ if format.lower() not in supported_formats:
125
+ logger.error(
126
+ f"Unsupported media format: {format}. Supported: {supported_formats}"
127
+ )
128
+ raise ValueError(f"Unsupported media format: {format}")
129
+
130
+ try:
131
+ decoded_data = base64.b64decode(data, validate=True)
132
+ if len(decoded_data) > max_size:
133
+ logger.error(
134
+ f"Media data size ({len(decoded_data)} bytes) exceeds limit ({max_size} bytes)."
135
+ )
136
+ raise ValueError(
137
+ f"Media data size exceeds limit of {max_size // 1024 // 1024}MB"
138
+ )
139
+ return data
140
+ except base64.binascii.Error as e:
141
+ logger.error(f"Invalid Base64 data provided: {e}")
142
+ raise ValueError("Invalid Base64 data")
143
+ except Exception as e:
144
+ logger.error(f"Error validating media data: {e}")
145
+ raise
146
+
147
+ def convert(
148
+ self, messages: List[Dict[str, Any]]
149
+ ) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
150
+ converted_messages = []
151
+ system_instruction_parts = []
152
+
153
+ for idx, msg in enumerate(messages):
154
+ role = msg.get("role", "")
155
+ parts = []
156
+
157
+ if "content" in msg and isinstance(msg["content"], list):
158
+ for content_item in msg["content"]:
159
+ if not isinstance(content_item, dict):
160
+ logger.warning(
161
+ f"Skipping unexpected content item format: {type(content_item)}"
162
+ )
163
+ continue
164
+
165
+ content_type = content_item.get("type")
166
+
167
+ if content_type == "text" and content_item.get("text"):
168
+ parts.append({"text": content_item["text"]})
169
+ elif content_type == "image_url" and content_item.get(
170
+ "image_url", {}
171
+ ).get("url"):
172
+ try:
173
+ parts.append(
174
+ _convert_image(content_item["image_url"]["url"])
175
+ )
176
+ except Exception as e:
177
+ logger.error(
178
+ f"Failed to convert image URL {content_item['image_url']['url']}: {e}"
179
+ )
180
+ parts.append(
181
+ {
182
+ "text": f"[Error processing image: {content_item['image_url']['url']}]"
183
+ }
184
+ )
185
+ elif content_type == "input_audio" and content_item.get(
186
+ "input_audio"
187
+ ):
188
+ audio_info = content_item["input_audio"]
189
+ audio_data = audio_info.get("data")
190
+ audio_format = audio_info.get("format", "").lower()
191
+
192
+ if not audio_data or not audio_format:
193
+ logger.warning(
194
+ "Skipping audio part due to missing data or format."
195
+ )
196
+ continue
197
+
198
+ try:
199
+ validated_data = self._validate_media_data(
200
+ audio_format,
201
+ audio_data,
202
+ SUPPORTED_AUDIO_FORMATS,
203
+ MAX_AUDIO_SIZE_BYTES,
204
+ )
205
+
206
+ # Get MIME type
207
+ mime_type = AUDIO_FORMAT_TO_MIMETYPE.get(audio_format)
208
+ if not mime_type:
209
+ # Should not happen if format validation passed, but double-check
210
+ logger.error(
211
+ f"Could not find MIME type for supported format: {audio_format}"
212
+ )
213
+ raise ValueError(
214
+ f"Internal error: MIME type mapping missing for {audio_format}"
215
+ )
216
+
217
+ parts.append(
218
+ {
219
+ "inline_data": {
220
+ "mimeType": mime_type,
221
+ "data": validated_data, # Use the validated Base64 data
222
+ }
223
+ }
224
+ )
225
+ logger.debug(
226
+ f"Successfully added audio part (format: {audio_format})"
227
+ )
228
+
229
+ except ValueError as e:
230
+ logger.error(
231
+ f"Skipping audio part due to validation error: {e}"
232
+ )
233
+ parts.append({"text": f"[Error processing audio: {e}]"})
234
+ except Exception:
235
+ logger.exception("Unexpected error processing audio part.")
236
+ parts.append(
237
+ {"text": "[Unexpected error processing audio]"}
238
+ )
239
+
240
+ elif content_type == "input_video" and content_item.get(
241
+ "input_video"
242
+ ):
243
+ video_info = content_item["input_video"]
244
+ video_data = video_info.get("data")
245
+ video_format = video_info.get("format", "").lower()
246
+
247
+ if not video_data or not video_format:
248
+ logger.warning(
249
+ "Skipping video part due to missing data or format."
250
+ )
251
+ continue
252
+
253
+ try:
254
+ validated_data = self._validate_media_data(
255
+ video_format,
256
+ video_data,
257
+ SUPPORTED_VIDEO_FORMATS,
258
+ MAX_VIDEO_SIZE_BYTES,
259
+ )
260
+ mime_type = VIDEO_FORMAT_TO_MIMETYPE.get(video_format)
261
+ if not mime_type:
262
+ raise ValueError(
263
+ f"Internal error: MIME type mapping missing for {video_format}"
264
+ )
265
+
266
+ parts.append(
267
+ {
268
+ "inline_data": {
269
+ "mimeType": mime_type,
270
+ "data": validated_data,
271
+ }
272
+ }
273
+ )
274
+ logger.debug(
275
+ f"Successfully added video part (format: {video_format})"
276
+ )
277
+
278
+ except ValueError as e:
279
+ logger.error(
280
+ f"Skipping video part due to validation error: {e}"
281
+ )
282
+ parts.append({"text": f"[Error processing video: {e}]"})
283
+ except Exception:
284
+ logger.exception("Unexpected error processing video part.")
285
+ parts.append(
286
+ {"text": "[Unexpected error processing video]"}
287
+ )
288
+
289
+ else:
290
+ # Log unrecognized but present types
291
+ if content_type:
292
+ logger.warning(
293
+ f"Unsupported content type or missing data in structured content: {content_type}"
294
+ )
295
+
296
+ elif (
297
+ "content" in msg and isinstance(msg["content"], str) and msg["content"]
298
+ ):
299
+ parts.extend(_process_text_with_image(msg["content"]))
300
+ elif "tool_calls" in msg and isinstance(msg["tool_calls"], list):
301
+ # Keep existing tool call processing
302
+ for tool_call in msg["tool_calls"]:
303
+ function_call = tool_call.get("function", {})
304
+ # Sanitize arguments loading
305
+ arguments_str = function_call.get("arguments", "{}")
306
+ try:
307
+ function_call["args"] = json.loads(arguments_str)
308
+ except json.JSONDecodeError:
309
+ logger.warning(
310
+ f"Failed to decode tool call arguments: {arguments_str}"
311
+ )
312
+ function_call["args"] = {}
313
+ if "arguments" in function_call:
314
+ if "arguments" in function_call:
315
+ del function_call["arguments"]
316
+
317
+ parts.append({"functionCall": function_call})
318
+
319
+ if role not in SUPPORTED_ROLES:
320
+ if role == "tool":
321
+ role = "user"
322
+ else:
323
+ # 如果是最后一条消息,则认为是用户消息
324
+ if idx == len(messages) - 1:
325
+ role = "user"
326
+ else:
327
+ role = "model"
328
+ if parts:
329
+ if role == "system":
330
+ text_only_parts = [p for p in parts if "text" in p]
331
+ if len(text_only_parts) != len(parts):
332
+ logger.warning(
333
+ "Non-text parts found in system message; discarding them."
334
+ )
335
+ if text_only_parts:
336
+ system_instruction_parts.extend(text_only_parts)
337
+
338
+ else:
339
+ converted_messages.append({"role": role, "parts": parts})
340
+
341
+ system_instruction = (
342
+ None
343
+ if not system_instruction_parts
344
+ else {
345
+ "role": "system",
346
+ "parts": system_instruction_parts,
347
+ }
348
+ )
349
+ return converted_messages, system_instruction
app/handler/response_handler.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import random
4
+ import string
5
+ import time
6
+ import uuid
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from app.config.config import settings
11
+ from app.utils.uploader import ImageUploaderFactory
12
+
13
+
14
+ class ResponseHandler(ABC):
15
+ """响应处理器基类"""
16
+
17
+ @abstractmethod
18
+ def handle_response(
19
+ self, response: Dict[str, Any], model: str, stream: bool = False
20
+ ) -> Dict[str, Any]:
21
+ pass
22
+
23
+
24
+ class GeminiResponseHandler(ResponseHandler):
25
+ """Gemini响应处理器"""
26
+
27
+ def __init__(self):
28
+ self.thinking_first = True
29
+ self.thinking_status = False
30
+
31
+ def handle_response(
32
+ self, response: Dict[str, Any], model: str, stream: bool = False, usage_metadata: Optional[Dict[str, Any]] = None
33
+ ) -> Dict[str, Any]:
34
+ if stream:
35
+ return _handle_gemini_stream_response(response, model, stream)
36
+ return _handle_gemini_normal_response(response, model, stream)
37
+
38
+
39
+ def _handle_openai_stream_response(
40
+ response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
41
+ ) -> Dict[str, Any]:
42
+ text, tool_calls, _ = _extract_result(
43
+ response, model, stream=True, gemini_format=False
44
+ )
45
+ if not text and not tool_calls:
46
+ delta = {}
47
+ else:
48
+ delta = {"content": text, "role": "assistant"}
49
+ if tool_calls:
50
+ delta["tool_calls"] = tool_calls
51
+ template_chunk = {
52
+ "id": f"chatcmpl-{uuid.uuid4()}",
53
+ "object": "chat.completion.chunk",
54
+ "created": int(time.time()),
55
+ "model": model,
56
+ "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
57
+ }
58
+ if usage_metadata:
59
+ template_chunk["usage"] = {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)}
60
+ return template_chunk
61
+
62
+
63
+ def _handle_openai_normal_response(
64
+ response: Dict[str, Any], model: str, finish_reason: str, usage_metadata: Optional[Dict[str, Any]]
65
+ ) -> Dict[str, Any]:
66
+ text, tool_calls, _ = _extract_result(
67
+ response, model, stream=False, gemini_format=False
68
+ )
69
+ return {
70
+ "id": f"chatcmpl-{uuid.uuid4()}",
71
+ "object": "chat.completion",
72
+ "created": int(time.time()),
73
+ "model": model,
74
+ "choices": [
75
+ {
76
+ "index": 0,
77
+ "message": {
78
+ "role": "assistant",
79
+ "content": text,
80
+ "tool_calls": tool_calls,
81
+ },
82
+ "finish_reason": finish_reason,
83
+ }
84
+ ],
85
+ "usage": {"prompt_tokens": usage_metadata.get("promptTokenCount", 0), "completion_tokens": usage_metadata.get("candidatesTokenCount",0), "total_tokens": usage_metadata.get("totalTokenCount", 0)},
86
+ }
87
+
88
+
89
+ class OpenAIResponseHandler(ResponseHandler):
90
+ """OpenAI响应处理器"""
91
+
92
+ def __init__(self, config):
93
+ self.config = config
94
+ self.thinking_first = True
95
+ self.thinking_status = False
96
+
97
+ def handle_response(
98
+ self,
99
+ response: Dict[str, Any],
100
+ model: str,
101
+ stream: bool = False,
102
+ finish_reason: str = None,
103
+ usage_metadata: Optional[Dict[str, Any]] = None,
104
+ ) -> Optional[Dict[str, Any]]:
105
+ if stream:
106
+ return _handle_openai_stream_response(response, model, finish_reason, usage_metadata)
107
+ return _handle_openai_normal_response(response, model, finish_reason, usage_metadata)
108
+
109
+ def handle_image_chat_response(
110
+ self, image_str: str, model: str, stream=False, finish_reason="stop"
111
+ ):
112
+ if stream:
113
+ return _handle_openai_stream_image_response(image_str, model, finish_reason)
114
+ return _handle_openai_normal_image_response(image_str, model, finish_reason)
115
+
116
+
117
+ def _handle_openai_stream_image_response(
118
+ image_str: str, model: str, finish_reason: str
119
+ ) -> Dict[str, Any]:
120
+ return {
121
+ "id": f"chatcmpl-{uuid.uuid4()}",
122
+ "object": "chat.completion.chunk",
123
+ "created": int(time.time()),
124
+ "model": model,
125
+ "choices": [
126
+ {
127
+ "index": 0,
128
+ "delta": {"content": image_str} if image_str else {},
129
+ "finish_reason": finish_reason,
130
+ }
131
+ ],
132
+ }
133
+
134
+
135
+ def _handle_openai_normal_image_response(
136
+ image_str: str, model: str, finish_reason: str
137
+ ) -> Dict[str, Any]:
138
+ return {
139
+ "id": f"chatcmpl-{uuid.uuid4()}",
140
+ "object": "chat.completion",
141
+ "created": int(time.time()),
142
+ "model": model,
143
+ "choices": [
144
+ {
145
+ "index": 0,
146
+ "message": {"role": "assistant", "content": image_str},
147
+ "finish_reason": finish_reason,
148
+ }
149
+ ],
150
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
151
+ }
152
+
153
+
154
+ def _extract_result(
155
+ response: Dict[str, Any],
156
+ model: str,
157
+ stream: bool = False,
158
+ gemini_format: bool = False,
159
+ ) -> tuple[str, List[Dict[str, Any]], Optional[bool]]:
160
+ text, tool_calls = "", []
161
+ thought = None
162
+ if stream:
163
+ if response.get("candidates"):
164
+ candidate = response["candidates"][0]
165
+ content = candidate.get("content", {})
166
+ parts = content.get("parts", [])
167
+ if not parts:
168
+ return "", [], None
169
+ if "text" in parts[0]:
170
+ text = parts[0].get("text")
171
+ if "thought" in parts[0]:
172
+ thought = parts[0].get("thought")
173
+ elif "executableCode" in parts[0]:
174
+ text = _format_code_block(parts[0]["executableCode"])
175
+ elif "codeExecution" in parts[0]:
176
+ text = _format_code_block(parts[0]["codeExecution"])
177
+ elif "executableCodeResult" in parts[0]:
178
+ text = _format_execution_result(parts[0]["executableCodeResult"])
179
+ elif "codeExecutionResult" in parts[0]:
180
+ text = _format_execution_result(parts[0]["codeExecutionResult"])
181
+ elif "inlineData" in parts[0]:
182
+ text = _extract_image_data(parts[0])
183
+ else:
184
+ text = ""
185
+ text = _add_search_link_text(model, candidate, text)
186
+ tool_calls = _extract_tool_calls(parts, gemini_format)
187
+ else:
188
+ if response.get("candidates"):
189
+ candidate = response["candidates"][0]
190
+ if "thinking" in model:
191
+ if settings.SHOW_THINKING_PROCESS:
192
+ if len(candidate["content"]["parts"]) == 2:
193
+ text = (
194
+ "> thinking\n\n"
195
+ + candidate["content"]["parts"][0]["text"]
196
+ + "\n\n---\n> output\n\n"
197
+ + candidate["content"]["parts"][1]["text"]
198
+ )
199
+ else:
200
+ text = candidate["content"]["parts"][0]["text"]
201
+ else:
202
+ if len(candidate["content"]["parts"]) == 2:
203
+ text = candidate["content"]["parts"][1]["text"]
204
+ else:
205
+ text = candidate["content"]["parts"][0]["text"]
206
+ else:
207
+ text = ""
208
+ if "parts" in candidate["content"]:
209
+ for part in candidate["content"]["parts"]:
210
+ if "text" in part:
211
+ text += part["text"]
212
+ if "thought" in part and thought is None:
213
+ thought = part.get("thought")
214
+ elif "inlineData" in part:
215
+ text += _extract_image_data(part)
216
+
217
+ text = _add_search_link_text(model, candidate, text)
218
+ tool_calls = _extract_tool_calls(
219
+ candidate["content"]["parts"], gemini_format
220
+ )
221
+ else:
222
+ text = "暂无返回"
223
+ return text, tool_calls, thought
224
+
225
+
226
+ def _extract_image_data(part: dict) -> str:
227
+ image_uploader = None
228
+ if settings.UPLOAD_PROVIDER == "smms":
229
+ image_uploader = ImageUploaderFactory.create(
230
+ provider=settings.UPLOAD_PROVIDER, api_key=settings.SMMS_SECRET_TOKEN
231
+ )
232
+ elif settings.UPLOAD_PROVIDER == "picgo":
233
+ image_uploader = ImageUploaderFactory.create(
234
+ provider=settings.UPLOAD_PROVIDER, api_key=settings.PICGO_API_KEY
235
+ )
236
+ elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
237
+ image_uploader = ImageUploaderFactory.create(
238
+ provider=settings.UPLOAD_PROVIDER,
239
+ base_url=settings.CLOUDFLARE_IMGBED_URL,
240
+ auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
241
+ )
242
+ current_date = time.strftime("%Y/%m/%d")
243
+ filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
244
+ base64_data = part["inlineData"]["data"]
245
+ # 将base64_data转成bytes数组
246
+ bytes_data = base64.b64decode(base64_data)
247
+ upload_response = image_uploader.upload(bytes_data, filename)
248
+ if upload_response.success:
249
+ text = f"\n\n![image]({upload_response.data.url})\n\n"
250
+ else:
251
+ text = ""
252
+ return text
253
+
254
+
255
+ def _extract_tool_calls(
256
+ parts: List[Dict[str, Any]], gemini_format: bool
257
+ ) -> List[Dict[str, Any]]:
258
+ """提取工具调用信息"""
259
+ if not parts or not isinstance(parts, list):
260
+ return []
261
+
262
+ letters = string.ascii_lowercase + string.digits
263
+
264
+ tool_calls = list()
265
+ for i in range(len(parts)):
266
+ part = parts[i]
267
+ if not part or not isinstance(part, dict):
268
+ continue
269
+
270
+ item = part.get("functionCall", {})
271
+ if not item or not isinstance(item, dict):
272
+ continue
273
+
274
+ if gemini_format:
275
+ tool_calls.append(part)
276
+ else:
277
+ id = f"call_{''.join(random.sample(letters, 32))}"
278
+ name = item.get("name", "")
279
+ arguments = json.dumps(item.get("args", None) or {})
280
+
281
+ tool_calls.append(
282
+ {
283
+ "index": i,
284
+ "id": id,
285
+ "type": "function",
286
+ "function": {"name": name, "arguments": arguments},
287
+ }
288
+ )
289
+
290
+ return tool_calls
291
+
292
+
293
+ def _handle_gemini_stream_response(
294
+ response: Dict[str, Any], model: str, stream: bool
295
+ ) -> Dict[str, Any]:
296
+ text, tool_calls, thought = _extract_result(
297
+ response, model, stream=stream, gemini_format=True
298
+ )
299
+ if tool_calls:
300
+ content = {"parts": tool_calls, "role": "model"}
301
+ else:
302
+ part = {"text": text}
303
+ if thought is not None:
304
+ part["thought"] = thought
305
+ content = {"parts": [part], "role": "model"}
306
+ response["candidates"][0]["content"] = content
307
+ return response
308
+
309
+
310
+ def _handle_gemini_normal_response(
311
+ response: Dict[str, Any], model: str, stream: bool
312
+ ) -> Dict[str, Any]:
313
+ text, tool_calls, thought = _extract_result(
314
+ response, model, stream=stream, gemini_format=True
315
+ )
316
+ if tool_calls:
317
+ content = {"parts": tool_calls, "role": "model"}
318
+ else:
319
+ part = {"text": text}
320
+ if thought is not None:
321
+ part["thought"] = thought
322
+ content = {"parts": [part], "role": "model"}
323
+ response["candidates"][0]["content"] = content
324
+ return response
325
+
326
+
327
+ def _format_code_block(code_data: dict) -> str:
328
+ """格式化代码块输出"""
329
+ language = code_data.get("language", "").lower()
330
+ code = code_data.get("code", "").strip()
331
+ return f"""\n\n---\n\n【代码执行】\n```{language}\n{code}\n```\n"""
332
+
333
+
334
+ def _add_search_link_text(model: str, candidate: dict, text: str) -> str:
335
+ if (
336
+ settings.SHOW_SEARCH_LINK
337
+ and model.endswith("-search")
338
+ and "groundingMetadata" in candidate
339
+ and "groundingChunks" in candidate["groundingMetadata"]
340
+ ):
341
+ grounding_chunks = candidate["groundingMetadata"]["groundingChunks"]
342
+ text += "\n\n---\n\n"
343
+ text += "**【引用来源】**\n\n"
344
+ for _, grounding_chunk in enumerate(grounding_chunks, 1):
345
+ if "web" in grounding_chunk:
346
+ text += _create_search_link(grounding_chunk["web"])
347
+ return text
348
+ else:
349
+ return text
350
+
351
+
352
+ def _create_search_link(grounding_chunk: dict) -> str:
353
+ return f'\n- [{grounding_chunk["title"]}]({grounding_chunk["uri"]})'
354
+
355
+
356
+ def _format_execution_result(result_data: dict) -> str:
357
+ """格式化执行结果输出"""
358
+ outcome = result_data.get("outcome", "")
359
+ output = result_data.get("output", "").strip()
360
+ return f"""\n【执行结果】\n> outcome: {outcome}\n\n【输出结果】\n```plaintext\n{output}\n```\n\n---\n\n"""
app/handler/retry_handler.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from functools import wraps
3
+ from typing import Callable, TypeVar
4
+
5
+ from app.config.config import settings
6
+ from app.log.logger import get_retry_logger
7
+
8
+ T = TypeVar("T")
9
+ logger = get_retry_logger()
10
+
11
+
12
+ class RetryHandler:
13
+ """重试处理装饰器"""
14
+
15
+ def __init__(self, key_arg: str = "api_key"):
16
+ self.key_arg = key_arg
17
+
18
+ def __call__(self, func: Callable[..., T]) -> Callable[..., T]:
19
+ @wraps(func)
20
+ async def wrapper(*args, **kwargs) -> T:
21
+ last_exception = None
22
+
23
+ for attempt in range(settings.MAX_RETRIES):
24
+ retries = attempt + 1
25
+ try:
26
+ return await func(*args, **kwargs)
27
+ except Exception as e:
28
+ last_exception = e
29
+ logger.warning(
30
+ f"API call failed with error: {str(e)}. Attempt {retries} of {settings.MAX_RETRIES}"
31
+ )
32
+
33
+ # 从函数参数中获取 key_manager
34
+ key_manager = kwargs.get("key_manager")
35
+ if key_manager:
36
+ old_key = kwargs.get(self.key_arg)
37
+ new_key = await key_manager.handle_api_failure(old_key, retries)
38
+ if new_key:
39
+ kwargs[self.key_arg] = new_key
40
+ logger.info(f"Switched to new API key: {new_key}")
41
+ else:
42
+ logger.error(f"No valid API key available after {retries} retries.")
43
+ break
44
+
45
+ logger.error(
46
+ f"All retry attempts failed, raising final exception: {str(last_exception)}"
47
+ )
48
+ raise last_exception
49
+
50
+ return wrapper
app/handler/stream_optimizer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import asyncio
3
+ import math
4
+ from typing import Any, AsyncGenerator, Callable, List
5
+
6
+ from app.config.config import settings
7
+ from app.core.constants import (
8
+ DEFAULT_STREAM_CHUNK_SIZE,
9
+ DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
10
+ DEFAULT_STREAM_MAX_DELAY,
11
+ DEFAULT_STREAM_MIN_DELAY,
12
+ DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
13
+ )
14
+ from app.log.logger import get_gemini_logger, get_openai_logger
15
+
16
+ logger_openai = get_openai_logger()
17
+ logger_gemini = get_gemini_logger()
18
+
19
+
20
+ class StreamOptimizer:
21
+ """流式输出优化器
22
+
23
+ 提供流式输出优化功能,包括智能延迟调整和长文本分块输出。
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ logger=None,
29
+ min_delay: float = DEFAULT_STREAM_MIN_DELAY,
30
+ max_delay: float = DEFAULT_STREAM_MAX_DELAY,
31
+ short_text_threshold: int = DEFAULT_STREAM_SHORT_TEXT_THRESHOLD,
32
+ long_text_threshold: int = DEFAULT_STREAM_LONG_TEXT_THRESHOLD,
33
+ chunk_size: int = DEFAULT_STREAM_CHUNK_SIZE,
34
+ ):
35
+ """初始化流式输出优化器
36
+
37
+ 参数:
38
+ logger: 日志记录器
39
+ min_delay: 最小延迟时间(秒)
40
+ max_delay: 最大延迟时间(秒)
41
+ short_text_threshold: 短文本阈值(字符数)
42
+ long_text_threshold: 长文本阈值(字符数)
43
+ chunk_size: 长文本分块大小(字符数)
44
+ """
45
+ self.logger = logger
46
+ self.min_delay = min_delay
47
+ self.max_delay = max_delay
48
+ self.short_text_threshold = short_text_threshold
49
+ self.long_text_threshold = long_text_threshold
50
+ self.chunk_size = chunk_size
51
+
52
+ def calculate_delay(self, text_length: int) -> float:
53
+ """根据文本长度计算延迟时间
54
+
55
+ 参数:
56
+ text_length: 文本长度
57
+
58
+ 返回:
59
+ 延迟时间(秒)
60
+ """
61
+ if text_length <= self.short_text_threshold:
62
+ # 短文本使用较大延迟
63
+ return self.max_delay
64
+ elif text_length >= self.long_text_threshold:
65
+ # 长文本使用较小延迟
66
+ return self.min_delay
67
+ else:
68
+ # 中等长度文本使用线性插值计算延迟
69
+ # 使用对数函数使延迟变化更平滑
70
+ ratio = math.log(text_length / self.short_text_threshold) / math.log(
71
+ self.long_text_threshold / self.short_text_threshold
72
+ )
73
+ return self.max_delay - ratio * (self.max_delay - self.min_delay)
74
+
75
+ def split_text_into_chunks(self, text: str) -> List[str]:
76
+ """将文本分割成小块
77
+
78
+ 参数:
79
+ text: 要分割的文本
80
+
81
+ 返回:
82
+ 文本块列表
83
+ """
84
+ return [
85
+ text[i : i + self.chunk_size] for i in range(0, len(text), self.chunk_size)
86
+ ]
87
+
88
+ async def optimize_stream_output(
89
+ self,
90
+ text: str,
91
+ create_response_chunk: Callable[[str], Any],
92
+ format_chunk: Callable[[Any], str],
93
+ ) -> AsyncGenerator[str, None]:
94
+ """优化流式输出
95
+
96
+ 参数:
97
+ text: 要输出的文本
98
+ create_response_chunk: 创建响应块的函数,接收文本,返回响应块
99
+ format_chunk: 格式化响应块的函数,接收响应块,返回格式化后的字符串
100
+
101
+ 返回:
102
+ 异步生成器,生成格式化后的响应块
103
+ """
104
+ if not text:
105
+ return
106
+
107
+ # 计算智能延迟时间
108
+ delay = self.calculate_delay(len(text))
109
+
110
+ # 根据文本长度决定输出方式
111
+ if len(text) >= self.long_text_threshold:
112
+ # 长文本:分块输出
113
+ chunks = self.split_text_into_chunks(text)
114
+ for chunk_text in chunks:
115
+ chunk_response = create_response_chunk(chunk_text)
116
+ yield format_chunk(chunk_response)
117
+ await asyncio.sleep(delay)
118
+ else:
119
+ # 短文本:逐字符输出
120
+ for char in text:
121
+ char_chunk = create_response_chunk(char)
122
+ yield format_chunk(char_chunk)
123
+ await asyncio.sleep(delay)
124
+
125
+
126
+ # 创建默认的优化器实例,可以直接导入使用
127
+ openai_optimizer = StreamOptimizer(
128
+ logger=logger_openai,
129
+ min_delay=settings.STREAM_MIN_DELAY,
130
+ max_delay=settings.STREAM_MAX_DELAY,
131
+ short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
132
+ long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
133
+ chunk_size=settings.STREAM_CHUNK_SIZE,
134
+ )
135
+
136
+ gemini_optimizer = StreamOptimizer(
137
+ logger=logger_gemini,
138
+ min_delay=settings.STREAM_MIN_DELAY,
139
+ max_delay=settings.STREAM_MAX_DELAY,
140
+ short_text_threshold=settings.STREAM_SHORT_TEXT_THRESHOLD,
141
+ long_text_threshold=settings.STREAM_LONG_TEXT_THRESHOLD,
142
+ chunk_size=settings.STREAM_CHUNK_SIZE,
143
+ )
app/log/logger.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import platform
3
+ import sys
4
+ from typing import Dict, Optional
5
+
6
+ # ANSI转义序列颜色代码
7
+ COLORS = {
8
+ "DEBUG": "\033[34m", # 蓝色
9
+ "INFO": "\033[32m", # 绿色
10
+ "WARNING": "\033[33m", # 黄色
11
+ "ERROR": "\033[31m", # 红色
12
+ "CRITICAL": "\033[1;31m", # 红色加粗
13
+ }
14
+
15
+ # Windows系统启用ANSI支持
16
+ if platform.system() == "Windows":
17
+ import ctypes
18
+
19
+ kernel32 = ctypes.windll.kernel32
20
+ kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7)
21
+
22
+
23
+ class ColoredFormatter(logging.Formatter):
24
+ """
25
+ 自定义的日志格式化器,添加颜色支持
26
+ """
27
+
28
+ def format(self, record):
29
+ # 获取对应级别的颜色代码
30
+ color = COLORS.get(record.levelname, "")
31
+ # 添加颜色代码和重置代码
32
+ record.levelname = f"{color}{record.levelname}\033[0m"
33
+ # 创建包含文件名和行号的固定宽度字符串
34
+ record.fileloc = f"[{record.filename}:{record.lineno}]"
35
+ return super().format(record)
36
+
37
+
38
+ # 日志格式 - 使用 fileloc 并设置固定宽度 (例如 30)
39
+ FORMATTER = ColoredFormatter(
40
+ "%(asctime)s | %(levelname)-17s | %(fileloc)-30s | %(message)s"
41
+ )
42
+
43
+ # 日志级别映射
44
+ LOG_LEVELS = {
45
+ "debug": logging.DEBUG,
46
+ "info": logging.INFO,
47
+ "warning": logging.WARNING,
48
+ "error": logging.ERROR,
49
+ "critical": logging.CRITICAL,
50
+ }
51
+
52
+
53
+ class Logger:
54
+ def __init__(self):
55
+ pass
56
+
57
+ _loggers: Dict[str, logging.Logger] = {}
58
+
59
+ @staticmethod
60
+ def setup_logger(name: str) -> logging.Logger:
61
+ """
62
+ 设置并获取logger
63
+ :param name: logger名称
64
+ :return: logger实例
65
+ """
66
+ # 导入 settings 对象
67
+ from app.config.config import settings
68
+
69
+ # 从全局配置获取日志级别
70
+ log_level_str = settings.LOG_LEVEL.lower()
71
+ level = LOG_LEVELS.get(log_level_str, logging.INFO)
72
+
73
+ if name in Logger._loggers:
74
+ # 如果 logger 已存在,检查并更新其级别(如果需要)
75
+ existing_logger = Logger._loggers[name]
76
+ if existing_logger.level != level:
77
+ existing_logger.setLevel(level)
78
+ return existing_logger
79
+
80
+ logger = logging.getLogger(name)
81
+ logger.setLevel(level)
82
+ logger.propagate = False
83
+
84
+ # 添加控制台输出
85
+ console_handler = logging.StreamHandler(sys.stdout)
86
+ console_handler.setFormatter(FORMATTER)
87
+ logger.addHandler(console_handler)
88
+
89
+ Logger._loggers[name] = logger
90
+ return logger
91
+
92
+ @staticmethod
93
+ def get_logger(name: str) -> Optional[logging.Logger]:
94
+ """
95
+ 获取已存在的logger
96
+ :param name: logger名称
97
+ :return: logger实例或None
98
+ """
99
+ return Logger._loggers.get(name)
100
+
101
+ @staticmethod
102
+ def update_log_levels(log_level: str):
103
+ """
104
+ 根据当前的全局配置更新所有已创建 logger 的日志级别。
105
+ """
106
+ log_level_str = log_level.lower()
107
+ new_level = LOG_LEVELS.get(log_level_str, logging.INFO)
108
+
109
+ updated_count = 0
110
+ for logger_name, logger_instance in Logger._loggers.items():
111
+ if logger_instance.level != new_level:
112
+ logger_instance.setLevel(new_level)
113
+ # 可选:记录级别变更日志,但注意避免在日志模块内部产生过多日志
114
+ # print(f"Updated log level for logger '{logger_name}' to {log_level_str.upper()}")
115
+ updated_count += 1
116
+
117
+
118
+ # 预定义的loggers
119
+ def get_openai_logger():
120
+ return Logger.setup_logger("openai")
121
+
122
+
123
+ def get_gemini_logger():
124
+ return Logger.setup_logger("gemini")
125
+
126
+
127
+ def get_chat_logger():
128
+ return Logger.setup_logger("chat")
129
+
130
+
131
+ def get_model_logger():
132
+ return Logger.setup_logger("model")
133
+
134
+
135
+ def get_security_logger():
136
+ return Logger.setup_logger("security")
137
+
138
+
139
+ def get_key_manager_logger():
140
+ return Logger.setup_logger("key_manager")
141
+
142
+
143
+ def get_main_logger():
144
+ return Logger.setup_logger("main")
145
+
146
+
147
+ def get_embeddings_logger():
148
+ return Logger.setup_logger("embeddings")
149
+
150
+
151
+ def get_request_logger():
152
+ return Logger.setup_logger("request")
153
+
154
+
155
+ def get_retry_logger():
156
+ return Logger.setup_logger("retry")
157
+
158
+
159
+ def get_image_create_logger():
160
+ return Logger.setup_logger("image_create")
161
+
162
+
163
+ def get_exceptions_logger():
164
+ return Logger.setup_logger("exceptions")
165
+
166
+
167
+ def get_application_logger():
168
+ return Logger.setup_logger("application")
169
+
170
+
171
+ def get_initialization_logger():
172
+ return Logger.setup_logger("initialization")
173
+
174
+
175
+ def get_middleware_logger():
176
+ return Logger.setup_logger("middleware")
177
+
178
+
179
+ def get_routes_logger():
180
+ return Logger.setup_logger("routes")
181
+
182
+
183
+ def get_config_routes_logger():
184
+ return Logger.setup_logger("config_routes")
185
+
186
+
187
+ def get_config_logger():
188
+ return Logger.setup_logger("config")
189
+
190
+
191
+ def get_database_logger():
192
+ return Logger.setup_logger("database")
193
+
194
+
195
+ def get_log_routes_logger():
196
+ return Logger.setup_logger("log_routes")
197
+
198
+
199
+ def get_stats_logger():
200
+ return Logger.setup_logger("stats")
201
+
202
+
203
+ def get_update_logger():
204
+ return Logger.setup_logger("update_service")
205
+
206
+
207
+ def get_scheduler_routes():
208
+ return Logger.setup_logger("scheduler_routes")
209
+
210
+
211
+ def get_message_converter_logger():
212
+ return Logger.setup_logger("message_converter")
213
+
214
+
215
+ def get_api_client_logger():
216
+ return Logger.setup_logger("api_client")
217
+
218
+
219
+ def get_openai_compatible_logger():
220
+ return Logger.setup_logger("openai_compatible")
221
+
222
+
223
+ def get_error_log_logger():
224
+ return Logger.setup_logger("error_log")
225
+
226
+
227
+ def get_request_log_logger():
228
+ return Logger.setup_logger("request_log")
229
+
230
+
231
+ def get_vertex_express_logger():
232
+ return Logger.setup_logger("vertex_express")
233
+
app/main.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from dotenv import load_dotenv
3
+
4
+ # 在导入应用程序配置之前加载 .env 文件到环境变量
5
+ load_dotenv()
6
+
7
+ from app.core.application import create_app
8
+ from app.log.logger import get_main_logger
9
+
10
+ app = create_app()
11
+
12
+ if __name__ == "__main__":
13
+ logger = get_main_logger()
14
+ logger.info("Starting application server...")
15
+ uvicorn.run(app, host="0.0.0.0", port=8001)
app/middleware/middleware.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 中间件配置模块,负责设置和配置应用程序的中间件
3
+ """
4
+
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import RedirectResponse
8
+ from starlette.middleware.base import BaseHTTPMiddleware
9
+
10
+ # from app.middleware.request_logging_middleware import RequestLoggingMiddleware
11
+ from app.middleware.smart_routing_middleware import SmartRoutingMiddleware
12
+ from app.core.constants import API_VERSION
13
+ from app.core.security import verify_auth_token
14
+ from app.log.logger import get_middleware_logger
15
+
16
+ logger = get_middleware_logger()
17
+
18
+
19
+ class AuthMiddleware(BaseHTTPMiddleware):
20
+ """
21
+ 认证中间件,处理未经身份验证的请求
22
+ """
23
+
24
+ async def dispatch(self, request: Request, call_next):
25
+ # 允许特定路径绕过身份验证
26
+ if (
27
+ request.url.path not in ["/", "/auth"]
28
+ and not request.url.path.startswith("/static")
29
+ and not request.url.path.startswith("/gemini")
30
+ and not request.url.path.startswith("/v1")
31
+ and not request.url.path.startswith(f"/{API_VERSION}")
32
+ and not request.url.path.startswith("/health")
33
+ and not request.url.path.startswith("/hf")
34
+ and not request.url.path.startswith("/openai")
35
+ and not request.url.path.startswith("/api/version/check")
36
+ and not request.url.path.startswith("/vertex-express")
37
+ ):
38
+
39
+ auth_token = request.cookies.get("auth_token")
40
+ if not auth_token or not verify_auth_token(auth_token):
41
+ logger.warning(f"Unauthorized access attempt to {request.url.path}")
42
+ return RedirectResponse(url="/")
43
+ logger.debug("Request authenticated successfully")
44
+
45
+ response = await call_next(request)
46
+ return response
47
+
48
+
49
+ def setup_middlewares(app: FastAPI) -> None:
50
+ """
51
+ 设置应用程序的中间件
52
+
53
+ Args:
54
+ app: FastAPI应用程序实例
55
+ """
56
+ # 添加智能路由中间件(必须在认证中间件之前)
57
+ app.add_middleware(SmartRoutingMiddleware)
58
+
59
+ # 添加认证中间件
60
+ app.add_middleware(AuthMiddleware)
61
+
62
+ # 添加请求日志中间件(可选,默认注释掉)
63
+ # app.add_middleware(RequestLoggingMiddleware)
64
+
65
+ # 配置CORS中间件
66
+ app.add_middleware(
67
+ CORSMiddleware,
68
+ allow_origins=["*"],
69
+ allow_credentials=True,
70
+ allow_methods=[
71
+ "GET",
72
+ "POST",
73
+ "PUT",
74
+ "DELETE",
75
+ "OPTIONS",
76
+ ],
77
+ allow_headers=["*"],
78
+ expose_headers=["*"],
79
+ max_age=600,
80
+ )
app/middleware/request_logging_middleware.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from fastapi import Request
4
+ from starlette.middleware.base import BaseHTTPMiddleware
5
+
6
+ from app.log.logger import get_request_logger
7
+
8
+ logger = get_request_logger()
9
+
10
+
11
+ # 添加中间件类
12
+ class RequestLoggingMiddleware(BaseHTTPMiddleware):
13
+ async def dispatch(self, request: Request, call_next):
14
+ # 记录请求路径
15
+ logger.info(f"Request path: {request.url.path}")
16
+
17
+ # 获取并记录请求体
18
+ try:
19
+ body = await request.body()
20
+ if body:
21
+ body_str = body.decode()
22
+ # 尝试格式化JSON
23
+ try:
24
+ formatted_body = json.loads(body_str)
25
+ logger.info(
26
+ f"Formatted request body:\n{json.dumps(formatted_body, indent=2, ensure_ascii=False)}"
27
+ )
28
+ except json.JSONDecodeError:
29
+ logger.error("Request body is not valid JSON.")
30
+ except Exception as e:
31
+ logger.error(f"Error reading request body: {str(e)}")
32
+
33
+ # 重置请求的接收器,以便后续处理器可以继续读取请求体
34
+ async def receive():
35
+ return {"type": "http.request", "body": body, "more_body": False}
36
+
37
+ request._receive = receive
38
+
39
+ response = await call_next(request)
40
+ return response
app/middleware/smart_routing_middleware.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request
2
+ from starlette.middleware.base import BaseHTTPMiddleware
3
+ from app.config.config import settings
4
+ from app.log.logger import get_main_logger
5
+ import re
6
+
7
+ logger = get_main_logger()
8
+
9
+ class SmartRoutingMiddleware(BaseHTTPMiddleware):
10
+ def __init__(self, app):
11
+ super().__init__(app)
12
+ # 简化的路由规则 - 直接根据检测结果路由
13
+ pass
14
+
15
+ async def dispatch(self, request: Request, call_next):
16
+ if not settings.URL_NORMALIZATION_ENABLED:
17
+ return await call_next(request)
18
+ logger.debug(f"request: {request}")
19
+ original_path = str(request.url.path)
20
+ method = request.method
21
+
22
+ # 尝试修复URL
23
+ fixed_path, fix_info = self.fix_request_url(original_path, method, request)
24
+
25
+ if fixed_path != original_path:
26
+ logger.info(f"URL fixed: {method} {original_path} → {fixed_path}")
27
+ if fix_info:
28
+ logger.debug(f"Fix details: {fix_info}")
29
+
30
+ # 重写请求路径
31
+ request.scope["path"] = fixed_path
32
+ request.scope["raw_path"] = fixed_path.encode()
33
+
34
+ return await call_next(request)
35
+
36
+ def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
37
+ """简化的URL修复逻辑"""
38
+
39
+ # 首先检查是否已经是正确的格式,如果是则不处理
40
+ if self.is_already_correct_format(path):
41
+ return path, None
42
+
43
+ # 1. 最高优先级:包含generateContent → Gemini格式
44
+ if "generatecontent" in path.lower() or "v1beta/models" in path.lower():
45
+ return self.fix_gemini_by_operation(path, method, request)
46
+
47
+ # 2. 第二优先级:包含/openai/ → OpenAI格式
48
+ if "/openai/" in path.lower():
49
+ return self.fix_openai_by_operation(path, method)
50
+
51
+ # 3. 第三优先级:包含/v1/ → v1格式
52
+ if "/v1/" in path.lower():
53
+ return self.fix_v1_by_operation(path, method)
54
+
55
+ # 4. 第四优先级:包含/chat/completions → chat功能
56
+ if "/chat/completions" in path.lower():
57
+ return "/v1/chat/completions", {"type": "v1_chat"}
58
+
59
+ # 5. 默认:原样传递
60
+ return path, None
61
+
62
+ def is_already_correct_format(self, path: str) -> bool:
63
+ """检查是否已经是正确的API格式"""
64
+ # 检查是否已经是正确的端点格式
65
+ correct_patterns = [
66
+ r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini原生
67
+ r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Gemini带前缀
68
+ r"^/v1beta/models$", # Gemini模型列表
69
+ r"^/gemini/v1beta/models$", # Gemini带前缀的模型列表
70
+ r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # v1格式
71
+ r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # OpenAI格式
72
+ r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$", # HF格式
73
+ r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$", # Vertex Express Gemini格式
74
+ r"^/vertex-express/v1beta/models$", # Vertex Express模型列表
75
+ r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$", # Vertex Express OpenAI格式
76
+ ]
77
+
78
+ for pattern in correct_patterns:
79
+ if re.match(pattern, path):
80
+ return True
81
+
82
+ return False
83
+
84
+ def fix_gemini_by_operation(
85
+ self, path: str, method: str, request: Request
86
+ ) -> tuple:
87
+ """根据Gemini操作修复,考虑端点偏好"""
88
+ if method == "GET":
89
+ return "/v1beta/models", {
90
+ "role": "gemini_models",
91
+ }
92
+
93
+ # 提取模型名称
94
+ try:
95
+ model_name = self.extract_model_name(path, request)
96
+ except ValueError:
97
+ # 无法提取模型名称,返回原路径不做处理
98
+ return path, None
99
+
100
+ # 检测是否为流式请求
101
+ is_stream = self.detect_stream_request(path, request)
102
+
103
+ # 检查是否有vertex-express偏好
104
+ if "/vertex-express/" in path.lower():
105
+ if is_stream:
106
+ target_url = (
107
+ f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent"
108
+ )
109
+ else:
110
+ target_url = (
111
+ f"/vertex-express/v1beta/models/{model_name}:generateContent"
112
+ )
113
+
114
+ fix_info = {
115
+ "rule": (
116
+ "vertex_express_generate"
117
+ if not is_stream
118
+ else "vertex_express_stream"
119
+ ),
120
+ "preference": "vertex_express_format",
121
+ "is_stream": is_stream,
122
+ "model": model_name,
123
+ }
124
+ else:
125
+ # 标准Gemini端点
126
+ if is_stream:
127
+ target_url = f"/v1beta/models/{model_name}:streamGenerateContent"
128
+ else:
129
+ target_url = f"/v1beta/models/{model_name}:generateContent"
130
+
131
+ fix_info = {
132
+ "rule": "gemini_generate" if not is_stream else "gemini_stream",
133
+ "preference": "gemini_format",
134
+ "is_stream": is_stream,
135
+ "model": model_name,
136
+ }
137
+
138
+ return target_url, fix_info
139
+
140
+ def fix_openai_by_operation(self, path: str, method: str) -> tuple:
141
+ """根据操作类型修复OpenAI格式"""
142
+ if method == "POST":
143
+ if "chat" in path.lower() or "completion" in path.lower():
144
+ return "/openai/v1/chat/completions", {"type": "openai_chat"}
145
+ elif "embedding" in path.lower():
146
+ return "/openai/v1/embeddings", {"type": "openai_embeddings"}
147
+ elif "image" in path.lower():
148
+ return "/openai/v1/images/generations", {"type": "openai_images"}
149
+ elif "audio" in path.lower():
150
+ return "/openai/v1/audio/speech", {"type": "openai_audio"}
151
+ elif method == "GET":
152
+ if "model" in path.lower():
153
+ return "/openai/v1/models", {"type": "openai_models"}
154
+
155
+ return path, None
156
+
157
+ def fix_v1_by_operation(self, path: str, method: str) -> tuple:
158
+ """根据操作类型修复v1格式"""
159
+ if method == "POST":
160
+ if "chat" in path.lower() or "completion" in path.lower():
161
+ return "/v1/chat/completions", {"type": "v1_chat"}
162
+ elif "embedding" in path.lower():
163
+ return "/v1/embeddings", {"type": "v1_embeddings"}
164
+ elif "image" in path.lower():
165
+ return "/v1/images/generations", {"type": "v1_images"}
166
+ elif "audio" in path.lower():
167
+ return "/v1/audio/speech", {"type": "v1_audio"}
168
+ elif method == "GET":
169
+ if "model" in path.lower():
170
+ return "/v1/models", {"type": "v1_models"}
171
+
172
+ return path, None
173
+
174
+ def detect_stream_request(self, path: str, request: Request) -> bool:
175
+ """检测是否为流式请求"""
176
+ # 1. 路径中包含stream关键词
177
+ if "stream" in path.lower():
178
+ return True
179
+
180
+ # 2. 查询参数
181
+ if request.query_params.get("stream") == "true":
182
+ return True
183
+
184
+ return False
185
+
186
+ def extract_model_name(self, path: str, request: Request) -> str:
187
+ """从请求中提取模型名称,用于构建Gemini API URL"""
188
+ # 1. 从请求体中提取
189
+ try:
190
+ if hasattr(request, "_body") and request._body:
191
+ import json
192
+
193
+ body = json.loads(request._body.decode())
194
+ if "model" in body and body["model"]:
195
+ return body["model"]
196
+ except Exception:
197
+ pass
198
+
199
+ # 2. 从查询参数中提取
200
+ model_param = request.query_params.get("model")
201
+ if model_param:
202
+ return model_param
203
+
204
+ # 3. 从路径中提取(用于已包含模型名称的路径)
205
+ match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE)
206
+ if match:
207
+ return match.group(1)
208
+
209
+ # 4. 如果无法提取模型名称,抛出异常
210
+ raise ValueError("Unable to extract model name from request")
app/router/config_routes.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 配置路由模块
3
+ """
4
+
5
+ from typing import Any, Dict, List
6
+
7
+ from fastapi import APIRouter, HTTPException, Request
8
+ from fastapi.responses import RedirectResponse
9
+ from pydantic import BaseModel, Field
10
+
11
+ from app.core.security import verify_auth_token
12
+ from app.log.logger import Logger, get_config_routes_logger
13
+ from app.service.config.config_service import ConfigService
14
+
15
+ router = APIRouter(prefix="/api/config", tags=["config"])
16
+
17
+ logger = get_config_routes_logger()
18
+
19
+
20
+ @router.get("", response_model=Dict[str, Any])
21
+ async def get_config(request: Request):
22
+ auth_token = request.cookies.get("auth_token")
23
+ if not auth_token or not verify_auth_token(auth_token):
24
+ logger.warning("Unauthorized access attempt to config page")
25
+ return RedirectResponse(url="/", status_code=302)
26
+ return await ConfigService.get_config()
27
+
28
+
29
+ @router.put("", response_model=Dict[str, Any])
30
+ async def update_config(config_data: Dict[str, Any], request: Request):
31
+ auth_token = request.cookies.get("auth_token")
32
+ if not auth_token or not verify_auth_token(auth_token):
33
+ logger.warning("Unauthorized access attempt to config page")
34
+ return RedirectResponse(url="/", status_code=302)
35
+ try:
36
+ result = await ConfigService.update_config(config_data)
37
+ # 配置更新成功后,立即更新所有 logger 的级别
38
+ Logger.update_log_levels(config_data["LOG_LEVEL"])
39
+ logger.info("Log levels updated after configuration change.")
40
+ return result
41
+ except Exception as e:
42
+ logger.error(f"Error updating config or log levels: {e}", exc_info=True)
43
+ raise HTTPException(status_code=400, detail=str(e))
44
+
45
+
46
+ @router.post("/reset", response_model=Dict[str, Any])
47
+ async def reset_config(request: Request):
48
+ auth_token = request.cookies.get("auth_token")
49
+ if not auth_token or not verify_auth_token(auth_token):
50
+ logger.warning("Unauthorized access attempt to config page")
51
+ return RedirectResponse(url="/", status_code=302)
52
+ try:
53
+ return await ConfigService.reset_config()
54
+ except Exception as e:
55
+ raise HTTPException(status_code=400, detail=str(e))
56
+
57
+
58
+ class DeleteKeysRequest(BaseModel):
59
+ keys: List[str] = Field(..., description="List of API keys to delete")
60
+
61
+
62
+ @router.delete("/keys/{key_to_delete}", response_model=Dict[str, Any])
63
+ async def delete_single_key(key_to_delete: str, request: Request):
64
+ auth_token = request.cookies.get("auth_token")
65
+ if not auth_token or not verify_auth_token(auth_token):
66
+ logger.warning(f"Unauthorized attempt to delete key: {key_to_delete}")
67
+ return RedirectResponse(url="/", status_code=302)
68
+ try:
69
+ logger.info(f"Attempting to delete key: {key_to_delete}")
70
+ result = await ConfigService.delete_key(key_to_delete)
71
+ if not result.get("success"):
72
+ raise HTTPException(
73
+ status_code=(
74
+ 404 if "not found" in result.get("message", "").lower() else 400
75
+ ),
76
+ detail=result.get("message"),
77
+ )
78
+ return result
79
+ except HTTPException as e:
80
+ raise e
81
+ except Exception as e:
82
+ logger.error(f"Error deleting key '{key_to_delete}': {e}", exc_info=True)
83
+ raise HTTPException(status_code=500, detail=f"Error deleting key: {str(e)}")
84
+
85
+
86
+ @router.post("/keys/delete-selected", response_model=Dict[str, Any])
87
+ async def delete_selected_keys_route(
88
+ delete_request: DeleteKeysRequest, request: Request
89
+ ):
90
+ auth_token = request.cookies.get("auth_token")
91
+ if not auth_token or not verify_auth_token(auth_token):
92
+ logger.warning("Unauthorized attempt to bulk delete keys")
93
+ return RedirectResponse(url="/", status_code=302)
94
+
95
+ if not delete_request.keys:
96
+ logger.warning("Attempt to bulk delete keys with an empty list.")
97
+ raise HTTPException(status_code=400, detail="No keys provided for deletion.")
98
+
99
+ try:
100
+ logger.info(f"Attempting to bulk delete {len(delete_request.keys)} keys.")
101
+ result = await ConfigService.delete_selected_keys(delete_request.keys)
102
+ if not result.get("success") and result.get("deleted_count", 0) == 0:
103
+ raise HTTPException(
104
+ status_code=400, detail=result.get("message", "Failed to delete keys.")
105
+ )
106
+ return result
107
+ except HTTPException as e:
108
+ raise e
109
+ except Exception as e:
110
+ logger.error(f"Error bulk deleting keys: {e}", exc_info=True)
111
+ raise HTTPException(
112
+ status_code=500, detail=f"Error bulk deleting keys: {str(e)}"
113
+ )
114
+
115
+
116
+ @router.get("/ui/models")
117
+ async def get_ui_models(request: Request):
118
+ auth_token_cookie = request.cookies.get("auth_token")
119
+ if not auth_token_cookie or not verify_auth_token(auth_token_cookie):
120
+ logger.warning("Unauthorized access attempt to /api/config/ui/models")
121
+ raise HTTPException(status_code=403, detail="Not authenticated")
122
+
123
+ try:
124
+ models = await ConfigService.fetch_ui_models()
125
+ return models
126
+ except HTTPException as e:
127
+ raise e
128
+ except Exception as e:
129
+ logger.error(f"Unexpected error in /ui/models endpoint: {e}", exc_info=True)
130
+ raise HTTPException(
131
+ status_code=500,
132
+ detail=f"An unexpected error occurred while fetching UI models: {str(e)}",
133
+ )
app/router/error_log_routes.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 日志路由模块
3
+ """
4
+
5
+ from datetime import datetime
6
+ from typing import Dict, List, Optional
7
+
8
+ from fastapi import (
9
+ APIRouter,
10
+ Body,
11
+ HTTPException,
12
+ Path,
13
+ Query,
14
+ Request,
15
+ Response,
16
+ status,
17
+ )
18
+ from pydantic import BaseModel
19
+
20
+ from app.core.security import verify_auth_token
21
+ from app.log.logger import get_log_routes_logger
22
+ from app.service.error_log import error_log_service
23
+
24
+ router = APIRouter(prefix="/api/logs", tags=["logs"])
25
+
26
+ logger = get_log_routes_logger()
27
+
28
+
29
+ class ErrorLogListItem(BaseModel):
30
+ id: int
31
+ gemini_key: Optional[str] = None
32
+ error_type: Optional[str] = None
33
+ error_code: Optional[int] = None
34
+ model_name: Optional[str] = None
35
+ request_time: Optional[datetime] = None
36
+
37
+
38
+ class ErrorLogListResponse(BaseModel):
39
+ logs: List[ErrorLogListItem]
40
+ total: int
41
+
42
+
43
+ @router.get("/errors", response_model=ErrorLogListResponse)
44
+ async def get_error_logs_api(
45
+ request: Request,
46
+ limit: int = Query(10, ge=1, le=1000),
47
+ offset: int = Query(0, ge=0),
48
+ key_search: Optional[str] = Query(
49
+ None, description="Search term for Gemini key (partial match)"
50
+ ),
51
+ error_search: Optional[str] = Query(
52
+ None, description="Search term for error type or log message"
53
+ ),
54
+ error_code_search: Optional[str] = Query(
55
+ None, description="Search term for error code"
56
+ ),
57
+ start_date: Optional[datetime] = Query(
58
+ None, description="Start datetime for filtering"
59
+ ),
60
+ end_date: Optional[datetime] = Query(
61
+ None, description="End datetime for filtering"
62
+ ),
63
+ sort_by: str = Query(
64
+ "id", description="Field to sort by (e.g., 'id', 'request_time')"
65
+ ),
66
+ sort_order: str = Query("desc", description="Sort order ('asc' or 'desc')"),
67
+ ):
68
+ """
69
+ 获取错误日志列表 (返回错误码),支持过滤和排序
70
+
71
+ Args:
72
+ request: 请求对象
73
+ limit: 限制数量
74
+ offset: 偏移量
75
+ key_search: 密钥搜索
76
+ error_search: 错误搜索 (可能搜索类型或日志内容,由DB层决定)
77
+ error_code_search: 错误码搜索
78
+ start_date: 开始日期
79
+ end_date: 结束日期
80
+ sort_by: 排序字段
81
+ sort_order: 排序顺序
82
+
83
+ Returns:
84
+ ErrorLogListResponse: An object containing the list of logs (with error_code) and the total count.
85
+ """
86
+ auth_token = request.cookies.get("auth_token")
87
+ if not auth_token or not verify_auth_token(auth_token):
88
+ logger.warning("Unauthorized access attempt to error logs list")
89
+ raise HTTPException(status_code=401, detail="Not authenticated")
90
+
91
+ try:
92
+ result = await error_log_service.process_get_error_logs(
93
+ limit=limit,
94
+ offset=offset,
95
+ key_search=key_search,
96
+ error_search=error_search,
97
+ error_code_search=error_code_search,
98
+ start_date=start_date,
99
+ end_date=end_date,
100
+ sort_by=sort_by,
101
+ sort_order=sort_order,
102
+ )
103
+ logs_data = result["logs"]
104
+ total_count = result["total"]
105
+
106
+ validated_logs = [ErrorLogListItem(**log) for log in logs_data]
107
+ return ErrorLogListResponse(logs=validated_logs, total=total_count)
108
+ except Exception as e:
109
+ logger.exception(f"Failed to get error logs list: {str(e)}")
110
+ raise HTTPException(
111
+ status_code=500, detail=f"Failed to get error logs list: {str(e)}"
112
+ )
113
+
114
+
115
+ class ErrorLogDetailResponse(BaseModel):
116
+ id: int
117
+ gemini_key: Optional[str] = None
118
+ error_type: Optional[str] = None
119
+ error_log: Optional[str] = None
120
+ request_msg: Optional[str] = None
121
+ model_name: Optional[str] = None
122
+ request_time: Optional[datetime] = None
123
+
124
+
125
+ @router.get("/errors/{log_id}/details", response_model=ErrorLogDetailResponse)
126
+ async def get_error_log_detail_api(request: Request, log_id: int = Path(..., ge=1)):
127
+ """
128
+ 根据日志 ID 获取错误日志的详细信息 (包括 error_log 和 request_msg)
129
+ """
130
+ auth_token = request.cookies.get("auth_token")
131
+ if not auth_token or not verify_auth_token(auth_token):
132
+ logger.warning(
133
+ f"Unauthorized access attempt to error log details for ID: {log_id}"
134
+ )
135
+ raise HTTPException(status_code=401, detail="Not authenticated")
136
+
137
+ try:
138
+ log_details = await error_log_service.process_get_error_log_details(
139
+ log_id=log_id
140
+ )
141
+ if not log_details:
142
+ raise HTTPException(status_code=404, detail="Error log not found")
143
+
144
+ return ErrorLogDetailResponse(**log_details)
145
+ except HTTPException as http_exc:
146
+ raise http_exc
147
+ except Exception as e:
148
+ logger.exception(f"Failed to get error log details for ID {log_id}: {str(e)}")
149
+ raise HTTPException(
150
+ status_code=500, detail=f"Failed to get error log details: {str(e)}"
151
+ )
152
+
153
+
154
+ @router.delete("/errors", status_code=status.HTTP_204_NO_CONTENT)
155
+ async def delete_error_logs_bulk_api(
156
+ request: Request, payload: Dict[str, List[int]] = Body(...)
157
+ ):
158
+ """
159
+ 批量删除错误日志 (异步)
160
+ """
161
+ auth_token = request.cookies.get("auth_token")
162
+ if not auth_token or not verify_auth_token(auth_token):
163
+ logger.warning("Unauthorized access attempt to bulk delete error logs")
164
+ raise HTTPException(status_code=401, detail="Not authenticated")
165
+
166
+ log_ids = payload.get("ids")
167
+ if not log_ids:
168
+ raise HTTPException(status_code=400, detail="No log IDs provided for deletion.")
169
+
170
+ try:
171
+ deleted_count = await error_log_service.process_delete_error_logs_by_ids(
172
+ log_ids
173
+ )
174
+ # 注意:异步函数返回的是尝试删除的数量,可能不是精确值
175
+ logger.info(
176
+ f"Attempted bulk deletion for {deleted_count} error logs with IDs: {log_ids}"
177
+ )
178
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
179
+ except Exception as e:
180
+ logger.exception(f"Error bulk deleting error logs with IDs {log_ids}: {str(e)}")
181
+ raise HTTPException(
182
+ status_code=500, detail="Internal server error during bulk deletion"
183
+ )
184
+
185
+
186
+ @router.delete("/errors/all", status_code=status.HTTP_204_NO_CONTENT)
187
+ async def delete_all_error_logs_api(request: Request):
188
+ """
189
+ 删除所有错误日志 (异步)
190
+ """
191
+ auth_token = request.cookies.get("auth_token")
192
+ if not auth_token or not verify_auth_token(auth_token):
193
+ logger.warning("Unauthorized access attempt to delete all error logs")
194
+ raise HTTPException(status_code=401, detail="Not authenticated")
195
+
196
+ try:
197
+ deleted_count = await error_log_service.process_delete_all_error_logs()
198
+ logger.info(f"Successfully deleted all {deleted_count} error logs.")
199
+ # No body needed for 204 response
200
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
201
+ except Exception as e:
202
+ logger.exception(f"Error deleting all error logs: {str(e)}")
203
+ raise HTTPException(
204
+ status_code=500, detail="Internal server error during deletion of all logs"
205
+ )
206
+
207
+
208
+ @router.delete("/errors/{log_id}", status_code=status.HTTP_204_NO_CONTENT)
209
+ async def delete_error_log_api(request: Request, log_id: int = Path(..., ge=1)):
210
+ """
211
+ 删除单个错误日志 (异步)
212
+ """
213
+ auth_token = request.cookies.get("auth_token")
214
+ if not auth_token or not verify_auth_token(auth_token):
215
+ logger.warning(f"Unauthorized access attempt to delete error log ID: {log_id}")
216
+ raise HTTPException(status_code=401, detail="Not authenticated")
217
+
218
+ try:
219
+ success = await error_log_service.process_delete_error_log_by_id(log_id)
220
+ if not success:
221
+ # 服务层现在在未找到时返回 False,我们在这里转换为 404
222
+ raise HTTPException(
223
+ status_code=404, detail=f"Error log with ID {log_id} not found"
224
+ )
225
+ logger.info(f"Successfully deleted error log with ID: {log_id}")
226
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
227
+ except HTTPException as http_exc:
228
+ raise http_exc
229
+ except Exception as e:
230
+ logger.exception(f"Error deleting error log with ID {log_id}: {str(e)}")
231
+ raise HTTPException(
232
+ status_code=500, detail="Internal server error during deletion"
233
+ )
app/router/gemini_routes.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
+ from copy import deepcopy
4
+ import asyncio
5
+ from app.config.config import settings
6
+ from app.log.logger import get_gemini_logger
7
+ from app.core.security import SecurityService
8
+ from app.domain.gemini_models import GeminiContent, GeminiRequest, ResetSelectedKeysRequest, VerifySelectedKeysRequest
9
+ from app.service.chat.gemini_chat_service import GeminiChatService
10
+ from app.service.key.key_manager import KeyManager, get_key_manager_instance
11
+ from app.service.model.model_service import ModelService
12
+ from app.handler.retry_handler import RetryHandler
13
+ from app.handler.error_handler import handle_route_errors
14
+ from app.core.constants import API_VERSION
15
+
16
+ router = APIRouter(prefix=f"/gemini/{API_VERSION}")
17
+ router_v1beta = APIRouter(prefix=f"/{API_VERSION}")
18
+ logger = get_gemini_logger()
19
+
20
+ security_service = SecurityService()
21
+ model_service = ModelService()
22
+
23
+
24
+ async def get_key_manager():
25
+ """获取密钥管理器实例"""
26
+ return await get_key_manager_instance()
27
+
28
+
29
+ async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
30
+ """获取下一个可用的API密钥"""
31
+ return await key_manager.get_next_working_key()
32
+
33
+
34
+ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
35
+ """获取Gemini聊天服务实例"""
36
+ return GeminiChatService(settings.BASE_URL, key_manager)
37
+
38
+
39
+ @router.get("/models")
40
+ @router_v1beta.get("/models")
41
+ async def list_models(
42
+ _=Depends(security_service.verify_key_or_goog_api_key),
43
+ key_manager: KeyManager = Depends(get_key_manager)
44
+ ):
45
+ """获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
46
+ operation_name = "list_gemini_models"
47
+ logger.info("-" * 50 + operation_name + "-" * 50)
48
+ logger.info("Handling Gemini models list request")
49
+
50
+ try:
51
+ api_key = await key_manager.get_first_valid_key()
52
+ if not api_key:
53
+ raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
54
+ logger.info(f"Using API key: {api_key}")
55
+
56
+ models_data = await model_service.get_gemini_models(api_key)
57
+ if not models_data or "models" not in models_data:
58
+ raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
59
+
60
+ models_json = deepcopy(models_data)
61
+ model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
62
+
63
+ def add_derived_model(base_name, suffix, display_suffix):
64
+ model = model_mapping.get(base_name)
65
+ if not model:
66
+ logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
67
+ return
68
+ item = deepcopy(model)
69
+ item["name"] = f"models/{base_name}{suffix}"
70
+ display_name = f'{item.get("displayName", base_name)}{display_suffix}'
71
+ item["displayName"] = display_name
72
+ item["description"] = display_name
73
+ models_json["models"].append(item)
74
+
75
+ if settings.SEARCH_MODELS:
76
+ for name in settings.SEARCH_MODELS:
77
+ add_derived_model(name, "-search", " For Search")
78
+ if settings.IMAGE_MODELS:
79
+ for name in settings.IMAGE_MODELS:
80
+ add_derived_model(name, "-image", " For Image")
81
+ if settings.THINKING_MODELS:
82
+ for name in settings.THINKING_MODELS:
83
+ add_derived_model(name, "-non-thinking", " Non Thinking")
84
+
85
+ logger.info("Gemini models list request successful")
86
+ return models_json
87
+ except HTTPException as http_exc:
88
+ raise http_exc
89
+ except Exception as e:
90
+ logger.error(f"Error getting Gemini models list: {str(e)}")
91
+ raise HTTPException(
92
+ status_code=500, detail="Internal server error while fetching Gemini models list"
93
+ ) from e
94
+
95
+
96
+ @router.post("/models/{model_name}:generateContent")
97
+ @router_v1beta.post("/models/{model_name}:generateContent")
98
+ @RetryHandler(key_arg="api_key")
99
+ async def generate_content(
100
+ model_name: str,
101
+ request: GeminiRequest,
102
+ _=Depends(security_service.verify_key_or_goog_api_key),
103
+ api_key: str = Depends(get_next_working_key),
104
+ key_manager: KeyManager = Depends(get_key_manager),
105
+ chat_service: GeminiChatService = Depends(get_chat_service)
106
+ ):
107
+ """处理 Gemini 非流式内容生成请求。"""
108
+ operation_name = "gemini_generate_content"
109
+ async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
110
+ logger.info(f"Handling Gemini content generation request for model: {model_name}")
111
+ logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
112
+ logger.info(f"Using API key: {api_key}")
113
+
114
+ if not await model_service.check_model_support(model_name):
115
+ raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
116
+
117
+ response = await chat_service.generate_content(
118
+ model=model_name,
119
+ request=request,
120
+ api_key=api_key
121
+ )
122
+ return response
123
+
124
+
125
+ @router.post("/models/{model_name}:streamGenerateContent")
126
+ @router_v1beta.post("/models/{model_name}:streamGenerateContent")
127
+ @RetryHandler(key_arg="api_key")
128
+ async def stream_generate_content(
129
+ model_name: str,
130
+ request: GeminiRequest,
131
+ _=Depends(security_service.verify_key_or_goog_api_key),
132
+ api_key: str = Depends(get_next_working_key),
133
+ key_manager: KeyManager = Depends(get_key_manager),
134
+ chat_service: GeminiChatService = Depends(get_chat_service)
135
+ ):
136
+ """处理 Gemini 流式内容生成请求。"""
137
+ operation_name = "gemini_stream_generate_content"
138
+ async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
139
+ logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
140
+ logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
141
+ logger.info(f"Using API key: {api_key}")
142
+
143
+ if not await model_service.check_model_support(model_name):
144
+ raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
145
+
146
+ response_stream = chat_service.stream_generate_content(
147
+ model=model_name,
148
+ request=request,
149
+ api_key=api_key
150
+ )
151
+ return StreamingResponse(response_stream, media_type="text/event-stream")
152
+
153
+
154
+ @router.post("/reset-all-fail-counts")
155
+ async def reset_all_key_fail_counts(key_type: str = None, key_manager: KeyManager = Depends(get_key_manager)):
156
+ """批量重置Gemini API密钥的失败计数,可选择性地仅重置有效或无效密钥"""
157
+ logger.info("-" * 50 + "reset_all_gemini_key_fail_counts" + "-" * 50)
158
+ logger.info(f"Received reset request with key_type: {key_type}")
159
+
160
+ try:
161
+ # 获取分类后的密钥
162
+ keys_by_status = await key_manager.get_keys_by_status()
163
+ valid_keys = keys_by_status.get("valid_keys", {})
164
+ invalid_keys = keys_by_status.get("invalid_keys", {})
165
+
166
+ # 根据类型选择要重置的密钥
167
+ keys_to_reset = []
168
+ if key_type == "valid":
169
+ keys_to_reset = list(valid_keys.keys())
170
+ logger.info(f"Resetting only valid keys, count: {len(keys_to_reset)}")
171
+ elif key_type == "invalid":
172
+ keys_to_reset = list(invalid_keys.keys())
173
+ logger.info(f"Resetting only invalid keys, count: {len(keys_to_reset)}")
174
+ else:
175
+ # 重置所有密钥
176
+ await key_manager.reset_failure_counts()
177
+ return JSONResponse({"success": True, "message": "所有密钥的失败计数已重置"})
178
+
179
+ # 批量重置指定类型的密钥
180
+ for key in keys_to_reset:
181
+ await key_manager.reset_key_failure_count(key)
182
+
183
+ return JSONResponse({
184
+ "success": True,
185
+ "message": f"{key_type}密钥的失败计数已重置",
186
+ "reset_count": len(keys_to_reset)
187
+ })
188
+ except Exception as e:
189
+ logger.error(f"Failed to reset key failure counts: {str(e)}")
190
+ return JSONResponse({"success": False, "message": f"批量重置失败: {str(e)}"}, status_code=500)
191
+
192
+
193
+ @router.post("/reset-selected-fail-counts")
194
+ async def reset_selected_key_fail_counts(
195
+ request: ResetSelectedKeysRequest,
196
+ key_manager: KeyManager = Depends(get_key_manager)
197
+ ):
198
+ """批量重置选定Gemini API密钥的失败计数"""
199
+ logger.info("-" * 50 + "reset_selected_gemini_key_fail_counts" + "-" * 50)
200
+ keys_to_reset = request.keys
201
+ key_type = request.key_type
202
+ logger.info(f"Received reset request for {len(keys_to_reset)} selected {key_type} keys.")
203
+
204
+ if not keys_to_reset:
205
+ return JSONResponse({"success": False, "message": "没有提供需要重置的密钥"}, status_code=400)
206
+
207
+ reset_count = 0
208
+ errors = []
209
+
210
+ try:
211
+ for key in keys_to_reset:
212
+ try:
213
+ result = await key_manager.reset_key_failure_count(key)
214
+ if result:
215
+ reset_count += 1
216
+ else:
217
+ logger.warning(f"Key not found during selective reset: {key}")
218
+ except Exception as key_error:
219
+ logger.error(f"Error resetting key {key}: {str(key_error)}")
220
+ errors.append(f"Key {key}: {str(key_error)}")
221
+
222
+ if errors:
223
+ error_message = f"批量重置完成,但出现错误: {'; '.join(errors)}"
224
+ final_success = reset_count > 0
225
+ status_code = 207 if final_success and errors else 500
226
+ return JSONResponse({
227
+ "success": final_success,
228
+ "message": error_message,
229
+ "reset_count": reset_count
230
+ }, status_code=status_code)
231
+
232
+ return JSONResponse({
233
+ "success": True,
234
+ "message": f"成功重置 {reset_count} 个选定 {key_type} 密钥的失败计数",
235
+ "reset_count": reset_count
236
+ })
237
+ except Exception as e:
238
+ logger.error(f"Failed to process reset selected key failure counts request: {str(e)}")
239
+ return JSONResponse({"success": False, "message": f"批量重置处理失败: {str(e)}"}, status_code=500)
240
+
241
+
242
+ @router.post("/reset-fail-count/{api_key}")
243
+ async def reset_key_fail_count(api_key: str, key_manager: KeyManager = Depends(get_key_manager)):
244
+ """重置指定Gemini API密钥的失败计数"""
245
+ logger.info("-" * 50 + "reset_gemini_key_fail_count" + "-" * 50)
246
+ logger.info(f"Resetting failure count for API key: {api_key}")
247
+
248
+ try:
249
+ result = await key_manager.reset_key_failure_count(api_key)
250
+ if result:
251
+ return JSONResponse({"success": True, "message": "失败计数已重置"})
252
+ return JSONResponse({"success": False, "message": "未找到指定密钥"}, status_code=404)
253
+ except Exception as e:
254
+ logger.error(f"Failed to reset key failure count: {str(e)}")
255
+ return JSONResponse({"success": False, "message": f"重置失败: {str(e)}"}, status_code=500)
256
+
257
+
258
+ @router.post("/verify-key/{api_key}")
259
+ async def verify_key(api_key: str, chat_service: GeminiChatService = Depends(get_chat_service), key_manager: KeyManager = Depends(get_key_manager)):
260
+ """验证Gemini API密钥的有效性"""
261
+ logger.info("-" * 50 + "verify_gemini_key" + "-" * 50)
262
+ logger.info("Verifying API key validity")
263
+
264
+ try:
265
+ gemini_request = GeminiRequest(
266
+ contents=[
267
+ GeminiContent(
268
+ role="user",
269
+ parts=[{"text": "hi"}],
270
+ )
271
+ ],
272
+ generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
273
+ )
274
+
275
+ response = await chat_service.generate_content(
276
+ settings.TEST_MODEL,
277
+ gemini_request,
278
+ api_key
279
+ )
280
+
281
+ if response:
282
+ return JSONResponse({"status": "valid"})
283
+ except Exception as e:
284
+ logger.error(f"Key verification failed: {str(e)}")
285
+
286
+ async with key_manager.failure_count_lock:
287
+ if api_key in key_manager.key_failure_counts:
288
+ key_manager.key_failure_counts[api_key] += 1
289
+ logger.warning(f"Verification exception for key: {api_key}, incrementing failure count")
290
+
291
+ return JSONResponse({"status": "invalid", "error": str(e)})
292
+
293
+
294
+ @router.post("/verify-selected-keys")
295
+ async def verify_selected_keys(
296
+ request: VerifySelectedKeysRequest,
297
+ chat_service: GeminiChatService = Depends(get_chat_service),
298
+ key_manager: KeyManager = Depends(get_key_manager)
299
+ ):
300
+ """批量验证选定Gemini API密钥的有效性"""
301
+ logger.info("-" * 50 + "verify_selected_gemini_keys" + "-" * 50)
302
+ keys_to_verify = request.keys
303
+ logger.info(f"Received verification request for {len(keys_to_verify)} selected keys.")
304
+
305
+ if not keys_to_verify:
306
+ return JSONResponse({"success": False, "message": "没有提供需要验证的密钥"}, status_code=400)
307
+
308
+ successful_keys = []
309
+ failed_keys = {}
310
+
311
+ async def _verify_single_key(api_key: str):
312
+ """内部函数,用于验证单个密钥并处理异常"""
313
+ nonlocal successful_keys, failed_keys
314
+ try:
315
+ gemini_request = GeminiRequest(
316
+ contents=[GeminiContent(role="user", parts=[{"text": "hi"}])],
317
+ generation_config={"temperature": 0.7, "top_p": 1.0, "max_output_tokens": 10}
318
+ )
319
+ await chat_service.generate_content(
320
+ settings.TEST_MODEL,
321
+ gemini_request,
322
+ api_key
323
+ )
324
+ successful_keys.append(api_key)
325
+ return api_key, "valid", None
326
+ except Exception as e:
327
+ error_message = str(e)
328
+ logger.warning(f"Key verification failed for {api_key}: {error_message}")
329
+ async with key_manager.failure_count_lock:
330
+ if api_key in key_manager.key_failure_counts:
331
+ key_manager.key_failure_counts[api_key] += 1
332
+ logger.warning(f"Bulk verification exception for key: {api_key}, incrementing failure count")
333
+ else:
334
+ key_manager.key_failure_counts[api_key] = 1
335
+ logger.warning(f"Bulk verification exception for key: {api_key}, initializing failure count to 1")
336
+ failed_keys[api_key] = error_message
337
+ return api_key, "invalid", error_message
338
+
339
+ tasks = [_verify_single_key(key) for key in keys_to_verify]
340
+ results = await asyncio.gather(*tasks, return_exceptions=True)
341
+
342
+ for result in results:
343
+ if isinstance(result, Exception):
344
+ logger.error(f"An unexpected error occurred during bulk verification task: {result}")
345
+ elif result:
346
+ if not isinstance(result, Exception) and result:
347
+ key, status, error = result
348
+ elif isinstance(result, Exception):
349
+ logger.error(f"Task execution error during bulk verification: {result}")
350
+
351
+ valid_count = len(successful_keys)
352
+ invalid_count = len(failed_keys)
353
+ logger.info(f"Bulk verification finished. Valid: {valid_count}, Invalid: {invalid_count}")
354
+
355
+ if failed_keys:
356
+ message = f"批量验证完成。成功: {valid_count}, 失败: {invalid_count}。"
357
+ return JSONResponse({
358
+ "success": True,
359
+ "message": message,
360
+ "successful_keys": successful_keys,
361
+ "failed_keys": failed_keys,
362
+ "valid_count": valid_count,
363
+ "invalid_count": invalid_count
364
+ })
365
+ else:
366
+ message = f"批量验证成功完成。所有 {valid_count} 个密钥均有效。"
367
+ return JSONResponse({
368
+ "success": True,
369
+ "message": message,
370
+ "successful_keys": successful_keys,
371
+ "failed_keys": {},
372
+ "valid_count": valid_count,
373
+ "invalid_count": 0
374
+ })
app/router/openai_compatiable_routes.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends
2
+ from fastapi.responses import StreamingResponse
3
+
4
+ from app.config.config import settings
5
+ from app.core.security import SecurityService
6
+ from app.domain.openai_models import (
7
+ ChatRequest,
8
+ EmbeddingRequest,
9
+ ImageGenerationRequest,
10
+ )
11
+ from app.handler.retry_handler import RetryHandler
12
+ from app.handler.error_handler import handle_route_errors
13
+ from app.log.logger import get_openai_compatible_logger
14
+ from app.service.key.key_manager import KeyManager, get_key_manager_instance
15
+ from app.service.openai_compatiable.openai_compatiable_service import OpenAICompatiableService
16
+
17
+
18
+ router = APIRouter()
19
+ logger = get_openai_compatible_logger()
20
+
21
+ security_service = SecurityService()
22
+
23
+ async def get_key_manager():
24
+ return await get_key_manager_instance()
25
+
26
+
27
+ async def get_next_working_key_wrapper(
28
+ key_manager: KeyManager = Depends(get_key_manager),
29
+ ):
30
+ return await key_manager.get_next_working_key()
31
+
32
+
33
+ async def get_openai_service(key_manager: KeyManager = Depends(get_key_manager)):
34
+ """获取OpenAI聊天服务实例"""
35
+ return OpenAICompatiableService(settings.BASE_URL, key_manager)
36
+
37
+
38
+ @router.get("/openai/v1/models")
39
+ async def list_models(
40
+ _=Depends(security_service.verify_authorization),
41
+ key_manager: KeyManager = Depends(get_key_manager),
42
+ openai_service: OpenAICompatiableService = Depends(get_openai_service),
43
+ ):
44
+ """获取可用模型列表。"""
45
+ operation_name = "list_models"
46
+ async with handle_route_errors(logger, operation_name):
47
+ logger.info("Handling models list request")
48
+ api_key = await key_manager.get_first_valid_key()
49
+ logger.info(f"Using API key: {api_key}")
50
+ return await openai_service.get_models(api_key)
51
+
52
+
53
+ @router.post("/openai/v1/chat/completions")
54
+ @RetryHandler(key_arg="api_key")
55
+ async def chat_completion(
56
+ request: ChatRequest,
57
+ _=Depends(security_service.verify_authorization),
58
+ api_key: str = Depends(get_next_working_key_wrapper),
59
+ key_manager: KeyManager = Depends(get_key_manager),
60
+ openai_service: OpenAICompatiableService = Depends(get_openai_service),
61
+ ):
62
+ """处理聊天补全请求,支持流式响应和特定模型切换。"""
63
+ operation_name = "chat_completion"
64
+ is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
65
+ current_api_key = api_key
66
+ if is_image_chat:
67
+ current_api_key = await key_manager.get_paid_key()
68
+
69
+ async with handle_route_errors(logger, operation_name):
70
+ logger.info(f"Handling chat completion request for model: {request.model}")
71
+ logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
72
+ logger.info(f"Using API key: {current_api_key}")
73
+
74
+ if is_image_chat:
75
+ response = await openai_service.create_image_chat_completion(request, current_api_key)
76
+ return response
77
+ else:
78
+ response = await openai_service.create_chat_completion(request, current_api_key)
79
+ if request.stream:
80
+ return StreamingResponse(response, media_type="text/event-stream")
81
+ return response
82
+
83
+
84
+ @router.post("/openai/v1/images/generations")
85
+ async def generate_image(
86
+ request: ImageGenerationRequest,
87
+ _=Depends(security_service.verify_authorization),
88
+ openai_service: OpenAICompatiableService = Depends(get_openai_service),
89
+ ):
90
+ """处理图像生成请求。"""
91
+ operation_name = "generate_image"
92
+ async with handle_route_errors(logger, operation_name):
93
+ logger.info(f"Handling image generation request for prompt: {request.prompt}")
94
+ request.model = settings.CREATE_IMAGE_MODEL
95
+ return await openai_service.generate_images(request)
96
+
97
+
98
+ @router.post("/openai/v1/embeddings")
99
+ async def embedding(
100
+ request: EmbeddingRequest,
101
+ _=Depends(security_service.verify_authorization),
102
+ key_manager: KeyManager = Depends(get_key_manager),
103
+ openai_service: OpenAICompatiableService = Depends(get_openai_service),
104
+ ):
105
+ """处理文本嵌入请求。"""
106
+ operation_name = "embedding"
107
+ async with handle_route_errors(logger, operation_name):
108
+ logger.info(f"Handling embedding request for model: {request.model}")
109
+ api_key = await key_manager.get_next_working_key()
110
+ logger.info(f"Using API key: {api_key}")
111
+ return await openai_service.create_embeddings(
112
+ input_text=request.input, model=request.model, api_key=api_key
113
+ )
app/router/openai_routes.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, Response
2
+ from fastapi.responses import StreamingResponse
3
+
4
+ from app.config.config import settings
5
+ from app.core.security import SecurityService
6
+ from app.domain.openai_models import (
7
+ ChatRequest,
8
+ EmbeddingRequest,
9
+ ImageGenerationRequest,
10
+ TTSRequest,
11
+ )
12
+ from app.handler.retry_handler import RetryHandler
13
+ from app.handler.error_handler import handle_route_errors
14
+ from app.log.logger import get_openai_logger
15
+ from app.service.chat.openai_chat_service import OpenAIChatService
16
+ from app.service.embedding.embedding_service import EmbeddingService
17
+ from app.service.image.image_create_service import ImageCreateService
18
+ from app.service.tts.tts_service import TTSService
19
+ from app.service.key.key_manager import KeyManager, get_key_manager_instance
20
+ from app.service.model.model_service import ModelService
21
+
22
+ router = APIRouter()
23
+ logger = get_openai_logger()
24
+
25
+ security_service = SecurityService()
26
+ model_service = ModelService()
27
+ embedding_service = EmbeddingService()
28
+ image_create_service = ImageCreateService()
29
+ tts_service = TTSService()
30
+
31
+
32
+ async def get_key_manager():
33
+ return await get_key_manager_instance()
34
+
35
+
36
+ async def get_next_working_key_wrapper(
37
+ key_manager: KeyManager = Depends(get_key_manager),
38
+ ):
39
+ return await key_manager.get_next_working_key()
40
+
41
+
42
+ async def get_openai_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
43
+ """获取OpenAI聊天服务实例"""
44
+ return OpenAIChatService(settings.BASE_URL, key_manager)
45
+
46
+
47
+ async def get_tts_service():
48
+ """获取TTS服务实例"""
49
+ return tts_service
50
+
51
+
52
+ @router.get("/v1/models")
53
+ @router.get("/hf/v1/models")
54
+ async def list_models(
55
+ _=Depends(security_service.verify_authorization),
56
+ key_manager: KeyManager = Depends(get_key_manager),
57
+ ):
58
+ """获取可用的 OpenAI 模型列表 (兼容 Gemini 和 OpenAI)。"""
59
+ operation_name = "list_models"
60
+ async with handle_route_errors(logger, operation_name):
61
+ logger.info("Handling models list request")
62
+ api_key = await key_manager.get_first_valid_key()
63
+ logger.info(f"Using API key: {api_key}")
64
+ return await model_service.get_gemini_openai_models(api_key)
65
+
66
+
67
+ @router.post("/v1/chat/completions")
68
+ @router.post("/hf/v1/chat/completions")
69
+ @RetryHandler(key_arg="api_key")
70
+ async def chat_completion(
71
+ request: ChatRequest,
72
+ _=Depends(security_service.verify_authorization),
73
+ api_key: str = Depends(get_next_working_key_wrapper),
74
+ key_manager: KeyManager = Depends(get_key_manager),
75
+ chat_service: OpenAIChatService = Depends(get_openai_chat_service),
76
+ ):
77
+ """处理 OpenAI 聊天补全请求,支持流式响应和特定模型切换。"""
78
+ operation_name = "chat_completion"
79
+ is_image_chat = request.model == f"{settings.CREATE_IMAGE_MODEL}-chat"
80
+ current_api_key = api_key
81
+ if is_image_chat:
82
+ current_api_key = await key_manager.get_paid_key()
83
+
84
+ async with handle_route_errors(logger, operation_name):
85
+ logger.info(f"Handling chat completion request for model: {request.model}")
86
+ logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
87
+ logger.info(f"Using API key: {current_api_key}")
88
+
89
+ if not await model_service.check_model_support(request.model):
90
+ raise HTTPException(
91
+ status_code=400, detail=f"Model {request.model} is not supported"
92
+ )
93
+
94
+ if is_image_chat:
95
+ response = await chat_service.create_image_chat_completion(request, current_api_key)
96
+ if request.stream:
97
+ return StreamingResponse(response, media_type="text/event-stream")
98
+ return response
99
+ else:
100
+ response = await chat_service.create_chat_completion(request, current_api_key)
101
+ if request.stream:
102
+ return StreamingResponse(response, media_type="text/event-stream")
103
+ return response
104
+
105
+
106
+ @router.post("/v1/images/generations")
107
+ @router.post("/hf/v1/images/generations")
108
+ async def generate_image(
109
+ request: ImageGenerationRequest,
110
+ _=Depends(security_service.verify_authorization),
111
+ ):
112
+ """处理 OpenAI 图像生成请求。"""
113
+ operation_name = "generate_image"
114
+ async with handle_route_errors(logger, operation_name):
115
+ logger.info(f"Handling image generation request for prompt: {request.prompt}")
116
+ response = image_create_service.generate_images(request)
117
+ return response
118
+
119
+
120
+ @router.post("/v1/embeddings")
121
+ @router.post("/hf/v1/embeddings")
122
+ async def embedding(
123
+ request: EmbeddingRequest,
124
+ _=Depends(security_service.verify_authorization),
125
+ key_manager: KeyManager = Depends(get_key_manager),
126
+ ):
127
+ """处理 OpenAI 文本嵌入请求。"""
128
+ operation_name = "embedding"
129
+ async with handle_route_errors(logger, operation_name):
130
+ logger.info(f"Handling embedding request for model: {request.model}")
131
+ api_key = await key_manager.get_next_working_key()
132
+ logger.info(f"Using API key: {api_key}")
133
+ response = await embedding_service.create_embedding(
134
+ input_text=request.input, model=request.model, api_key=api_key
135
+ )
136
+ return response
137
+
138
+
139
+ @router.get("/v1/keys/list")
140
+ @router.get("/hf/v1/keys/list")
141
+ async def get_keys_list(
142
+ _=Depends(security_service.verify_auth_token),
143
+ key_manager: KeyManager = Depends(get_key_manager),
144
+ ):
145
+ """获取有效和无效的API key列表 (需要管理 Token 认证)。"""
146
+ operation_name = "get_keys_list"
147
+ async with handle_route_errors(logger, operation_name):
148
+ logger.info("Handling keys list request")
149
+ keys_status = await key_manager.get_keys_by_status()
150
+ return {
151
+ "status": "success",
152
+ "data": {
153
+ "valid_keys": keys_status["valid_keys"],
154
+ "invalid_keys": keys_status["invalid_keys"],
155
+ },
156
+ "total": len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"]),
157
+ }
158
+
159
+
160
+ @router.post("/v1/audio/speech")
161
+ @router.post("/hf/v1/audio/speech")
162
+ async def text_to_speech(
163
+ request: TTSRequest,
164
+ _=Depends(security_service.verify_authorization),
165
+ api_key: str = Depends(get_next_working_key_wrapper),
166
+ tts_service: TTSService = Depends(get_tts_service),
167
+ ):
168
+ """处理 OpenAI TTS 请求。"""
169
+ operation_name = "text_to_speech"
170
+ async with handle_route_errors(logger, operation_name):
171
+ logger.info(f"Handling TTS request for model: {request.model}")
172
+ logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
173
+ logger.info(f"Using API key: {api_key}")
174
+ audio_data = await tts_service.create_tts(request, api_key)
175
+ return Response(content=audio_data, media_type="audio/wav")
app/router/routes.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 路由配置模块,负责设置和配置应用程序的路由
3
+ """
4
+
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.responses import HTMLResponse, RedirectResponse
7
+ from fastapi.templating import Jinja2Templates
8
+
9
+ from app.core.security import verify_auth_token
10
+ from app.log.logger import get_routes_logger
11
+ from app.router import error_log_routes, gemini_routes, openai_routes, config_routes, scheduler_routes, stats_routes, version_routes, openai_compatiable_routes, vertex_express_routes
12
+ from app.service.key.key_manager import get_key_manager_instance
13
+ from app.service.stats.stats_service import StatsService
14
+
15
+ logger = get_routes_logger()
16
+
17
+ templates = Jinja2Templates(directory="app/templates")
18
+
19
+
20
+ def setup_routers(app: FastAPI) -> None:
21
+ """
22
+ 设置应用程序的路由
23
+
24
+ Args:
25
+ app: FastAPI应用程序实例
26
+ """
27
+ app.include_router(openai_routes.router)
28
+ app.include_router(gemini_routes.router)
29
+ app.include_router(gemini_routes.router_v1beta)
30
+ app.include_router(config_routes.router)
31
+ app.include_router(error_log_routes.router)
32
+ app.include_router(scheduler_routes.router)
33
+ app.include_router(stats_routes.router)
34
+ app.include_router(version_routes.router)
35
+ app.include_router(openai_compatiable_routes.router)
36
+ app.include_router(vertex_express_routes.router)
37
+
38
+ setup_page_routes(app)
39
+
40
+ setup_health_routes(app)
41
+ setup_api_stats_routes(app)
42
+
43
+
44
+ def setup_page_routes(app: FastAPI) -> None:
45
+ """
46
+ 设置页面相关的路由
47
+
48
+ Args:
49
+ app: FastAPI应用程序实例
50
+ """
51
+
52
+ @app.get("/", response_class=HTMLResponse)
53
+ async def auth_page(request: Request):
54
+ """认证页面"""
55
+ return templates.TemplateResponse("auth.html", {"request": request})
56
+
57
+ @app.post("/auth")
58
+ async def authenticate(request: Request):
59
+ """处理认证请求"""
60
+ try:
61
+ form = await request.form()
62
+ auth_token = form.get("auth_token")
63
+ if not auth_token:
64
+ logger.warning("Authentication attempt with empty token")
65
+ return RedirectResponse(url="/", status_code=302)
66
+
67
+ if verify_auth_token(auth_token):
68
+ logger.info("Successful authentication")
69
+ response = RedirectResponse(url="/config", status_code=302)
70
+ response.set_cookie(
71
+ key="auth_token", value=auth_token, httponly=True, max_age=3600
72
+ )
73
+ return response
74
+ logger.warning("Failed authentication attempt with invalid token")
75
+ return RedirectResponse(url="/", status_code=302)
76
+ except Exception as e:
77
+ logger.error(f"Authentication error: {str(e)}")
78
+ return RedirectResponse(url="/", status_code=302)
79
+
80
+ @app.get("/keys", response_class=HTMLResponse)
81
+ async def keys_page(request: Request):
82
+ """密钥管理页面"""
83
+ try:
84
+ auth_token = request.cookies.get("auth_token")
85
+ if not auth_token or not verify_auth_token(auth_token):
86
+ logger.warning("Unauthorized access attempt to keys page")
87
+ return RedirectResponse(url="/", status_code=302)
88
+
89
+ key_manager = await get_key_manager_instance()
90
+ keys_status = await key_manager.get_keys_by_status()
91
+ total_keys = len(keys_status["valid_keys"]) + len(keys_status["invalid_keys"])
92
+ valid_key_count = len(keys_status["valid_keys"])
93
+ invalid_key_count = len(keys_status["invalid_keys"])
94
+
95
+ stats_service = StatsService()
96
+ api_stats = await stats_service.get_api_usage_stats()
97
+ logger.info(f"API stats retrieved: {api_stats}")
98
+
99
+ logger.info(f"Keys status retrieved successfully. Total keys: {total_keys}")
100
+ return templates.TemplateResponse(
101
+ "keys_status.html",
102
+ {
103
+ "request": request,
104
+ "valid_keys": keys_status["valid_keys"],
105
+ "invalid_keys": keys_status["invalid_keys"],
106
+ "total_keys": total_keys,
107
+ "valid_key_count": valid_key_count,
108
+ "invalid_key_count": invalid_key_count,
109
+ "api_stats": api_stats,
110
+ },
111
+ )
112
+ except Exception as e:
113
+ logger.error(f"Error retrieving keys status or API stats: {str(e)}")
114
+ raise
115
+
116
+ @app.get("/config", response_class=HTMLResponse)
117
+ async def config_page(request: Request):
118
+ """配置编辑页面"""
119
+ try:
120
+ auth_token = request.cookies.get("auth_token")
121
+ if not auth_token or not verify_auth_token(auth_token):
122
+ logger.warning("Unauthorized access attempt to config page")
123
+ return RedirectResponse(url="/", status_code=302)
124
+
125
+ logger.info("Config page accessed successfully")
126
+ return templates.TemplateResponse("config_editor.html", {"request": request})
127
+ except Exception as e:
128
+ logger.error(f"Error accessing config page: {str(e)}")
129
+ raise
130
+
131
+ @app.get("/logs", response_class=HTMLResponse)
132
+ async def logs_page(request: Request):
133
+ """错误日志页面"""
134
+ try:
135
+ auth_token = request.cookies.get("auth_token")
136
+ if not auth_token or not verify_auth_token(auth_token):
137
+ logger.warning("Unauthorized access attempt to logs page")
138
+ return RedirectResponse(url="/", status_code=302)
139
+
140
+ logger.info("Logs page accessed successfully")
141
+ return templates.TemplateResponse("error_logs.html", {"request": request})
142
+ except Exception as e:
143
+ logger.error(f"Error accessing logs page: {str(e)}")
144
+ raise
145
+
146
+
147
+ def setup_health_routes(app: FastAPI) -> None:
148
+ """
149
+ 设置健康检查相关的路由
150
+
151
+ Args:
152
+ app: FastAPI应用程序实例
153
+ """
154
+
155
+ @app.get("/health")
156
+ async def health_check(request: Request):
157
+ """健康检查端点"""
158
+ logger.info("Health check endpoint called")
159
+ return {"status": "healthy"}
160
+
161
+
162
+ def setup_api_stats_routes(app: FastAPI) -> None:
163
+ """
164
+ 设置 API 统计相关的路由
165
+
166
+ Args:
167
+ app: FastAPI应用程序实例
168
+ """
169
+ @app.get("/api/stats/details")
170
+ async def api_stats_details(request: Request, period: str):
171
+ """获取指定时间段内的 API 调用详情"""
172
+ try:
173
+ auth_token = request.cookies.get("auth_token")
174
+ if not auth_token or not verify_auth_token(auth_token):
175
+ logger.warning("Unauthorized access attempt to API stats details")
176
+ return {"error": "Unauthorized"}, 401
177
+
178
+ logger.info(f"Fetching API call details for period: {period}")
179
+ stats_service = StatsService()
180
+ details = await stats_service.get_api_call_details(period)
181
+ return details
182
+ except ValueError as e:
183
+ logger.warning(f"Invalid period requested for API stats details: {period} - {str(e)}")
184
+ return {"error": str(e)}, 400
185
+ except Exception as e:
186
+ logger.error(f"Error fetching API stats details for period {period}: {str(e)}")
187
+ return {"error": "Internal server error"}, 500
app/router/scheduler_routes.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 定时任务控制路由模块
3
+ """
4
+
5
+ from fastapi import APIRouter, Request, HTTPException, status
6
+ from fastapi.responses import JSONResponse
7
+
8
+ from app.core.security import verify_auth_token
9
+ from app.scheduler.scheduled_tasks import start_scheduler, stop_scheduler
10
+ from app.log.logger import get_scheduler_routes
11
+
12
+ logger = get_scheduler_routes()
13
+
14
+ router = APIRouter(
15
+ prefix="/api/scheduler",
16
+ tags=["Scheduler"]
17
+ )
18
+
19
+ async def verify_token(request: Request):
20
+ auth_token = request.cookies.get("auth_token")
21
+ if not auth_token or not verify_auth_token(auth_token):
22
+ logger.warning("Unauthorized access attempt to scheduler API")
23
+ raise HTTPException(
24
+ status_code=status.HTTP_401_UNAUTHORIZED,
25
+ detail="Not authenticated",
26
+ headers={"WWW-Authenticate": "Bearer"},
27
+ )
28
+
29
+ @router.post("/start", summary="启动定时任务")
30
+ async def start_scheduler_endpoint(request: Request):
31
+ """Start the background scheduler task"""
32
+ await verify_token(request)
33
+ try:
34
+ logger.info("Received request to start scheduler.")
35
+ start_scheduler()
36
+ return JSONResponse(content={"message": "Scheduler started successfully."}, status_code=status.HTTP_200_OK)
37
+ except Exception as e:
38
+ logger.error(f"Error starting scheduler: {str(e)}", exc_info=True)
39
+ raise HTTPException(
40
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
41
+ detail=f"Failed to start scheduler: {str(e)}"
42
+ )
43
+
44
+ @router.post("/stop", summary="停止定时任务")
45
+ async def stop_scheduler_endpoint(request: Request):
46
+ """Stop the background scheduler task"""
47
+ await verify_token(request)
48
+ try:
49
+ logger.info("Received request to stop scheduler.")
50
+ stop_scheduler()
51
+ return JSONResponse(content={"message": "Scheduler stopped successfully."}, status_code=status.HTTP_200_OK)
52
+ except Exception as e:
53
+ logger.error(f"Error stopping scheduler: {str(e)}", exc_info=True)
54
+ raise HTTPException(
55
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
56
+ detail=f"Failed to stop scheduler: {str(e)}"
57
+ )
app/router/stats_routes.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException, Request
2
+ from starlette import status
3
+ from app.core.security import verify_auth_token
4
+ from app.service.stats.stats_service import StatsService
5
+ from app.log.logger import get_stats_logger
6
+
7
+ logger = get_stats_logger()
8
+
9
+
10
+ async def verify_token(request: Request):
11
+ auth_token = request.cookies.get("auth_token")
12
+ if not auth_token or not verify_auth_token(auth_token):
13
+ logger.warning("Unauthorized access attempt to scheduler API")
14
+ raise HTTPException(
15
+ status_code=status.HTTP_401_UNAUTHORIZED,
16
+ detail="Not authenticated",
17
+ headers={"WWW-Authenticate": "Bearer"},
18
+ )
19
+
20
+ router = APIRouter(
21
+ prefix="/api",
22
+ tags=["stats"],
23
+ dependencies=[Depends(verify_token)]
24
+ )
25
+
26
+ stats_service = StatsService()
27
+
28
+ @router.get("/key-usage-details/{key}",
29
+ summary="获取指定密钥最近24小时的模型调用次数",
30
+ description="根据提供的 API 密钥,返回过去24小时内每个模型被调用的次数统计。")
31
+ async def get_key_usage_details(key: str):
32
+ """
33
+ Retrieves the model usage count for a specific API key within the last 24 hours.
34
+
35
+ Args:
36
+ key: The API key to get usage details for.
37
+
38
+ Returns:
39
+ A dictionary with model names as keys and their call counts as values.
40
+ Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
41
+
42
+ Raises:
43
+ HTTPException: If an error occurs during data retrieval.
44
+ """
45
+ try:
46
+ usage_details = await stats_service.get_key_usage_details_last_24h(key)
47
+ if usage_details is None:
48
+ return {}
49
+ return usage_details
50
+ except Exception as e:
51
+ logger.error(f"Error fetching key usage details for key {key[:4]}...: {e}")
52
+ raise HTTPException(
53
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
54
+ detail=f"获取密钥使用详情时出错: {e}"
55
+ )
app/router/version_routes.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional
4
+
5
+ from app.service.update.update_service import check_for_updates
6
+ from app.utils.helpers import get_current_version
7
+ from app.log.logger import get_update_logger
8
+
9
+ router = APIRouter(prefix="/api/version", tags=["Version"])
10
+ logger = get_update_logger()
11
+
12
+ class VersionInfo(BaseModel):
13
+ current_version: str = Field(..., description="当前应用程序版本")
14
+ latest_version: Optional[str] = Field(None, description="可用的最新版本")
15
+ update_available: bool = Field(False, description="是否有可用更新")
16
+ error_message: Optional[str] = Field(None, description="检查更新时发生的错误信息")
17
+
18
+ @router.get("/check", response_model=VersionInfo, summary="检查应用程序更新")
19
+ async def get_version_info():
20
+ """
21
+ 检查当前应用程序版本与最新的 GitHub release 版本。
22
+ """
23
+ try:
24
+ current_version = get_current_version()
25
+ update_available, latest_version, error_message = await check_for_updates()
26
+
27
+ logger.info(f"Version check API result: current={current_version}, latest={latest_version}, available={update_available}, error='{error_message}'")
28
+
29
+ return VersionInfo(
30
+ current_version=current_version,
31
+ latest_version=latest_version,
32
+ update_available=update_available,
33
+ error_message=error_message
34
+ )
35
+ except Exception as e:
36
+ logger.error(f"Error in /api/version/check endpoint: {e}", exc_info=True)
37
+ raise HTTPException(status_code=500, detail="检查版本信息时发生内部错误")
app/router/vertex_express_routes.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from copy import deepcopy
4
+ from app.config.config import settings
5
+ from app.log.logger import get_vertex_express_logger
6
+ from app.core.security import SecurityService
7
+ from app.domain.gemini_models import GeminiRequest
8
+ from app.service.chat.vertex_express_chat_service import GeminiChatService
9
+ from app.service.key.key_manager import KeyManager, get_key_manager_instance
10
+ from app.service.model.model_service import ModelService
11
+ from app.handler.retry_handler import RetryHandler
12
+ from app.handler.error_handler import handle_route_errors
13
+ from app.core.constants import API_VERSION
14
+
15
+ router = APIRouter(prefix=f"/vertex-express/{API_VERSION}")
16
+ logger = get_vertex_express_logger()
17
+
18
+ security_service = SecurityService()
19
+ model_service = ModelService()
20
+
21
+
22
+ async def get_key_manager():
23
+ """获取密钥管理器实例"""
24
+ return await get_key_manager_instance()
25
+
26
+
27
+ async def get_next_working_key(key_manager: KeyManager = Depends(get_key_manager)):
28
+ """获取下一个可用的API密钥"""
29
+ return await key_manager.get_next_working_vertex_key()
30
+
31
+
32
+ async def get_chat_service(key_manager: KeyManager = Depends(get_key_manager)):
33
+ """获取Gemini聊天服务实例"""
34
+ return GeminiChatService(settings.VERTEX_EXPRESS_BASE_URL, key_manager)
35
+
36
+
37
+ @router.get("/models")
38
+ async def list_models(
39
+ _=Depends(security_service.verify_key_or_goog_api_key),
40
+ key_manager: KeyManager = Depends(get_key_manager)
41
+ ):
42
+ """获取可用的 Gemini 模型列表,并根据配置添加衍生模型(搜索、图像、非思考)。"""
43
+ operation_name = "list_gemini_models"
44
+ logger.info("-" * 50 + operation_name + "-" * 50)
45
+ logger.info("Handling Gemini models list request")
46
+
47
+ try:
48
+ api_key = await key_manager.get_first_valid_key()
49
+ if not api_key:
50
+ raise HTTPException(status_code=503, detail="No valid API keys available to fetch models.")
51
+ logger.info(f"Using API key: {api_key}")
52
+
53
+ models_data = await model_service.get_gemini_models(api_key)
54
+ if not models_data or "models" not in models_data:
55
+ raise HTTPException(status_code=500, detail="Failed to fetch base models list.")
56
+
57
+ models_json = deepcopy(models_data)
58
+ model_mapping = {x.get("name", "").split("/", maxsplit=1)[-1]: x for x in models_json.get("models", [])}
59
+
60
+ def add_derived_model(base_name, suffix, display_suffix):
61
+ model = model_mapping.get(base_name)
62
+ if not model:
63
+ logger.warning(f"Base model '{base_name}' not found for derived model '{suffix}'.")
64
+ return
65
+ item = deepcopy(model)
66
+ item["name"] = f"models/{base_name}{suffix}"
67
+ display_name = f'{item.get("displayName", base_name)}{display_suffix}'
68
+ item["displayName"] = display_name
69
+ item["description"] = display_name
70
+ models_json["models"].append(item)
71
+
72
+ if settings.SEARCH_MODELS:
73
+ for name in settings.SEARCH_MODELS:
74
+ add_derived_model(name, "-search", " For Search")
75
+ if settings.IMAGE_MODELS:
76
+ for name in settings.IMAGE_MODELS:
77
+ add_derived_model(name, "-image", " For Image")
78
+ if settings.THINKING_MODELS:
79
+ for name in settings.THINKING_MODELS:
80
+ add_derived_model(name, "-non-thinking", " Non Thinking")
81
+
82
+ logger.info("Gemini models list request successful")
83
+ return models_json
84
+ except HTTPException as http_exc:
85
+ raise http_exc
86
+ except Exception as e:
87
+ logger.error(f"Error getting Gemini models list: {str(e)}")
88
+ raise HTTPException(
89
+ status_code=500, detail="Internal server error while fetching Gemini models list"
90
+ ) from e
91
+
92
+
93
+ @router.post("/models/{model_name}:generateContent")
94
+ @RetryHandler(key_arg="api_key")
95
+ async def generate_content(
96
+ model_name: str,
97
+ request: GeminiRequest,
98
+ _=Depends(security_service.verify_key_or_goog_api_key),
99
+ api_key: str = Depends(get_next_working_key),
100
+ key_manager: KeyManager = Depends(get_key_manager),
101
+ chat_service: GeminiChatService = Depends(get_chat_service)
102
+ ):
103
+ """处理 Gemini 非流式内容生成请求。"""
104
+ operation_name = "gemini_generate_content"
105
+ async with handle_route_errors(logger, operation_name, failure_message="Content generation failed"):
106
+ logger.info(f"Handling Gemini content generation request for model: {model_name}")
107
+ logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
108
+ logger.info(f"Using API key: {api_key}")
109
+
110
+ if not await model_service.check_model_support(model_name):
111
+ raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
112
+
113
+ response = await chat_service.generate_content(
114
+ model=model_name,
115
+ request=request,
116
+ api_key=api_key
117
+ )
118
+ return response
119
+
120
+
121
+ @router.post("/models/{model_name}:streamGenerateContent")
122
+ @RetryHandler(key_arg="api_key")
123
+ async def stream_generate_content(
124
+ model_name: str,
125
+ request: GeminiRequest,
126
+ _=Depends(security_service.verify_key_or_goog_api_key),
127
+ api_key: str = Depends(get_next_working_key),
128
+ key_manager: KeyManager = Depends(get_key_manager),
129
+ chat_service: GeminiChatService = Depends(get_chat_service)
130
+ ):
131
+ """处理 Gemini 流式内容生成请求。"""
132
+ operation_name = "gemini_stream_generate_content"
133
+ async with handle_route_errors(logger, operation_name, failure_message="Streaming request initiation failed"):
134
+ logger.info(f"Handling Gemini streaming content generation for model: {model_name}")
135
+ logger.debug(f"Request: \n{request.model_dump_json(indent=2)}")
136
+ logger.info(f"Using API key: {api_key}")
137
+
138
+ if not await model_service.check_model_support(model_name):
139
+ raise HTTPException(status_code=400, detail=f"Model {model_name} is not supported")
140
+
141
+ response_stream = chat_service.stream_generate_content(
142
+ model=model_name,
143
+ request=request,
144
+ api_key=api_key
145
+ )
146
+ return StreamingResponse(response_stream, media_type="text/event-stream")
app/scheduler/scheduled_tasks.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
3
+
4
+ from app.config.config import settings
5
+ from app.domain.gemini_models import GeminiContent, GeminiRequest
6
+ from app.log.logger import Logger
7
+ from app.service.chat.gemini_chat_service import GeminiChatService
8
+ from app.service.error_log.error_log_service import delete_old_error_logs
9
+ from app.service.key.key_manager import get_key_manager_instance
10
+ from app.service.request_log.request_log_service import delete_old_request_logs_task
11
+
12
+ logger = Logger.setup_logger("scheduler")
13
+
14
+
15
+ async def check_failed_keys():
16
+ """
17
+ 定时检查失败次数大于0的API密钥,并尝试验证它们。
18
+ 如果验证成功,重置失败计数;如果失败,增加失败计数。
19
+ """
20
+ logger.info("Starting scheduled check for failed API keys...")
21
+ try:
22
+ key_manager = await get_key_manager_instance()
23
+ # 确保 KeyManager 已经初始化
24
+ if not key_manager or not hasattr(key_manager, "key_failure_counts"):
25
+ logger.warning(
26
+ "KeyManager instance not available or not initialized. Skipping check."
27
+ )
28
+ return
29
+
30
+ # 创建 GeminiChatService 实例用于验证
31
+ # 注意:这里直接创建实例,而不是通过依赖注入,因为这是后台任务
32
+ chat_service = GeminiChatService(settings.BASE_URL, key_manager)
33
+
34
+ # 获取需要检查的 key 列表 (失败次数 > 0)
35
+ keys_to_check = []
36
+ async with key_manager.failure_count_lock: # 访问共享数据需要加锁
37
+ # 复制一份以避免在迭代时修改字典
38
+ failure_counts_copy = key_manager.key_failure_counts.copy()
39
+ keys_to_check = [
40
+ key for key, count in failure_counts_copy.items() if count > 0
41
+ ] # 检查所有失败次数大于0的key
42
+
43
+ if not keys_to_check:
44
+ logger.info("No keys with failure count > 0 found. Skipping verification.")
45
+ return
46
+
47
+ logger.info(
48
+ f"Found {len(keys_to_check)} keys with failure count > 0 to verify."
49
+ )
50
+
51
+ for key in keys_to_check:
52
+ # 隐藏部分 key 用于日志记录
53
+ log_key = f"{key[:4]}...{key[-4:]}" if len(key) > 8 else key
54
+ logger.info(f"Verifying key: {log_key}...")
55
+ try:
56
+ # 构造测试请求
57
+ gemini_request = GeminiRequest(
58
+ contents=[
59
+ GeminiContent(
60
+ role="user",
61
+ parts=[{"text": "hi"}],
62
+ )
63
+ ]
64
+ )
65
+ await chat_service.generate_content(
66
+ settings.TEST_MODEL, gemini_request, key
67
+ )
68
+ logger.info(
69
+ f"Key {log_key} verification successful. Resetting failure count."
70
+ )
71
+ await key_manager.reset_key_failure_count(key)
72
+ except Exception as e:
73
+ logger.warning(
74
+ f"Key {log_key} verification failed: {str(e)}. Incrementing failure count."
75
+ )
76
+ # 直接操作计数器,需要加锁
77
+ async with key_manager.failure_count_lock:
78
+ # 再次检查 key 是否存在且失败次数未达上限
79
+ if (
80
+ key in key_manager.key_failure_counts
81
+ and key_manager.key_failure_counts[key]
82
+ < key_manager.MAX_FAILURES
83
+ ):
84
+ key_manager.key_failure_counts[key] += 1
85
+ logger.info(
86
+ f"Failure count for key {log_key} incremented to {key_manager.key_failure_counts[key]}."
87
+ )
88
+ elif key in key_manager.key_failure_counts:
89
+ logger.warning(
90
+ f"Key {log_key} reached MAX_FAILURES ({key_manager.MAX_FAILURES}). Not incrementing further."
91
+ )
92
+
93
+ except Exception as e:
94
+ logger.error(
95
+ f"An error occurred during the scheduled key check: {str(e)}", exc_info=True
96
+ )
97
+
98
+
99
+ def setup_scheduler():
100
+ """设置并启动 APScheduler"""
101
+ scheduler = AsyncIOScheduler(timezone=str(settings.TIMEZONE)) # 从配置读取时区
102
+ # 添加检查失败密钥的定时任务
103
+ scheduler.add_job(
104
+ check_failed_keys,
105
+ "interval",
106
+ hours=settings.CHECK_INTERVAL_HOURS,
107
+ id="check_failed_keys_job",
108
+ name="Check Failed API Keys",
109
+ )
110
+ logger.info(
111
+ f"Key check job scheduled to run every {settings.CHECK_INTERVAL_HOURS} hour(s)."
112
+ )
113
+
114
+ # 新增:添加自动删除错误日志的定时任务,每天凌晨3点执行
115
+ scheduler.add_job(
116
+ delete_old_error_logs,
117
+ "cron",
118
+ hour=3,
119
+ minute=0,
120
+ id="delete_old_error_logs_job",
121
+ name="Delete Old Error Logs",
122
+ )
123
+ logger.info("Auto-delete error logs job scheduled to run daily at 3:00 AM.")
124
+
125
+ # 新增:添加自动删除请求日志的定时任务,每天凌晨3点05分执行
126
+ scheduler.add_job(
127
+ delete_old_request_logs_task,
128
+ "cron",
129
+ hour=3,
130
+ minute=5,
131
+ id="delete_old_request_logs_job",
132
+ name="Delete Old Request Logs",
133
+ )
134
+ logger.info(
135
+ f"Auto-delete request logs job scheduled to run daily at 3:05 AM, if enabled and AUTO_DELETE_REQUEST_LOGS_DAYS is set to {settings.AUTO_DELETE_REQUEST_LOGS_DAYS} days."
136
+ )
137
+
138
+ scheduler.start()
139
+ logger.info("Scheduler started with all jobs.")
140
+ return scheduler
141
+
142
+
143
+ # 可以在这里添加一个全局的 scheduler 实例,以便在应用关闭时优雅地停止
144
+ scheduler_instance = None
145
+
146
+
147
+ def start_scheduler():
148
+ global scheduler_instance
149
+ if scheduler_instance is None or not scheduler_instance.running:
150
+ logger.info("Starting scheduler...")
151
+ scheduler_instance = setup_scheduler()
152
+ logger.info("Scheduler is already running.")
153
+
154
+
155
+ def stop_scheduler():
156
+ global scheduler_instance
157
+ if scheduler_instance and scheduler_instance.running:
158
+ scheduler_instance.shutdown()
159
+ logger.info("Scheduler stopped.")
app/service/chat/gemini_chat_service.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/services/chat_service.py
2
+
3
+ import json
4
+ import re
5
+ import datetime
6
+ import time
7
+ from typing import Any, AsyncGenerator, Dict, List
8
+ from app.config.config import settings
9
+ from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
10
+ from app.domain.gemini_models import GeminiRequest
11
+ from app.handler.response_handler import GeminiResponseHandler
12
+ from app.handler.stream_optimizer import gemini_optimizer
13
+ from app.log.logger import get_gemini_logger
14
+ from app.service.client.api_client import GeminiApiClient
15
+ from app.service.key.key_manager import KeyManager
16
+ from app.database.services import add_error_log, add_request_log
17
+
18
+ logger = get_gemini_logger()
19
+
20
+
21
+ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
22
+ """判断消息是否包含图片部分"""
23
+ for content in contents:
24
+ if "parts" in content:
25
+ for part in content["parts"]:
26
+ if "image_url" in part or "inline_data" in part:
27
+ return True
28
+ return False
29
+
30
+
31
+ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
32
+ """构建工具"""
33
+
34
+ def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
35
+ record = dict()
36
+ for item in tools:
37
+ if not item or not isinstance(item, dict):
38
+ continue
39
+
40
+ for k, v in item.items():
41
+ if k == "functionDeclarations" and v and isinstance(v, list):
42
+ functions = record.get("functionDeclarations", [])
43
+ functions.extend(v)
44
+ record["functionDeclarations"] = functions
45
+ else:
46
+ record[k] = v
47
+ return record
48
+
49
+ tool = dict()
50
+ if payload and isinstance(payload, dict) and "tools" in payload:
51
+ if payload.get("tools") and isinstance(payload.get("tools"), dict):
52
+ payload["tools"] = [payload.get("tools")]
53
+ items = payload.get("tools", [])
54
+ if items and isinstance(items, list):
55
+ tool.update(_merge_tools(items))
56
+
57
+ if (
58
+ settings.TOOLS_CODE_EXECUTION_ENABLED
59
+ and not (model.endswith("-search") or "-thinking" in model)
60
+ and not _has_image_parts(payload.get("contents", []))
61
+ ):
62
+ tool["codeExecution"] = {}
63
+ if model.endswith("-search"):
64
+ tool["googleSearch"] = {}
65
+
66
+ # 解决 "Tool use with function calling is unsupported" 问题
67
+ if tool.get("functionDeclarations"):
68
+ tool.pop("googleSearch", None)
69
+ tool.pop("codeExecution", None)
70
+
71
+ return [tool] if tool else []
72
+
73
+
74
+ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
75
+ """获取安全设置"""
76
+ if model == "gemini-2.0-flash-exp":
77
+ return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
78
+ return settings.SAFETY_SETTINGS
79
+
80
+
81
+ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
82
+ """构建请求payload"""
83
+ request_dict = request.model_dump()
84
+ if request.generationConfig:
85
+ if request.generationConfig.maxOutputTokens is None:
86
+ # 如果未指定最大输出长度,则不传递该字段,解决截断的问题
87
+ request_dict["generationConfig"].pop("maxOutputTokens")
88
+
89
+ payload = {
90
+ "contents": request_dict.get("contents", []),
91
+ "tools": _build_tools(model, request_dict),
92
+ "safetySettings": _get_safety_settings(model),
93
+ "generationConfig": request_dict.get("generationConfig"),
94
+ "systemInstruction": request_dict.get("systemInstruction"),
95
+ }
96
+
97
+ if model.endswith("-image") or model.endswith("-image-generation"):
98
+ payload.pop("systemInstruction")
99
+ payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
100
+
101
+ # 处理思考配置:优先使用客户端提供的配置,否则使用默认配置
102
+ client_thinking_config = None
103
+ if request.generationConfig and request.generationConfig.thinkingConfig:
104
+ client_thinking_config = request.generationConfig.thinkingConfig
105
+
106
+ if client_thinking_config is not None:
107
+ # 客户端提供了思考配置,直接使用
108
+ payload["generationConfig"]["thinkingConfig"] = client_thinking_config
109
+ else:
110
+ # 客户端没有提供思考配置,使用默认配置
111
+ if model.endswith("-non-thinking"):
112
+ payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
113
+ elif model in settings.THINKING_BUDGET_MAP:
114
+ payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
115
+
116
+ return payload
117
+
118
+
119
+ class GeminiChatService:
120
+ """聊天服务"""
121
+
122
+ def __init__(self, base_url: str, key_manager: KeyManager):
123
+ self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
124
+ self.key_manager = key_manager
125
+ self.response_handler = GeminiResponseHandler()
126
+
127
+ def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
128
+ """从响应中提取文本内容"""
129
+ if not response.get("candidates"):
130
+ return ""
131
+
132
+ candidate = response["candidates"][0]
133
+ content = candidate.get("content", {})
134
+ parts = content.get("parts", [])
135
+
136
+ if parts and "text" in parts[0]:
137
+ return parts[0].get("text", "")
138
+ return ""
139
+
140
+ def _create_char_response(
141
+ self, original_response: Dict[str, Any], text: str
142
+ ) -> Dict[str, Any]:
143
+ """创建包含指定文本的响应"""
144
+ response_copy = json.loads(json.dumps(original_response))
145
+ if response_copy.get("candidates") and response_copy["candidates"][0].get(
146
+ "content", {}
147
+ ).get("parts"):
148
+ response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
149
+ return response_copy
150
+
151
+ async def generate_content(
152
+ self, model: str, request: GeminiRequest, api_key: str
153
+ ) -> Dict[str, Any]:
154
+ """生成内容"""
155
+ payload = _build_payload(model, request)
156
+ start_time = time.perf_counter()
157
+ request_datetime = datetime.datetime.now()
158
+ is_success = False
159
+ status_code = None
160
+ response = None
161
+
162
+ try:
163
+ response = await self.api_client.generate_content(payload, model, api_key)
164
+ is_success = True
165
+ status_code = 200
166
+ return self.response_handler.handle_response(response, model, stream=False)
167
+ except Exception as e:
168
+ is_success = False
169
+ error_log_msg = str(e)
170
+ logger.error(f"Normal API call failed with error: {error_log_msg}")
171
+ match = re.search(r"status code (\d+)", error_log_msg)
172
+ if match:
173
+ status_code = int(match.group(1))
174
+ else:
175
+ status_code = 500
176
+
177
+ await add_error_log(
178
+ gemini_key=api_key,
179
+ model_name=model,
180
+ error_type="gemini-chat-non-stream",
181
+ error_log=error_log_msg,
182
+ error_code=status_code,
183
+ request_msg=payload
184
+ )
185
+ raise e
186
+ finally:
187
+ end_time = time.perf_counter()
188
+ latency_ms = int((end_time - start_time) * 1000)
189
+ await add_request_log(
190
+ model_name=model,
191
+ api_key=api_key,
192
+ is_success=is_success,
193
+ status_code=status_code,
194
+ latency_ms=latency_ms,
195
+ request_time=request_datetime
196
+ )
197
+
198
+ async def stream_generate_content(
199
+ self, model: str, request: GeminiRequest, api_key: str
200
+ ) -> AsyncGenerator[str, None]:
201
+ """流式生成内容"""
202
+ retries = 0
203
+ max_retries = settings.MAX_RETRIES
204
+ payload = _build_payload(model, request)
205
+ is_success = False
206
+ status_code = None
207
+ final_api_key = api_key
208
+
209
+ while retries < max_retries:
210
+ request_datetime = datetime.datetime.now()
211
+ start_time = time.perf_counter()
212
+ current_attempt_key = api_key
213
+ final_api_key = current_attempt_key
214
+ try:
215
+ async for line in self.api_client.stream_generate_content(
216
+ payload, model, current_attempt_key
217
+ ):
218
+ # print(line)
219
+ if line.startswith("data:"):
220
+ line = line[6:]
221
+ response_data = self.response_handler.handle_response(
222
+ json.loads(line), model, stream=True
223
+ )
224
+ text = self._extract_text_from_response(response_data)
225
+ # 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
226
+ if text and settings.STREAM_OPTIMIZER_ENABLED:
227
+ # 使用流式输出优化器处理文本输出
228
+ async for (
229
+ optimized_chunk
230
+ ) in gemini_optimizer.optimize_stream_output(
231
+ text,
232
+ lambda t: self._create_char_response(response_data, t),
233
+ lambda c: "data: " + json.dumps(c) + "\n\n",
234
+ ):
235
+ yield optimized_chunk
236
+ else:
237
+ # 如果没有文本内容(如工具调用等),整块输出
238
+ yield "data: " + json.dumps(response_data) + "\n\n"
239
+ logger.info("Streaming completed successfully")
240
+ is_success = True
241
+ status_code = 200
242
+ break
243
+ except Exception as e:
244
+ retries += 1
245
+ is_success = False
246
+ error_log_msg = str(e)
247
+ logger.warning(
248
+ f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
249
+ )
250
+ match = re.search(r"status code (\d+)", error_log_msg)
251
+ if match:
252
+ status_code = int(match.group(1))
253
+ else:
254
+ status_code = 500
255
+
256
+ await add_error_log(
257
+ gemini_key=current_attempt_key,
258
+ model_name=model,
259
+ error_type="gemini-chat-stream",
260
+ error_log=error_log_msg,
261
+ error_code=status_code,
262
+ request_msg=payload
263
+ )
264
+
265
+ api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
266
+ if api_key:
267
+ logger.info(f"Switched to new API key: {api_key}")
268
+ else:
269
+ logger.error(f"No valid API key available after {retries} retries.")
270
+ break
271
+
272
+ if retries >= max_retries:
273
+ logger.error(
274
+ f"Max retries ({max_retries}) reached for streaming."
275
+ )
276
+ break
277
+ finally:
278
+ end_time = time.perf_counter()
279
+ latency_ms = int((end_time - start_time) * 1000)
280
+ await add_request_log(
281
+ model_name=model,
282
+ api_key=final_api_key,
283
+ is_success=is_success,
284
+ status_code=status_code,
285
+ latency_ms=latency_ms,
286
+ request_time=request_datetime
287
+ )
app/service/chat/openai_chat_service.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/services/chat_service.py
2
+
3
+ import asyncio
4
+ import datetime
5
+ import json
6
+ import re
7
+ import time
8
+ from copy import deepcopy
9
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
10
+
11
+ from app.config.config import settings
12
+ from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
13
+ from app.database.services import (
14
+ add_error_log,
15
+ add_request_log,
16
+ )
17
+ from app.domain.openai_models import ChatRequest, ImageGenerationRequest
18
+ from app.handler.message_converter import OpenAIMessageConverter
19
+ from app.handler.response_handler import OpenAIResponseHandler
20
+ from app.handler.stream_optimizer import openai_optimizer
21
+ from app.log.logger import get_openai_logger
22
+ from app.service.client.api_client import GeminiApiClient
23
+ from app.service.image.image_create_service import ImageCreateService
24
+ from app.service.key.key_manager import KeyManager
25
+
26
+ logger = get_openai_logger()
27
+
28
+
29
+ def _has_media_parts(contents: List[Dict[str, Any]]) -> bool:
30
+ """判断消息是否包含图片、音频或视频部分 (inline_data)"""
31
+ for content in contents:
32
+ if content and "parts" in content and isinstance(content["parts"], list):
33
+ for part in content["parts"]:
34
+ if isinstance(part, dict) and "inline_data" in part:
35
+ return True
36
+ return False
37
+
38
+
39
+ def _build_tools(
40
+ request: ChatRequest, messages: List[Dict[str, Any]]
41
+ ) -> List[Dict[str, Any]]:
42
+ """构建工具"""
43
+ tool = dict()
44
+ model = request.model
45
+
46
+ if (
47
+ settings.TOOLS_CODE_EXECUTION_ENABLED
48
+ and not (
49
+ model.endswith("-search")
50
+ or "-thinking" in model
51
+ or model.endswith("-image")
52
+ or model.endswith("-image-generation")
53
+ )
54
+ and not _has_media_parts(messages)
55
+ ):
56
+ tool["codeExecution"] = {}
57
+ logger.debug("Code execution tool enabled.")
58
+ elif _has_media_parts(messages):
59
+ logger.debug("Code execution tool disabled due to media parts presence.")
60
+
61
+ if model.endswith("-search"):
62
+ tool["googleSearch"] = {}
63
+
64
+ # 将 request 中的 tools 合并到 tools 中
65
+ if request.tools:
66
+ function_declarations = []
67
+ for item in request.tools:
68
+ if not item or not isinstance(item, dict):
69
+ continue
70
+
71
+ if item.get("type", "") == "function" and item.get("function"):
72
+ function = deepcopy(item.get("function"))
73
+ parameters = function.get("parameters", {})
74
+ if parameters.get("type") == "object" and not parameters.get(
75
+ "properties", {}
76
+ ):
77
+ function.pop("parameters", None)
78
+
79
+ function_declarations.append(function)
80
+
81
+ if function_declarations:
82
+ # 按照 function 的 name 去重
83
+ names, functions = set(), []
84
+ for fc in function_declarations:
85
+ if fc.get("name") not in names:
86
+ if fc.get("name")=="googleSearch":
87
+ # cherry开启内置搜索时,添加googleSearch工具
88
+ tool["googleSearch"] = {}
89
+ else:
90
+ # 其他函数,添加到functionDeclarations中
91
+ names.add(fc.get("name"))
92
+ functions.append(fc)
93
+
94
+ tool["functionDeclarations"] = functions
95
+
96
+ # 解决 "Tool use with function calling is unsupported" 问题
97
+ if tool.get("functionDeclarations"):
98
+ tool.pop("googleSearch", None)
99
+ tool.pop("codeExecution", None)
100
+
101
+ return [tool] if tool else []
102
+
103
+
104
+ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
105
+ """获取安全设置"""
106
+ # if (
107
+ # "2.0" in model
108
+ # and "gemini-2.0-flash-thinking-exp" not in model
109
+ # and "gemini-2.0-pro-exp" not in model
110
+ # ):
111
+ if model == "gemini-2.0-flash-exp":
112
+ return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
113
+ return settings.SAFETY_SETTINGS
114
+
115
+
116
+ def _build_payload(
117
+ request: ChatRequest,
118
+ messages: List[Dict[str, Any]],
119
+ instruction: Optional[Dict[str, Any]] = None,
120
+ ) -> Dict[str, Any]:
121
+ """构建请求payload"""
122
+ payload = {
123
+ "contents": messages,
124
+ "generationConfig": {
125
+ "temperature": request.temperature,
126
+ "stopSequences": request.stop,
127
+ "topP": request.top_p,
128
+ "topK": request.top_k,
129
+ },
130
+ "tools": _build_tools(request, messages),
131
+ "safetySettings": _get_safety_settings(request.model),
132
+ }
133
+ if request.max_tokens is not None:
134
+ payload["generationConfig"]["maxOutputTokens"] = request.max_tokens
135
+ if request.model.endswith("-image") or request.model.endswith("-image-generation"):
136
+ payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
137
+ if request.model.endswith("-non-thinking"):
138
+ payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
139
+ if request.model in settings.THINKING_BUDGET_MAP:
140
+ payload["generationConfig"]["thinkingConfig"] = {
141
+ "thinkingBudget": settings.THINKING_BUDGET_MAP.get(request.model, 1000)
142
+ }
143
+
144
+ if (
145
+ instruction
146
+ and isinstance(instruction, dict)
147
+ and instruction.get("role") == "system"
148
+ and instruction.get("parts")
149
+ and not request.model.endswith("-image")
150
+ and not request.model.endswith("-image-generation")
151
+ ):
152
+ payload["systemInstruction"] = instruction
153
+
154
+ return payload
155
+
156
+
157
+ class OpenAIChatService:
158
+ """聊天服务"""
159
+
160
+ def __init__(self, base_url: str, key_manager: KeyManager = None):
161
+ self.message_converter = OpenAIMessageConverter()
162
+ self.response_handler = OpenAIResponseHandler(config=None)
163
+ self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
164
+ self.key_manager = key_manager
165
+ self.image_create_service = ImageCreateService()
166
+
167
+ def _extract_text_from_openai_chunk(self, chunk: Dict[str, Any]) -> str:
168
+ """从OpenAI响应块中提取文本内容"""
169
+ if not chunk.get("choices"):
170
+ return ""
171
+
172
+ choice = chunk["choices"][0]
173
+ if "delta" in choice and "content" in choice["delta"]:
174
+ return choice["delta"]["content"]
175
+ return ""
176
+
177
+ def _create_char_openai_chunk(
178
+ self, original_chunk: Dict[str, Any], text: str
179
+ ) -> Dict[str, Any]:
180
+ """创建包含指定文本的OpenAI响应块"""
181
+ chunk_copy = json.loads(json.dumps(original_chunk))
182
+ if chunk_copy.get("choices") and "delta" in chunk_copy["choices"][0]:
183
+ chunk_copy["choices"][0]["delta"]["content"] = text
184
+ return chunk_copy
185
+
186
+ async def create_chat_completion(
187
+ self,
188
+ request: ChatRequest,
189
+ api_key: str,
190
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
191
+ """创建聊天完成"""
192
+ messages, instruction = self.message_converter.convert(request.messages)
193
+
194
+ payload = _build_payload(request, messages, instruction)
195
+
196
+ if request.stream:
197
+ return self._handle_stream_completion(request.model, payload, api_key)
198
+ return await self._handle_normal_completion(request.model, payload, api_key)
199
+
200
+ async def _handle_normal_completion(
201
+ self, model: str, payload: Dict[str, Any], api_key: str
202
+ ) -> Dict[str, Any]:
203
+ """处理普通聊天完成"""
204
+ start_time = time.perf_counter()
205
+ request_datetime = datetime.datetime.now()
206
+ is_success = False
207
+ status_code = None
208
+ response = None
209
+ try:
210
+ response = await self.api_client.generate_content(payload, model, api_key)
211
+ usage_metadata = response.get("usageMetadata", {})
212
+ is_success = True
213
+ status_code = 200
214
+ return self.response_handler.handle_response(
215
+ response,
216
+ model,
217
+ stream=False,
218
+ finish_reason="stop",
219
+ usage_metadata=usage_metadata,
220
+ )
221
+ except Exception as e:
222
+ is_success = False
223
+ error_log_msg = str(e)
224
+ logger.error(f"Normal API call failed with error: {error_log_msg}")
225
+ match = re.search(r"status code (\d+)", error_log_msg)
226
+ if match:
227
+ status_code = int(match.group(1))
228
+ else:
229
+ status_code = 500
230
+
231
+ await add_error_log(
232
+ gemini_key=api_key,
233
+ model_name=model,
234
+ error_type="openai-chat-non-stream",
235
+ error_log=error_log_msg,
236
+ error_code=status_code,
237
+ request_msg=payload,
238
+ )
239
+ raise e
240
+ finally:
241
+ end_time = time.perf_counter()
242
+ latency_ms = int((end_time - start_time) * 1000)
243
+ await add_request_log(
244
+ model_name=model,
245
+ api_key=api_key,
246
+ is_success=is_success,
247
+ status_code=status_code,
248
+ latency_ms=latency_ms,
249
+ request_time=request_datetime,
250
+ )
251
+
252
+ async def _fake_stream_logic_impl(
253
+ self, model: str, payload: Dict[str, Any], api_key: str
254
+ ) -> AsyncGenerator[str, None]:
255
+ """处理伪流式 (fake stream) 的核心逻辑"""
256
+ logger.info(
257
+ f"Fake streaming enabled for model: {model}. Calling non-streaming endpoint."
258
+ )
259
+ keep_sending_empty_data = True
260
+
261
+ async def send_empty_data_locally() -> AsyncGenerator[str, None]:
262
+ """定期发送空数据以保持连接"""
263
+ while keep_sending_empty_data:
264
+ await asyncio.sleep(settings.FAKE_STREAM_EMPTY_DATA_INTERVAL_SECONDS)
265
+ if keep_sending_empty_data:
266
+ empty_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
267
+ yield f"data: {json.dumps(empty_chunk)}\n\n"
268
+ logger.debug("Sent empty data chunk for fake stream heartbeat.")
269
+
270
+ empty_data_generator = send_empty_data_locally()
271
+ api_response_task = asyncio.create_task(
272
+ self.api_client.generate_content(payload, model, api_key)
273
+ )
274
+
275
+ try:
276
+ while not api_response_task.done():
277
+ try:
278
+ next_empty_chunk = await asyncio.wait_for(
279
+ empty_data_generator.__anext__(), timeout=0.1
280
+ )
281
+ yield next_empty_chunk
282
+ except asyncio.TimeoutError:
283
+ pass
284
+ except (
285
+ StopAsyncIteration
286
+ ):
287
+ break
288
+
289
+ response = await api_response_task
290
+ finally:
291
+ keep_sending_empty_data = False
292
+
293
+ if response and response.get("candidates"):
294
+ response = self.response_handler.handle_response(response, model, stream=True, finish_reason='stop', usage_metadata=response.get("usageMetadata", {}))
295
+ yield f"data: {json.dumps(response)}\n\n"
296
+ logger.info(f"Sent full response content for fake stream: {model}")
297
+ else:
298
+ error_message = "Failed to get response from model"
299
+ if (
300
+ response and isinstance(response, dict) and response.get("error")
301
+ ):
302
+ error_details = response.get("error")
303
+ if isinstance(error_details, dict):
304
+ error_message = error_details.get("message", error_message)
305
+
306
+ logger.error(
307
+ f"No candidates or error in response for fake stream model {model}: {response}"
308
+ )
309
+ error_chunk = self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=None)
310
+ yield f"data: {json.dumps(error_chunk)}\n\n"
311
+
312
+ async def _real_stream_logic_impl(
313
+ self, model: str, payload: Dict[str, Any], api_key: str
314
+ ) -> AsyncGenerator[str, None]:
315
+ """处理真实流式 (real stream) 的核心逻辑"""
316
+ tool_call_flag = False
317
+ usage_metadata = None
318
+ async for line in self.api_client.stream_generate_content(
319
+ payload, model, api_key
320
+ ):
321
+ if line.startswith("data:"):
322
+ chunk_str = line[6:]
323
+ if not chunk_str or chunk_str.isspace():
324
+ logger.debug(
325
+ f"Received empty data line for model {model}, skipping."
326
+ )
327
+ continue
328
+ try:
329
+ chunk = json.loads(chunk_str)
330
+ usage_metadata = chunk.get("usageMetadata", {})
331
+ except json.JSONDecodeError:
332
+ logger.error(
333
+ f"Failed to decode JSON from stream for model {model}: {chunk_str}"
334
+ )
335
+ continue
336
+ openai_chunk = self.response_handler.handle_response(
337
+ chunk, model, stream=True, finish_reason=None, usage_metadata=usage_metadata
338
+ )
339
+ if openai_chunk:
340
+ text = self._extract_text_from_openai_chunk(openai_chunk)
341
+ if text and settings.STREAM_OPTIMIZER_ENABLED:
342
+ async for (
343
+ optimized_chunk_data
344
+ ) in openai_optimizer.optimize_stream_output(
345
+ text,
346
+ lambda t: self._create_char_openai_chunk(openai_chunk, t),
347
+ lambda c: f"data: {json.dumps(c)}\n\n",
348
+ ):
349
+ yield optimized_chunk_data
350
+ else:
351
+ if openai_chunk.get("choices") and openai_chunk["choices"][0].get("delta", {}).get("tool_calls"):
352
+ tool_call_flag = True
353
+
354
+ yield f"data: {json.dumps(openai_chunk)}\n\n"
355
+
356
+ if tool_call_flag:
357
+ yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='tool_calls', usage_metadata=usage_metadata))}\n\n"
358
+ else:
359
+ yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop', usage_metadata=usage_metadata))}\n\n"
360
+
361
+ async def _handle_stream_completion(
362
+ self, model: str, payload: Dict[str, Any], api_key: str
363
+ ) -> AsyncGenerator[str, None]:
364
+ """处理流式聊天完成,添加重试逻辑和假流式支持"""
365
+ retries = 0
366
+ max_retries = settings.MAX_RETRIES
367
+ is_success = False
368
+ status_code = None
369
+ final_api_key = api_key
370
+
371
+ while retries < max_retries:
372
+ start_time = time.perf_counter()
373
+ request_datetime = datetime.datetime.now()
374
+ current_attempt_key = final_api_key
375
+
376
+ try:
377
+ stream_generator = None
378
+ if settings.FAKE_STREAM_ENABLED:
379
+ logger.info(
380
+ f"Using fake stream logic for model: {model}, Attempt: {retries + 1}"
381
+ )
382
+ stream_generator = self._fake_stream_logic_impl(
383
+ model, payload, current_attempt_key
384
+ )
385
+ else:
386
+ logger.info(
387
+ f"Using real stream logic for model: {model}, Attempt: {retries + 1}"
388
+ )
389
+ stream_generator = self._real_stream_logic_impl(
390
+ model, payload, current_attempt_key
391
+ )
392
+
393
+ async for chunk_data in stream_generator:
394
+ yield chunk_data
395
+
396
+ yield "data: [DONE]\n\n"
397
+ logger.info(
398
+ f"Streaming completed successfully for model: {model}, FakeStream: {settings.FAKE_STREAM_ENABLED}, Attempt: {retries + 1}"
399
+ )
400
+ is_success = True
401
+ status_code = 200
402
+ break
403
+
404
+ except Exception as e:
405
+ retries += 1
406
+ is_success = False
407
+ error_log_msg = str(e)
408
+ logger.warning(
409
+ f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries} with key {current_attempt_key}"
410
+ )
411
+
412
+ match = re.search(r"status code (\\d+)", error_log_msg)
413
+ if match:
414
+ status_code = int(match.group(1))
415
+ else:
416
+ if isinstance(e, asyncio.TimeoutError):
417
+ status_code = 408
418
+ else:
419
+ status_code = 500
420
+
421
+ await add_error_log(
422
+ gemini_key=current_attempt_key,
423
+ model_name=model,
424
+ error_type="openai-chat-stream",
425
+ error_log=error_log_msg,
426
+ error_code=status_code,
427
+ request_msg=payload,
428
+ )
429
+
430
+ if self.key_manager:
431
+ new_api_key = await self.key_manager.handle_api_failure(
432
+ current_attempt_key, retries
433
+ )
434
+ if new_api_key and new_api_key != current_attempt_key:
435
+ final_api_key = new_api_key
436
+ logger.info(
437
+ f"Switched to new API key for next attempt: {final_api_key}"
438
+ )
439
+ elif not new_api_key:
440
+ logger.error(
441
+ f"No valid API key available after {retries} retries, ceasing attempts for this request."
442
+ )
443
+ break
444
+ else:
445
+ logger.error(
446
+ "KeyManager not available, cannot switch API key. Ceasing attempts for this request."
447
+ )
448
+ break
449
+
450
+ if retries >= max_retries:
451
+ logger.error(
452
+ f"Max retries ({max_retries}) reached for streaming model {model}."
453
+ )
454
+ finally:
455
+ end_time = time.perf_counter()
456
+ latency_ms = int((end_time - start_time) * 1000)
457
+ await add_request_log(
458
+ model_name=model,
459
+ api_key=current_attempt_key,
460
+ is_success=is_success,
461
+ status_code=status_code,
462
+ latency_ms=latency_ms,
463
+ request_time=request_datetime,
464
+ )
465
+
466
+ if not is_success:
467
+ logger.error(
468
+ f"Streaming failed permanently for model {model} after {retries} attempts."
469
+ )
470
+ yield f"data: {json.dumps({'error': f'Streaming failed after {retries} retries.'})}\n\n"
471
+ yield "data: [DONE]\n\n"
472
+
473
+ async def create_image_chat_completion(
474
+ self, request: ChatRequest, api_key: str
475
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
476
+
477
+ image_generate_request = ImageGenerationRequest()
478
+ image_generate_request.prompt = request.messages[-1]["content"]
479
+ image_res = self.image_create_service.generate_images_chat(
480
+ image_generate_request
481
+ )
482
+
483
+ if request.stream:
484
+ return self._handle_stream_image_completion(
485
+ request.model, image_res, api_key
486
+ )
487
+ else:
488
+ return await self._handle_normal_image_completion(
489
+ request.model, image_res, api_key
490
+ )
491
+
492
+ async def _handle_stream_image_completion(
493
+ self, model: str, image_data: str, api_key: str
494
+ ) -> AsyncGenerator[str, None]:
495
+ logger.info(f"Starting stream image completion for model: {model}")
496
+ start_time = time.perf_counter()
497
+ request_datetime = datetime.datetime.now()
498
+ is_success = False
499
+ status_code = None
500
+
501
+ try:
502
+ if image_data:
503
+ openai_chunk = self.response_handler.handle_image_chat_response(
504
+ image_data, model, stream=True, finish_reason=None
505
+ )
506
+ if openai_chunk:
507
+ # 提取文本内容
508
+ text = self._extract_text_from_openai_chunk(openai_chunk)
509
+ if text:
510
+ # 使用流式输出优化器处理文本输出
511
+ async for (
512
+ optimized_chunk
513
+ ) in openai_optimizer.optimize_stream_output(
514
+ text,
515
+ lambda t: self._create_char_openai_chunk(openai_chunk, t),
516
+ lambda c: f"data: {json.dumps(c)}\n\n",
517
+ ):
518
+ yield optimized_chunk
519
+ else:
520
+ # 如果没有文本内容(如图片URL等),整块输出
521
+ yield f"data: {json.dumps(openai_chunk)}\n\n"
522
+ yield f"data: {json.dumps(self.response_handler.handle_response({}, model, stream=True, finish_reason='stop'))}\n\n"
523
+ logger.info(
524
+ f"Stream image completion finished successfully for model: {model}"
525
+ )
526
+ is_success = True
527
+ status_code = 200
528
+ yield "data: [DONE]\n\n"
529
+ except Exception as e:
530
+ is_success = False
531
+ error_log_msg = f"Stream image completion failed for model {model}: {e}"
532
+ logger.error(error_log_msg)
533
+ status_code = 500
534
+ await add_error_log(
535
+ gemini_key=api_key,
536
+ model_name=model,
537
+ error_type="openai-image-stream",
538
+ error_log=error_log_msg,
539
+ error_code=status_code,
540
+ request_msg={"image_data_truncated": image_data[:1000]},
541
+ )
542
+ yield f"data: {json.dumps({'error': error_log_msg})}\n\n"
543
+ yield "data: [DONE]\n\n"
544
+ finally:
545
+ end_time = time.perf_counter()
546
+ latency_ms = int((end_time - start_time) * 1000)
547
+ logger.info(
548
+ f"Stream image completion for model {model} took {latency_ms} ms. Success: {is_success}"
549
+ )
550
+ await add_request_log(
551
+ model_name=model,
552
+ api_key=api_key,
553
+ is_success=is_success,
554
+ status_code=status_code,
555
+ latency_ms=latency_ms,
556
+ request_time=request_datetime,
557
+ )
558
+
559
+ async def _handle_normal_image_completion(
560
+ self, model: str, image_data: str, api_key: str
561
+ ) -> Dict[str, Any]:
562
+ logger.info(f"Starting normal image completion for model: {model}")
563
+ start_time = time.perf_counter()
564
+ request_datetime = datetime.datetime.now()
565
+ is_success = False
566
+ status_code = None
567
+ result = None
568
+
569
+ try:
570
+ result = self.response_handler.handle_image_chat_response(
571
+ image_data, model, stream=False, finish_reason="stop"
572
+ )
573
+ logger.info(
574
+ f"Normal image completion finished successfully for model: {model}"
575
+ )
576
+ is_success = True
577
+ status_code = 200
578
+ return result
579
+ except Exception as e:
580
+ is_success = False
581
+ error_log_msg = f"Normal image completion failed for model {model}: {e}"
582
+ logger.error(error_log_msg)
583
+ status_code = 500
584
+ await add_error_log(
585
+ gemini_key=api_key,
586
+ model_name=model,
587
+ error_type="openai-image-non-stream",
588
+ error_log=error_log_msg,
589
+ error_code=status_code,
590
+ request_msg={"image_data_truncated": image_data[:1000]},
591
+ )
592
+ raise e
593
+ finally:
594
+ end_time = time.perf_counter()
595
+ latency_ms = int((end_time - start_time) * 1000)
596
+ logger.info(
597
+ f"Normal image completion for model {model} took {latency_ms} ms. Success: {is_success}"
598
+ )
599
+ await add_request_log(
600
+ model_name=model,
601
+ api_key=api_key,
602
+ is_success=is_success,
603
+ status_code=status_code,
604
+ latency_ms=latency_ms,
605
+ request_time=request_datetime,
606
+ )
app/service/chat/vertex_express_chat_service.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/services/chat_service.py
2
+
3
+ import json
4
+ import re
5
+ import datetime
6
+ import time
7
+ from typing import Any, AsyncGenerator, Dict, List
8
+ from app.config.config import settings
9
+ from app.core.constants import GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
10
+ from app.domain.gemini_models import GeminiRequest
11
+ from app.handler.response_handler import GeminiResponseHandler
12
+ from app.handler.stream_optimizer import gemini_optimizer
13
+ from app.log.logger import get_gemini_logger
14
+ from app.service.client.api_client import GeminiApiClient
15
+ from app.service.key.key_manager import KeyManager
16
+ from app.database.services import add_error_log, add_request_log
17
+
18
+ logger = get_gemini_logger()
19
+
20
+
21
+ def _has_image_parts(contents: List[Dict[str, Any]]) -> bool:
22
+ """判断消息是否包含图片部分"""
23
+ for content in contents:
24
+ if "parts" in content:
25
+ for part in content["parts"]:
26
+ if "image_url" in part or "inline_data" in part:
27
+ return True
28
+ return False
29
+
30
+
31
+ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
32
+ """构建工具"""
33
+
34
+ def _merge_tools(tools: List[Dict[str, Any]]) -> Dict[str, Any]:
35
+ record = dict()
36
+ for item in tools:
37
+ if not item or not isinstance(item, dict):
38
+ continue
39
+
40
+ for k, v in item.items():
41
+ if k == "functionDeclarations" and v and isinstance(v, list):
42
+ functions = record.get("functionDeclarations", [])
43
+ functions.extend(v)
44
+ record["functionDeclarations"] = functions
45
+ else:
46
+ record[k] = v
47
+ return record
48
+
49
+ tool = dict()
50
+ if payload and isinstance(payload, dict) and "tools" in payload:
51
+ if payload.get("tools") and isinstance(payload.get("tools"), dict):
52
+ payload["tools"] = [payload.get("tools")]
53
+ items = payload.get("tools", [])
54
+ if items and isinstance(items, list):
55
+ tool.update(_merge_tools(items))
56
+
57
+ if (
58
+ settings.TOOLS_CODE_EXECUTION_ENABLED
59
+ and not (model.endswith("-search") or "-thinking" in model)
60
+ and not _has_image_parts(payload.get("contents", []))
61
+ ):
62
+ tool["codeExecution"] = {}
63
+ if model.endswith("-search"):
64
+ tool["googleSearch"] = {}
65
+
66
+ # 解决 "Tool use with function calling is unsupported" 问题
67
+ if tool.get("functionDeclarations"):
68
+ tool.pop("googleSearch", None)
69
+ tool.pop("codeExecution", None)
70
+
71
+ return [tool] if tool else []
72
+
73
+
74
+ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
75
+ """获取安全设置"""
76
+ if model == "gemini-2.0-flash-exp":
77
+ return GEMINI_2_FLASH_EXP_SAFETY_SETTINGS
78
+ return settings.SAFETY_SETTINGS
79
+
80
+
81
+ def _build_payload(model: str, request: GeminiRequest) -> Dict[str, Any]:
82
+ """构建请求payload"""
83
+ request_dict = request.model_dump()
84
+ if request.generationConfig:
85
+ if request.generationConfig.maxOutputTokens is None:
86
+ # 如果未指定最大输出长度,则不传递该字段,解决截断的问题
87
+ request_dict["generationConfig"].pop("maxOutputTokens")
88
+
89
+ payload = {
90
+ "contents": request_dict.get("contents", []),
91
+ "tools": _build_tools(model, request_dict),
92
+ "safetySettings": _get_safety_settings(model),
93
+ "generationConfig": request_dict.get("generationConfig"),
94
+ "systemInstruction": request_dict.get("systemInstruction"),
95
+ }
96
+
97
+ if model.endswith("-image") or model.endswith("-image-generation"):
98
+ payload.pop("systemInstruction")
99
+ payload["generationConfig"]["responseModalities"] = ["Text", "Image"]
100
+
101
+ if model.endswith("-non-thinking"):
102
+ payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": 0}
103
+ if model in settings.THINKING_BUDGET_MAP:
104
+ payload["generationConfig"]["thinkingConfig"] = {"thinkingBudget": settings.THINKING_BUDGET_MAP.get(model,1000)}
105
+
106
+ return payload
107
+
108
+
109
+ class GeminiChatService:
110
+ """聊天服务"""
111
+
112
+ def __init__(self, base_url: str, key_manager: KeyManager):
113
+ self.api_client = GeminiApiClient(base_url, settings.TIME_OUT)
114
+ self.key_manager = key_manager
115
+ self.response_handler = GeminiResponseHandler()
116
+
117
+ def _extract_text_from_response(self, response: Dict[str, Any]) -> str:
118
+ """从响应中提取文本内容"""
119
+ if not response.get("candidates"):
120
+ return ""
121
+
122
+ candidate = response["candidates"][0]
123
+ content = candidate.get("content", {})
124
+ parts = content.get("parts", [])
125
+
126
+ if parts and "text" in parts[0]:
127
+ return parts[0].get("text", "")
128
+ return ""
129
+
130
+ def _create_char_response(
131
+ self, original_response: Dict[str, Any], text: str
132
+ ) -> Dict[str, Any]:
133
+ """创建包含指定文本的响应"""
134
+ response_copy = json.loads(json.dumps(original_response)) # 深拷贝
135
+ if response_copy.get("candidates") and response_copy["candidates"][0].get(
136
+ "content", {}
137
+ ).get("parts"):
138
+ response_copy["candidates"][0]["content"]["parts"][0]["text"] = text
139
+ return response_copy
140
+
141
+ async def generate_content(
142
+ self, model: str, request: GeminiRequest, api_key: str
143
+ ) -> Dict[str, Any]:
144
+ """生成内容"""
145
+ payload = _build_payload(model, request)
146
+ start_time = time.perf_counter()
147
+ request_datetime = datetime.datetime.now()
148
+ is_success = False
149
+ status_code = None
150
+ response = None
151
+
152
+ try:
153
+ response = await self.api_client.generate_content(payload, model, api_key)
154
+ is_success = True
155
+ status_code = 200
156
+ return self.response_handler.handle_response(response, model, stream=False)
157
+ except Exception as e:
158
+ is_success = False
159
+ error_log_msg = str(e)
160
+ logger.error(f"Normal API call failed with error: {error_log_msg}")
161
+ match = re.search(r"status code (\d+)", error_log_msg)
162
+ if match:
163
+ status_code = int(match.group(1))
164
+ else:
165
+ status_code = 500
166
+
167
+ await add_error_log(
168
+ gemini_key=api_key,
169
+ model_name=model,
170
+ error_type="gemini-chat-non-stream",
171
+ error_log=error_log_msg,
172
+ error_code=status_code,
173
+ request_msg=payload
174
+ )
175
+ raise e
176
+ finally:
177
+ end_time = time.perf_counter()
178
+ latency_ms = int((end_time - start_time) * 1000)
179
+ await add_request_log(
180
+ model_name=model,
181
+ api_key=api_key,
182
+ is_success=is_success,
183
+ status_code=status_code,
184
+ latency_ms=latency_ms,
185
+ request_time=request_datetime
186
+ )
187
+
188
+ async def stream_generate_content(
189
+ self, model: str, request: GeminiRequest, api_key: str
190
+ ) -> AsyncGenerator[str, None]:
191
+ """流式生成内容"""
192
+ retries = 0
193
+ max_retries = settings.MAX_RETRIES
194
+ payload = _build_payload(model, request)
195
+ is_success = False
196
+ status_code = None
197
+ final_api_key = api_key
198
+
199
+ while retries < max_retries:
200
+ request_datetime = datetime.datetime.now()
201
+ start_time = time.perf_counter()
202
+ current_attempt_key = api_key
203
+ final_api_key = current_attempt_key # Update final key used
204
+ try:
205
+ async for line in self.api_client.stream_generate_content(
206
+ payload, model, current_attempt_key
207
+ ):
208
+ # print(line)
209
+ if line.startswith("data:"):
210
+ line = line[6:]
211
+ response_data = self.response_handler.handle_response(
212
+ json.loads(line), model, stream=True
213
+ )
214
+ text = self._extract_text_from_response(response_data)
215
+ # 如果有文本内容,且开启了流式输出优化器,则使用流式输出优化器处理
216
+ if text and settings.STREAM_OPTIMIZER_ENABLED:
217
+ # 使用流式输出优化器处理文本输出
218
+ async for (
219
+ optimized_chunk
220
+ ) in gemini_optimizer.optimize_stream_output(
221
+ text,
222
+ lambda t: self._create_char_response(response_data, t),
223
+ lambda c: "data: " + json.dumps(c) + "\n\n",
224
+ ):
225
+ yield optimized_chunk
226
+ else:
227
+ # 如果没有文本内容(如工具调用等),整块输出
228
+ yield "data: " + json.dumps(response_data) + "\n\n"
229
+ logger.info("Streaming completed successfully")
230
+ is_success = True
231
+ status_code = 200
232
+ break
233
+ except Exception as e:
234
+ retries += 1
235
+ is_success = False
236
+ error_log_msg = str(e)
237
+ logger.warning(
238
+ f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
239
+ )
240
+ match = re.search(r"status code (\d+)", error_log_msg)
241
+ if match:
242
+ status_code = int(match.group(1))
243
+ else:
244
+ status_code = 500
245
+
246
+ await add_error_log(
247
+ gemini_key=current_attempt_key,
248
+ model_name=model,
249
+ error_type="gemini-chat-stream",
250
+ error_log=error_log_msg,
251
+ error_code=status_code,
252
+ request_msg=payload
253
+ )
254
+
255
+ api_key = await self.key_manager.handle_api_failure(current_attempt_key, retries)
256
+ if api_key:
257
+ logger.info(f"Switched to new API key: {api_key}")
258
+ else:
259
+ logger.error(f"No valid API key available after {retries} retries.")
260
+ break
261
+
262
+ if retries >= max_retries:
263
+ logger.error(
264
+ f"Max retries ({max_retries}) reached for streaming."
265
+ )
266
+ break
267
+ finally:
268
+ end_time = time.perf_counter()
269
+ latency_ms = int((end_time - start_time) * 1000)
270
+ await add_request_log(
271
+ model_name=model,
272
+ api_key=final_api_key,
273
+ is_success=is_success,
274
+ status_code=status_code,
275
+ latency_ms=latency_ms,
276
+ request_time=request_datetime
277
+ )
app/service/client/api_client.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/services/chat/api_client.py
2
+
3
+ from typing import Dict, Any, AsyncGenerator, Optional
4
+ import httpx
5
+ import random
6
+ from abc import ABC, abstractmethod
7
+ from app.config.config import settings
8
+ from app.log.logger import get_api_client_logger
9
+ from app.core.constants import DEFAULT_TIMEOUT
10
+
11
+ logger = get_api_client_logger()
12
+
13
+ class ApiClient(ABC):
14
+ """API客户端基类"""
15
+
16
+ @abstractmethod
17
+ async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
18
+ pass
19
+
20
+ @abstractmethod
21
+ async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
22
+ pass
23
+
24
+
25
+ class GeminiApiClient(ApiClient):
26
+ """Gemini API客户端"""
27
+
28
+ def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
29
+ self.base_url = base_url
30
+ self.timeout = timeout
31
+
32
+ def _get_real_model(self, model: str) -> str:
33
+ if model.endswith("-search"):
34
+ model = model[:-7]
35
+ if model.endswith("-image"):
36
+ model = model[:-6]
37
+ if model.endswith("-non-thinking"):
38
+ model = model[:-13]
39
+ if "-search" in model and "-non-thinking" in model:
40
+ model = model[:-20]
41
+ return model
42
+
43
+ async def get_models(self, api_key: str) -> Optional[Dict[str, Any]]:
44
+ """获取可用的 Gemini 模型列表"""
45
+ timeout = httpx.Timeout(timeout=5)
46
+
47
+ proxy_to_use = None
48
+ if settings.PROXIES:
49
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
50
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
51
+ else:
52
+ proxy_to_use = random.choice(settings.PROXIES)
53
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
54
+
55
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
56
+ url = f"{self.base_url}/models?key={api_key}&pageSize=1000"
57
+ try:
58
+ response = await client.get(url)
59
+ response.raise_for_status()
60
+ return response.json()
61
+ except httpx.HTTPStatusError as e:
62
+ logger.error(f"获取模型列表失败: {e.response.status_code}")
63
+ logger.error(e.response.text)
64
+ return None
65
+ except httpx.RequestError as e:
66
+ logger.error(f"请求模型列表失败: {e}")
67
+ return None
68
+
69
+ async def generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> Dict[str, Any]:
70
+ timeout = httpx.Timeout(self.timeout, read=self.timeout)
71
+ model = self._get_real_model(model)
72
+
73
+ proxy_to_use = None
74
+ if settings.PROXIES:
75
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
76
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
77
+ else:
78
+ proxy_to_use = random.choice(settings.PROXIES)
79
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
80
+
81
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
82
+ url = f"{self.base_url}/models/{model}:generateContent?key={api_key}"
83
+ response = await client.post(url, json=payload)
84
+ if response.status_code != 200:
85
+ error_content = response.text
86
+ raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
87
+ return response.json()
88
+
89
+ async def stream_generate_content(self, payload: Dict[str, Any], model: str, api_key: str) -> AsyncGenerator[str, None]:
90
+ timeout = httpx.Timeout(self.timeout, read=self.timeout)
91
+ model = self._get_real_model(model)
92
+
93
+ proxy_to_use = None
94
+ if settings.PROXIES:
95
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
96
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
97
+ else:
98
+ proxy_to_use = random.choice(settings.PROXIES)
99
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
100
+
101
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
102
+ url = f"{self.base_url}/models/{model}:streamGenerateContent?alt=sse&key={api_key}"
103
+ async with client.stream(method="POST", url=url, json=payload) as response:
104
+ if response.status_code != 200:
105
+ error_content = await response.aread()
106
+ error_msg = error_content.decode("utf-8")
107
+ raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
108
+ async for line in response.aiter_lines():
109
+ yield line
110
+
111
+
112
+ class OpenaiApiClient(ApiClient):
113
+ """OpenAI API客户端"""
114
+
115
+ def __init__(self, base_url: str, timeout: int = DEFAULT_TIMEOUT):
116
+ self.base_url = base_url
117
+ self.timeout = timeout
118
+
119
+ async def get_models(self, api_key: str) -> Dict[str, Any]:
120
+ timeout = httpx.Timeout(self.timeout, read=self.timeout)
121
+
122
+ proxy_to_use = None
123
+ if settings.PROXIES:
124
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
125
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
126
+ else:
127
+ proxy_to_use = random.choice(settings.PROXIES)
128
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
129
+
130
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
131
+ url = f"{self.base_url}/openai/models"
132
+ headers = {"Authorization": f"Bearer {api_key}"}
133
+ response = await client.get(url, headers=headers)
134
+ if response.status_code != 200:
135
+ error_content = response.text
136
+ raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
137
+ return response.json()
138
+
139
+ async def generate_content(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
140
+ timeout = httpx.Timeout(self.timeout, read=self.timeout)
141
+ logger.info(f"settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY: {settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY}")
142
+ proxy_to_use = None
143
+ if settings.PROXIES:
144
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
145
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
146
+ else:
147
+ proxy_to_use = random.choice(settings.PROXIES)
148
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
149
+
150
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
151
+ url = f"{self.base_url}/openai/chat/completions"
152
+ headers = {"Authorization": f"Bearer {api_key}"}
153
+ response = await client.post(url, json=payload, headers=headers)
154
+ if response.status_code != 200:
155
+ error_content = response.text
156
+ raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
157
+ return response.json()
158
+
159
+ async def stream_generate_content(self, payload: Dict[str, Any], api_key: str) -> AsyncGenerator[str, None]:
160
+ timeout = httpx.Timeout(self.timeout, read=self.timeout)
161
+ proxy_to_use = None
162
+ if settings.PROXIES:
163
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
164
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
165
+ else:
166
+ proxy_to_use = random.choice(settings.PROXIES)
167
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
168
+
169
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
170
+ url = f"{self.base_url}/openai/chat/completions"
171
+ headers = {"Authorization": f"Bearer {api_key}"}
172
+ async with client.stream(method="POST", url=url, json=payload, headers=headers) as response:
173
+ if response.status_code != 200:
174
+ error_content = await response.aread()
175
+ error_msg = error_content.decode("utf-8")
176
+ raise Exception(f"API call failed with status code {response.status_code}, {error_msg}")
177
+ async for line in response.aiter_lines():
178
+ yield line
179
+
180
+ async def create_embeddings(self, input: str, model: str, api_key: str) -> Dict[str, Any]:
181
+ timeout = httpx.Timeout(self.timeout, read=self.timeout)
182
+
183
+ proxy_to_use = None
184
+ if settings.PROXIES:
185
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
186
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
187
+ else:
188
+ proxy_to_use = random.choice(settings.PROXIES)
189
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
190
+
191
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
192
+ url = f"{self.base_url}/openai/embeddings"
193
+ headers = {"Authorization": f"Bearer {api_key}"}
194
+ payload = {
195
+ "input": input,
196
+ "model": model,
197
+ }
198
+ response = await client.post(url, json=payload, headers=headers)
199
+ if response.status_code != 200:
200
+ error_content = response.text
201
+ raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
202
+ return response.json()
203
+
204
+ async def generate_images(self, payload: Dict[str, Any], api_key: str) -> Dict[str, Any]:
205
+ timeout = httpx.Timeout(self.timeout, read=self.timeout)
206
+
207
+ proxy_to_use = None
208
+ if settings.PROXIES:
209
+ if settings.PROXIES_USE_CONSISTENCY_HASH_BY_API_KEY:
210
+ proxy_to_use = settings.PROXIES[hash(api_key) % len(settings.PROXIES)]
211
+ else:
212
+ proxy_to_use = random.choice(settings.PROXIES)
213
+ logger.info(f"Using proxy for getting models: {proxy_to_use}")
214
+
215
+ async with httpx.AsyncClient(timeout=timeout, proxy=proxy_to_use) as client:
216
+ url = f"{self.base_url}/openai/images/generations"
217
+ headers = {"Authorization": f"Bearer {api_key}"}
218
+ response = await client.post(url, json=payload, headers=headers)
219
+ if response.status_code != 200:
220
+ error_content = response.text
221
+ raise Exception(f"API call failed with status code {response.status_code}, {error_content}")
222
+ return response.json()
app/service/config/config_service.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 配置服务模块
3
+ """
4
+
5
+ import datetime
6
+ import json
7
+ from typing import Any, Dict, List
8
+
9
+ from dotenv import find_dotenv, load_dotenv
10
+ from fastapi import HTTPException
11
+ from sqlalchemy import insert, update
12
+
13
+ from app.config.config import Settings as ConfigSettings
14
+ from app.config.config import settings
15
+ from app.database.connection import database
16
+ from app.database.models import Settings
17
+ from app.database.services import get_all_settings
18
+ from app.log.logger import get_config_routes_logger
19
+ from app.service.key.key_manager import (
20
+ get_key_manager_instance,
21
+ reset_key_manager_instance,
22
+ )
23
+ from app.service.model.model_service import ModelService
24
+
25
+ logger = get_config_routes_logger()
26
+
27
+
28
+ class ConfigService:
29
+ """配置服务类,用于管理应用程序配置"""
30
+
31
+ @staticmethod
32
+ async def get_config() -> Dict[str, Any]:
33
+ return settings.model_dump()
34
+
35
+ @staticmethod
36
+ async def update_config(config_data: Dict[str, Any]) -> Dict[str, Any]:
37
+ for key, value in config_data.items():
38
+ if hasattr(settings, key):
39
+ setattr(settings, key, value)
40
+ logger.debug(f"Updated setting in memory: {key}")
41
+
42
+ # 获取现有设置
43
+ existing_settings_raw: List[Dict[str, Any]] = await get_all_settings()
44
+ existing_settings_map: Dict[str, Dict[str, Any]] = {
45
+ s["key"]: s for s in existing_settings_raw
46
+ }
47
+ existing_keys = set(existing_settings_map.keys())
48
+
49
+ settings_to_update: List[Dict[str, Any]] = []
50
+ settings_to_insert: List[Dict[str, Any]] = []
51
+ now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=8)))
52
+
53
+ # 准备要更新或插入的数据
54
+ for key, value in config_data.items():
55
+ # 处理不同类型的值
56
+ if isinstance(value, list):
57
+ db_value = json.dumps(value)
58
+ elif isinstance(value, dict):
59
+ db_value = json.dumps(value)
60
+ elif isinstance(value, bool):
61
+ db_value = str(value).lower()
62
+ else:
63
+ db_value = str(value)
64
+
65
+ # 仅当值发生变化时才更新
66
+ if key in existing_keys and existing_settings_map[key]["value"] == db_value:
67
+ continue
68
+
69
+ description = f"{key}配置项"
70
+
71
+ data = {
72
+ "key": key,
73
+ "value": db_value,
74
+ "description": description,
75
+ "updated_at": now,
76
+ }
77
+
78
+ if key in existing_keys:
79
+ data["description"] = existing_settings_map[key].get(
80
+ "description", description
81
+ )
82
+ settings_to_update.append(data)
83
+ else:
84
+ data["created_at"] = now
85
+ settings_to_insert.append(data)
86
+
87
+ # 在事务中执行批量插入和更新
88
+ if settings_to_insert or settings_to_update:
89
+ try:
90
+ async with database.transaction():
91
+ if settings_to_insert:
92
+ query_insert = insert(Settings).values(settings_to_insert)
93
+ await database.execute(query=query_insert)
94
+ logger.info(
95
+ f"Bulk inserted {len(settings_to_insert)} settings."
96
+ )
97
+
98
+ if settings_to_update:
99
+ for setting_data in settings_to_update:
100
+ query_update = (
101
+ update(Settings)
102
+ .where(Settings.key == setting_data["key"])
103
+ .values(
104
+ value=setting_data["value"],
105
+ description=setting_data["description"],
106
+ updated_at=setting_data["updated_at"],
107
+ )
108
+ )
109
+ await database.execute(query=query_update)
110
+ logger.info(f"Updated {len(settings_to_update)} settings.")
111
+ except Exception as e:
112
+ logger.error(f"Failed to bulk update/insert settings: {str(e)}")
113
+ raise
114
+
115
+ # 重置并重新初始化 KeyManager
116
+ try:
117
+ await reset_key_manager_instance()
118
+ await get_key_manager_instance(settings.API_KEYS, settings.VERTEX_API_KEYS)
119
+ logger.info("KeyManager instance re-initialized with updated settings.")
120
+ except Exception as e:
121
+ logger.error(f"Failed to re-initialize KeyManager: {str(e)}")
122
+
123
+ return await ConfigService.get_config()
124
+
125
+ @staticmethod
126
+ async def delete_key(key_to_delete: str) -> Dict[str, Any]:
127
+ """删除单个API密钥"""
128
+ # 确保 settings.API_KEYS 是一个列表
129
+ if not isinstance(settings.API_KEYS, list):
130
+ settings.API_KEYS = []
131
+
132
+ original_keys_count = len(settings.API_KEYS)
133
+ # 创建一个不包含待删除密钥的新列表
134
+ updated_api_keys = [k for k in settings.API_KEYS if k != key_to_delete]
135
+
136
+ if len(updated_api_keys) < original_keys_count:
137
+ # 密钥已找到并从列表中移除
138
+ settings.API_KEYS = updated_api_keys # 首先更新内存中的 settings
139
+ # 使用 update_config 持久化更改,它同时处理数据库和 KeyManager
140
+ await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
141
+ logger.info(f"密钥 '{key_to_delete}' 已成功删除。")
142
+ return {"success": True, "message": f"密钥 '{key_to_delete}' 已成功删除。"}
143
+ else:
144
+ # 未找到密钥
145
+ logger.warning(f"尝试删除密钥 '{key_to_delete}',但未找到该密钥。")
146
+ return {"success": False, "message": f"未找到密钥 '{key_to_delete}'。"}
147
+
148
+ @staticmethod
149
+ async def delete_selected_keys(keys_to_delete: List[str]) -> Dict[str, Any]:
150
+ """批量删除选定的API密钥"""
151
+ if not isinstance(settings.API_KEYS, list):
152
+ settings.API_KEYS = []
153
+
154
+ deleted_count = 0
155
+ not_found_keys: List[str] = []
156
+
157
+ current_api_keys = list(settings.API_KEYS)
158
+ keys_actually_removed: List[str] = []
159
+
160
+ for key_to_del in keys_to_delete:
161
+ if key_to_del in current_api_keys:
162
+ current_api_keys.remove(key_to_del)
163
+ keys_actually_removed.append(key_to_del)
164
+ deleted_count += 1
165
+ else:
166
+ not_found_keys.append(key_to_del)
167
+
168
+ if deleted_count > 0:
169
+ settings.API_KEYS = current_api_keys
170
+ await ConfigService.update_config({"API_KEYS": settings.API_KEYS})
171
+ logger.info(
172
+ f"成功删除 {deleted_count} 个密钥。密钥: {keys_actually_removed}"
173
+ )
174
+ message = f"成功删除 {deleted_count} 个密钥。"
175
+ if not_found_keys:
176
+ message += f" {len(not_found_keys)} 个密钥未找到: {not_found_keys}。"
177
+ return {
178
+ "success": True,
179
+ "message": message,
180
+ "deleted_count": deleted_count,
181
+ "not_found_keys": not_found_keys,
182
+ }
183
+ else:
184
+ message = "没有密钥被删除。"
185
+ if not_found_keys:
186
+ message = f"所有 {len(not_found_keys)} 个指定的密钥均未找到: {not_found_keys}。"
187
+ elif not keys_to_delete:
188
+ message = "未指定要删除的密钥。"
189
+ logger.warning(message)
190
+ return {
191
+ "success": False,
192
+ "message": message,
193
+ "deleted_count": 0,
194
+ "not_found_keys": not_found_keys,
195
+ }
196
+
197
+ @staticmethod
198
+ async def reset_config() -> Dict[str, Any]:
199
+ """
200
+ 重置配置:优先从系统环境变量加载,然后从 .env 文件加载,
201
+ 更新内存中的 settings 对象,并刷新 KeyManager。
202
+
203
+ Returns:
204
+ Dict[str, Any]: 重置后的配置字典
205
+ """
206
+ # 1. 重新加载配置对象,它应该处理环境变量和 .env 的优先级
207
+ _reload_settings()
208
+ logger.info(
209
+ "Settings object reloaded, prioritizing system environment variables then .env file."
210
+ )
211
+
212
+ # 2. 重置并重新初始化 KeyManager
213
+ try:
214
+ await reset_key_manager_instance()
215
+ # 确保使用更新后的 settings 中的 API_KEYS
216
+ await get_key_manager_instance(settings.API_KEYS)
217
+ logger.info("KeyManager instance re-initialized with reloaded settings.")
218
+ except Exception as e:
219
+ logger.error(f"Failed to re-initialize KeyManager during reset: {str(e)}")
220
+ # 根据需要决定是否抛出异常或继续
221
+ # 这里选择记录错误并继续
222
+
223
+ # 3. 返回更新后的配置
224
+ return await ConfigService.get_config()
225
+
226
+ @staticmethod
227
+ async def fetch_ui_models() -> List[Dict[str, Any]]:
228
+ """获取用于UI显示的模型列表"""
229
+ try:
230
+ key_manager = await get_key_manager_instance()
231
+ model_service = ModelService()
232
+
233
+ api_key = await key_manager.get_first_valid_key()
234
+ if not api_key:
235
+ logger.error("No valid API keys available to fetch model list for UI.")
236
+ raise HTTPException(
237
+ status_code=500,
238
+ detail="No valid API keys available to fetch model list.",
239
+ )
240
+
241
+ models = await model_service.get_gemini_openai_models(api_key)
242
+ return models
243
+ except HTTPException as e:
244
+ raise e
245
+ except Exception as e:
246
+ logger.error(
247
+ f"Failed to fetch models for UI in ConfigService: {e}", exc_info=True
248
+ )
249
+ raise HTTPException(
250
+ status_code=500, detail=f"Failed to fetch models for UI: {str(e)}"
251
+ )
252
+
253
+
254
+ # 重新加载配置的函数
255
+ def _reload_settings():
256
+ """重新加载环境变量并更新配置"""
257
+ # 显式加载 .env 文件,覆盖现有环境变量
258
+ load_dotenv(find_dotenv(), override=True)
259
+ # 更新现有 settings 对象的属性,而不是新建实例
260
+ for key, value in ConfigSettings().model_dump().items():
261
+ setattr(settings, key, value)
app/service/embedding/embedding_service.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import time
3
+ import re
4
+ from typing import List, Union
5
+
6
+ import openai
7
+ from openai import APIStatusError
8
+ from openai.types import CreateEmbeddingResponse
9
+
10
+ from app.config.config import settings
11
+ from app.log.logger import get_embeddings_logger
12
+ from app.database.services import add_error_log, add_request_log
13
+
14
+ logger = get_embeddings_logger()
15
+
16
+
17
+ class EmbeddingService:
18
+
19
+ async def create_embedding(
20
+ self, input_text: Union[str, List[str]], model: str, api_key: str
21
+ ) -> CreateEmbeddingResponse:
22
+ """Create embeddings using OpenAI API with database logging"""
23
+ start_time = time.perf_counter()
24
+ request_datetime = datetime.datetime.now()
25
+ is_success = False
26
+ status_code = None
27
+ response = None
28
+ error_log_msg = ""
29
+ if isinstance(input_text, list):
30
+ request_msg_log = {"input_truncated": [str(item)[:100] + "..." if len(str(item)) > 100 else str(item) for item in input_text[:5]]}
31
+ if len(input_text) > 5:
32
+ request_msg_log["input_truncated"].append("...")
33
+ else:
34
+ request_msg_log = {"input_truncated": input_text[:1000] + "..." if len(input_text) > 1000 else input_text}
35
+
36
+
37
+ try:
38
+ client = openai.OpenAI(api_key=api_key, base_url=settings.BASE_URL)
39
+ response = client.embeddings.create(input=input_text, model=model)
40
+ is_success = True
41
+ status_code = 200
42
+ return response
43
+ except APIStatusError as e:
44
+ is_success = False
45
+ status_code = e.status_code
46
+ error_log_msg = f"OpenAI API error: {e}"
47
+ logger.error(f"Error creating embedding (APIStatusError): {error_log_msg}")
48
+ raise e
49
+ except Exception as e:
50
+ is_success = False
51
+ error_log_msg = f"Generic error: {e}"
52
+ logger.error(f"Error creating embedding (Exception): {error_log_msg}")
53
+ match = re.search(r"status code (\d+)", str(e))
54
+ if match:
55
+ status_code = int(match.group(1))
56
+ else:
57
+ status_code = 500
58
+ raise e
59
+ finally:
60
+ end_time = time.perf_counter()
61
+ latency_ms = int((end_time - start_time) * 1000)
62
+ if not is_success:
63
+ await add_error_log(
64
+ gemini_key=api_key,
65
+ model_name=model,
66
+ error_type="openai-embedding",
67
+ error_log=error_log_msg,
68
+ error_code=status_code,
69
+ request_msg=request_msg_log
70
+ )
71
+ await add_request_log(
72
+ model_name=model,
73
+ api_key=api_key,
74
+ is_success=is_success,
75
+ status_code=status_code,
76
+ latency_ms=latency_ms,
77
+ request_time=request_datetime
78
+ )
app/service/error_log/error_log_service.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta, timezone
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from sqlalchemy import delete, func, select
5
+
6
+ from app.config.config import settings
7
+ from app.database import services as db_services
8
+ from app.database.connection import database
9
+ from app.database.models import ErrorLog
10
+ from app.log.logger import get_error_log_logger
11
+
12
+ logger = get_error_log_logger()
13
+
14
+
15
+ async def delete_old_error_logs():
16
+ """
17
+ Deletes error logs older than a specified number of days,
18
+ based on the AUTO_DELETE_ERROR_LOGS_ENABLED and AUTO_DELETE_ERROR_LOGS_DAYS settings.
19
+ """
20
+ if not settings.AUTO_DELETE_ERROR_LOGS_ENABLED:
21
+ logger.info("Auto-deletion of error logs is disabled. Skipping.")
22
+ return
23
+
24
+ days_to_keep = settings.AUTO_DELETE_ERROR_LOGS_DAYS
25
+ if not isinstance(days_to_keep, int) or days_to_keep <= 0:
26
+ logger.error(
27
+ f"Invalid AUTO_DELETE_ERROR_LOGS_DAYS value: {days_to_keep}. Must be a positive integer. Skipping deletion."
28
+ )
29
+ return
30
+
31
+ cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
32
+
33
+ logger.info(
34
+ f"Attempting to delete error logs older than {days_to_keep} days (before {cutoff_date.strftime('%Y-%m-%d %H:%M:%S %Z')})."
35
+ )
36
+
37
+ try:
38
+ if not database.is_connected:
39
+ await database.connect()
40
+ logger.info("Database connection established for deleting error logs.")
41
+
42
+ # First, count how many logs will be deleted (optional, for logging)
43
+ count_query = select(func.count(ErrorLog.id)).where(
44
+ ErrorLog.request_time < cutoff_date
45
+ )
46
+ num_logs_to_delete = await database.fetch_val(count_query)
47
+
48
+ if num_logs_to_delete == 0:
49
+ logger.info(
50
+ "No error logs found older than the specified period. No deletion needed."
51
+ )
52
+ return
53
+
54
+ logger.info(f"Found {num_logs_to_delete} error logs to delete.")
55
+
56
+ # Perform the deletion
57
+ query = delete(ErrorLog).where(ErrorLog.request_time < cutoff_date)
58
+ await database.execute(query)
59
+ logger.info(
60
+ f"Successfully deleted {num_logs_to_delete} error logs older than {days_to_keep} days."
61
+ )
62
+
63
+ except Exception as e:
64
+ logger.error(
65
+ f"Error during automatic deletion of error logs: {e}", exc_info=True
66
+ )
67
+
68
+
69
+ async def process_get_error_logs(
70
+ limit: int,
71
+ offset: int,
72
+ key_search: Optional[str],
73
+ error_search: Optional[str],
74
+ error_code_search: Optional[str],
75
+ start_date: Optional[datetime],
76
+ end_date: Optional[datetime],
77
+ sort_by: str,
78
+ sort_order: str,
79
+ ) -> Dict[str, Any]:
80
+ """
81
+ 处理错误日志的检索,支持分页和过滤。
82
+ """
83
+ try:
84
+ logs_data = await db_services.get_error_logs(
85
+ limit=limit,
86
+ offset=offset,
87
+ key_search=key_search,
88
+ error_search=error_search,
89
+ error_code_search=error_code_search,
90
+ start_date=start_date,
91
+ end_date=end_date,
92
+ sort_by=sort_by,
93
+ sort_order=sort_order,
94
+ )
95
+ total_count = await db_services.get_error_logs_count(
96
+ key_search=key_search,
97
+ error_search=error_search,
98
+ error_code_search=error_code_search,
99
+ start_date=start_date,
100
+ end_date=end_date,
101
+ )
102
+ return {"logs": logs_data, "total": total_count}
103
+ except Exception as e:
104
+ logger.error(f"Service error in process_get_error_logs: {e}", exc_info=True)
105
+ raise
106
+
107
+
108
+ async def process_get_error_log_details(log_id: int) -> Optional[Dict[str, Any]]:
109
+ """
110
+ 处理特定错误日志详细信息的检索。
111
+ 如果未找到,则返回 None。
112
+ """
113
+ try:
114
+ log_details = await db_services.get_error_log_details(log_id=log_id)
115
+ return log_details
116
+ except Exception as e:
117
+ logger.error(
118
+ f"Service error in process_get_error_log_details for ID {log_id}: {e}",
119
+ exc_info=True,
120
+ )
121
+ raise
122
+
123
+
124
+ async def process_delete_error_logs_by_ids(log_ids: List[int]) -> int:
125
+ """
126
+ 按 ID 批量删除错误日志。
127
+ 返回尝试删除的日志数量。
128
+ """
129
+ if not log_ids:
130
+ return 0
131
+ try:
132
+ deleted_count = await db_services.delete_error_logs_by_ids(log_ids)
133
+ return deleted_count
134
+ except Exception as e:
135
+ logger.error(
136
+ f"Service error in process_delete_error_logs_by_ids for IDs {log_ids}: {e}",
137
+ exc_info=True,
138
+ )
139
+ raise
140
+
141
+
142
+ async def process_delete_error_log_by_id(log_id: int) -> bool:
143
+ """
144
+ 按 ID 删除单个错误日志。
145
+ 如果删除成功(或找到日志并尝试删除),则返回 True,否则返回 False。
146
+ """
147
+ try:
148
+ success = await db_services.delete_error_log_by_id(log_id)
149
+ return success
150
+ except Exception as e:
151
+ logger.error(
152
+ f"Service error in process_delete_error_log_by_id for ID {log_id}: {e}",
153
+ exc_info=True,
154
+ )
155
+ raise
156
+
157
+
158
+ async def process_delete_all_error_logs() -> int:
159
+ """
160
+ 处理删除所有错误日志的请求。
161
+ 返回删除的日志数量。
162
+ """
163
+ try:
164
+ if not database.is_connected:
165
+ await database.connect()
166
+ logger.info("Database connection established for deleting all error logs.")
167
+
168
+ deleted_count = await db_services.delete_all_error_logs()
169
+ logger.info(
170
+ f"Successfully processed request to delete all error logs. Count: {deleted_count}"
171
+ )
172
+ return deleted_count
173
+ except Exception as e:
174
+ logger.error(
175
+ f"Service error in process_delete_all_error_logs: {e}",
176
+ exc_info=True,
177
+ )
178
+ raise
app/service/image/image_create_service.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import time
3
+ import uuid
4
+
5
+ from google import genai
6
+ from google.genai import types
7
+
8
+ from app.config.config import settings
9
+ from app.core.constants import VALID_IMAGE_RATIOS
10
+ from app.domain.openai_models import ImageGenerationRequest
11
+ from app.log.logger import get_image_create_logger
12
+ from app.utils.uploader import ImageUploaderFactory
13
+
14
+ logger = get_image_create_logger()
15
+
16
+
17
+ class ImageCreateService:
18
+ def __init__(self, aspect_ratio="1:1"):
19
+ self.image_model = settings.CREATE_IMAGE_MODEL
20
+ self.aspect_ratio = aspect_ratio
21
+
22
+ def parse_prompt_parameters(self, prompt: str) -> tuple:
23
+ """从prompt中解析参数
24
+ 支持的格式:
25
+ - {n:数量} 例如: {n:2} 生成2张图片
26
+ - {ratio:比例} 例如: {ratio:16:9} 使用16:9比例
27
+ """
28
+ import re
29
+
30
+ # 默认值
31
+ n = 1
32
+ aspect_ratio = self.aspect_ratio
33
+
34
+ # 解析n参数
35
+ n_match = re.search(r"{n:(\d+)}", prompt)
36
+ if n_match:
37
+ n = int(n_match.group(1))
38
+ if n < 1 or n > 4:
39
+ raise ValueError(f"Invalid n value: {n}. Must be between 1 and 4.")
40
+ prompt = prompt.replace(n_match.group(0), "").strip()
41
+
42
+ # 解析ratio参数
43
+ ratio_match = re.search(r"{ratio:(\d+:\d+)}", prompt)
44
+ if ratio_match:
45
+ aspect_ratio = ratio_match.group(1)
46
+ if aspect_ratio not in VALID_IMAGE_RATIOS:
47
+ raise ValueError(
48
+ f"Invalid ratio: {aspect_ratio}. Must be one of: {', '.join(VALID_IMAGE_RATIOS)}"
49
+ )
50
+ prompt = prompt.replace(ratio_match.group(0), "").strip()
51
+
52
+ return prompt, n, aspect_ratio
53
+
54
+ def generate_images(self, request: ImageGenerationRequest):
55
+ client = genai.Client(api_key=settings.PAID_KEY)
56
+
57
+ if request.size == "1024x1024":
58
+ self.aspect_ratio = "1:1"
59
+ elif request.size == "1792x1024":
60
+ self.aspect_ratio = "16:9"
61
+ elif request.size == "1027x1792":
62
+ self.aspect_ratio = "9:16"
63
+ else:
64
+ raise ValueError(
65
+ f"Invalid size: {request.size}. Supported sizes are 1024x1024, 1792x1024, and 1024x1792."
66
+ )
67
+
68
+ # 解析prompt中的参数
69
+ cleaned_prompt, prompt_n, prompt_ratio = self.parse_prompt_parameters(
70
+ request.prompt
71
+ )
72
+ request.prompt = cleaned_prompt
73
+
74
+ # 如果prompt中指定了n,则覆盖请求中的n
75
+ if prompt_n > 1:
76
+ request.n = prompt_n
77
+
78
+ # 如果prompt中指定了ratio,则覆盖默认的aspect_ratio
79
+ if prompt_ratio != self.aspect_ratio:
80
+ self.aspect_ratio = prompt_ratio
81
+
82
+ response = client.models.generate_images(
83
+ model=self.image_model,
84
+ prompt=request.prompt,
85
+ config=types.GenerateImagesConfig(
86
+ number_of_images=request.n,
87
+ output_mime_type="image/png",
88
+ aspect_ratio=self.aspect_ratio,
89
+ safety_filter_level="BLOCK_LOW_AND_ABOVE",
90
+ person_generation="ALLOW_ADULT",
91
+ ),
92
+ )
93
+
94
+ if response.generated_images:
95
+ images_data = []
96
+ for index, generated_image in enumerate(response.generated_images):
97
+ image_data = generated_image.image.image_bytes
98
+ image_uploader = None
99
+
100
+ if request.response_format == "b64_json":
101
+ base64_image = base64.b64encode(image_data).decode("utf-8")
102
+ images_data.append(
103
+ {"b64_json": base64_image, "revised_prompt": request.prompt}
104
+ )
105
+ else:
106
+ current_date = time.strftime("%Y/%m/%d")
107
+ filename = f"{current_date}/{uuid.uuid4().hex[:8]}.png"
108
+
109
+ if settings.UPLOAD_PROVIDER == "smms":
110
+ image_uploader = ImageUploaderFactory.create(
111
+ provider=settings.UPLOAD_PROVIDER,
112
+ api_key=settings.SMMS_SECRET_TOKEN,
113
+ )
114
+ elif settings.UPLOAD_PROVIDER == "picgo":
115
+ image_uploader = ImageUploaderFactory.create(
116
+ provider=settings.UPLOAD_PROVIDER,
117
+ api_key=settings.PICGO_API_KEY,
118
+ )
119
+ elif settings.UPLOAD_PROVIDER == "cloudflare_imgbed":
120
+ image_uploader = ImageUploaderFactory.create(
121
+ provider=settings.UPLOAD_PROVIDER,
122
+ base_url=settings.CLOUDFLARE_IMGBED_URL,
123
+ auth_code=settings.CLOUDFLARE_IMGBED_AUTH_CODE,
124
+ )
125
+ else:
126
+ raise ValueError(
127
+ f"Unsupported upload provider: {settings.UPLOAD_PROVIDER}"
128
+ )
129
+
130
+ upload_response = image_uploader.upload(image_data, filename)
131
+
132
+ images_data.append(
133
+ {
134
+ "url": f"{upload_response.data.url}",
135
+ "revised_prompt": request.prompt,
136
+ }
137
+ )
138
+
139
+ response_data = {
140
+ "created": int(time.time()),
141
+ "data": images_data,
142
+ }
143
+ return response_data
144
+ else:
145
+ raise Exception("I can't generate these images")
146
+
147
+ def generate_images_chat(self, request: ImageGenerationRequest) -> str:
148
+ response = self.generate_images(request)
149
+ image_datas = response["data"]
150
+ if image_datas:
151
+ markdown_images = []
152
+ for index, image_data in enumerate(image_datas):
153
+ if "url" in image_data:
154
+ markdown_images.append(
155
+ f"![Generated Image {index+1}]({image_data['url']})"
156
+ )
157
+ else:
158
+ # 如果是base64格式,创建data URL
159
+ markdown_images.append(
160
+ f"![Generated Image {index+1}](data:image/png;base64,{image_data['b64_json']})"
161
+ )
162
+ return "\n".join(markdown_images)
app/service/key/key_manager.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from itertools import cycle
3
+ from typing import Dict, Union
4
+
5
+ from app.config.config import settings
6
+ from app.log.logger import get_key_manager_logger
7
+
8
+ logger = get_key_manager_logger()
9
+
10
+
11
+ class KeyManager:
12
+ def __init__(self, api_keys: list, vertex_api_keys: list):
13
+ self.api_keys = api_keys
14
+ self.vertex_api_keys = vertex_api_keys
15
+ self.key_cycle = cycle(api_keys)
16
+ self.vertex_key_cycle = cycle(vertex_api_keys)
17
+ self.key_cycle_lock = asyncio.Lock()
18
+ self.vertex_key_cycle_lock = asyncio.Lock()
19
+ self.failure_count_lock = asyncio.Lock()
20
+ self.vertex_failure_count_lock = asyncio.Lock()
21
+ self.key_failure_counts: Dict[str, int] = {key: 0 for key in api_keys}
22
+ self.vertex_key_failure_counts: Dict[str, int] = {
23
+ key: 0 for key in vertex_api_keys
24
+ }
25
+ self.MAX_FAILURES = settings.MAX_FAILURES
26
+ self.paid_key = settings.PAID_KEY
27
+
28
+ async def get_paid_key(self) -> str:
29
+ return self.paid_key
30
+
31
+ async def get_next_key(self) -> str:
32
+ """获取下一个API key"""
33
+ async with self.key_cycle_lock:
34
+ return next(self.key_cycle)
35
+
36
+ async def get_next_vertex_key(self) -> str:
37
+ """获取下一个 Vertex API key"""
38
+ async with self.vertex_key_cycle_lock:
39
+ return next(self.vertex_key_cycle)
40
+
41
+ async def is_key_valid(self, key: str) -> bool:
42
+ """检查key是否有效"""
43
+ async with self.failure_count_lock:
44
+ return self.key_failure_counts[key] < self.MAX_FAILURES
45
+
46
+ async def is_vertex_key_valid(self, key: str) -> bool:
47
+ """检查 Vertex key 是否有效"""
48
+ async with self.vertex_failure_count_lock:
49
+ return self.vertex_key_failure_counts[key] < self.MAX_FAILURES
50
+
51
+ async def reset_failure_counts(self):
52
+ """重置所有key的失败计数"""
53
+ async with self.failure_count_lock:
54
+ for key in self.key_failure_counts:
55
+ self.key_failure_counts[key] = 0
56
+
57
+ async def reset_vertex_failure_counts(self):
58
+ """重置所有 Vertex key 的失败计数"""
59
+ async with self.vertex_failure_count_lock:
60
+ for key in self.vertex_key_failure_counts:
61
+ self.vertex_key_failure_counts[key] = 0
62
+
63
+ async def reset_key_failure_count(self, key: str) -> bool:
64
+ """重置指定key的失败计数"""
65
+ async with self.failure_count_lock:
66
+ if key in self.key_failure_counts:
67
+ self.key_failure_counts[key] = 0
68
+ logger.info(f"Reset failure count for key: {key}")
69
+ return True
70
+ logger.warning(
71
+ f"Attempt to reset failure count for non-existent key: {key}"
72
+ )
73
+ return False
74
+
75
+ async def reset_vertex_key_failure_count(self, key: str) -> bool:
76
+ """重置指定 Vertex key 的失败计数"""
77
+ async with self.vertex_failure_count_lock:
78
+ if key in self.vertex_key_failure_counts:
79
+ self.vertex_key_failure_counts[key] = 0
80
+ logger.info(f"Reset failure count for Vertex key: {key}")
81
+ return True
82
+ logger.warning(
83
+ f"Attempt to reset failure count for non-existent Vertex key: {key}"
84
+ )
85
+ return False
86
+
87
+ async def get_next_working_key(self) -> str:
88
+ """获取下一可用的API key"""
89
+ initial_key = await self.get_next_key()
90
+ current_key = initial_key
91
+
92
+ while True:
93
+ if await self.is_key_valid(current_key):
94
+ return current_key
95
+
96
+ current_key = await self.get_next_key()
97
+ if current_key == initial_key:
98
+ return current_key
99
+
100
+ async def get_next_working_vertex_key(self) -> str:
101
+ """获取下一可用的 Vertex API key"""
102
+ initial_key = await self.get_next_vertex_key()
103
+ current_key = initial_key
104
+
105
+ while True:
106
+ if await self.is_vertex_key_valid(current_key):
107
+ return current_key
108
+
109
+ current_key = await self.get_next_vertex_key()
110
+ if current_key == initial_key:
111
+ return current_key
112
+
113
+ async def handle_api_failure(self, api_key: str, retries: int) -> str:
114
+ """处理API调用失败"""
115
+ async with self.failure_count_lock:
116
+ self.key_failure_counts[api_key] += 1
117
+ if self.key_failure_counts[api_key] >= self.MAX_FAILURES:
118
+ logger.warning(
119
+ f"API key {api_key} has failed {self.MAX_FAILURES} times"
120
+ )
121
+ if retries < settings.MAX_RETRIES:
122
+ return await self.get_next_working_key()
123
+ else:
124
+ return ""
125
+
126
+ async def handle_vertex_api_failure(self, api_key: str, retries: int) -> str:
127
+ """处理 Vertex API 调用失败"""
128
+ async with self.vertex_failure_count_lock:
129
+ self.vertex_key_failure_counts[api_key] += 1
130
+ if self.vertex_key_failure_counts[api_key] >= self.MAX_FAILURES:
131
+ logger.warning(
132
+ f"Vertex API key {api_key} has failed {self.MAX_FAILURES} times"
133
+ )
134
+
135
+ def get_fail_count(self, key: str) -> int:
136
+ """获取指定密钥的失败次数"""
137
+ return self.key_failure_counts.get(key, 0)
138
+
139
+ def get_vertex_fail_count(self, key: str) -> int:
140
+ """获取指定 Vertex 密钥的失败次数"""
141
+ return self.vertex_key_failure_counts.get(key, 0)
142
+
143
+ async def get_keys_by_status(self) -> dict:
144
+ """获取分类后的API key列表,包括失败次数"""
145
+ valid_keys = {}
146
+ invalid_keys = {}
147
+
148
+ async with self.failure_count_lock:
149
+ for key in self.api_keys:
150
+ fail_count = self.key_failure_counts[key]
151
+ if fail_count < self.MAX_FAILURES:
152
+ valid_keys[key] = fail_count
153
+ else:
154
+ invalid_keys[key] = fail_count
155
+
156
+ return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
157
+
158
+ async def get_vertex_keys_by_status(self) -> dict:
159
+ """获取分类后的 Vertex API key 列表,包括失败次数"""
160
+ valid_keys = {}
161
+ invalid_keys = {}
162
+
163
+ async with self.vertex_failure_count_lock:
164
+ for key in self.vertex_api_keys:
165
+ fail_count = self.vertex_key_failure_counts[key]
166
+ if fail_count < self.MAX_FAILURES:
167
+ valid_keys[key] = fail_count
168
+ else:
169
+ invalid_keys[key] = fail_count
170
+ return {"valid_keys": valid_keys, "invalid_keys": invalid_keys}
171
+
172
+ async def get_first_valid_key(self) -> str:
173
+ """获取第一个有效的API key"""
174
+ async with self.failure_count_lock:
175
+ for key in self.key_failure_counts:
176
+ if self.key_failure_counts[key] < self.MAX_FAILURES:
177
+ return key
178
+ if self.api_keys:
179
+ return self.api_keys[0]
180
+ if not self.api_keys:
181
+ logger.warning(
182
+ "API key list is empty, cannot get first valid key.")
183
+ return ""
184
+ return self.api_keys[0]
185
+
186
+
187
+ _singleton_instance = None
188
+ _singleton_lock = asyncio.Lock()
189
+ _preserved_failure_counts: Union[Dict[str, int], None] = None
190
+ _preserved_vertex_failure_counts: Union[Dict[str, int], None] = None
191
+ _preserved_old_api_keys_for_reset: Union[list, None] = None
192
+ _preserved_vertex_old_api_keys_for_reset: Union[list, None] = None
193
+ _preserved_next_key_in_cycle: Union[str, None] = None
194
+ _preserved_vertex_next_key_in_cycle: Union[str, None] = None
195
+
196
+
197
+ async def get_key_manager_instance(
198
+ api_keys: list = None, vertex_api_keys: list = None
199
+ ) -> KeyManager:
200
+ """
201
+ 获取 KeyManager 单例实例。
202
+
203
+ 如果尚未创建实例,将使用提供的 api_keys,vertex_api_keys 初始化 KeyManager。
204
+ 如果已创建实例,则忽略 api_keys 参数,返回现有单例。
205
+ 如果在重置后调用,会尝试恢复之前的状态(失败计数、循环位置)。
206
+ """
207
+ global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle
208
+
209
+ async with _singleton_lock:
210
+ if _singleton_instance is None:
211
+ if api_keys is None:
212
+ raise ValueError(
213
+ "API keys are required to initialize or re-initialize the KeyManager instance."
214
+ )
215
+ if vertex_api_keys is None:
216
+ raise ValueError(
217
+ "Vertex API keys are required to initialize or re-initialize the KeyManager instance."
218
+ )
219
+
220
+ if not api_keys:
221
+ logger.warning(
222
+ "Initializing KeyManager with an empty list of API keys."
223
+ )
224
+ if not vertex_api_keys:
225
+ logger.warning(
226
+ "Initializing KeyManager with an empty list of Vertex API keys."
227
+ )
228
+
229
+ _singleton_instance = KeyManager(api_keys, vertex_api_keys)
230
+ logger.info(
231
+ f"KeyManager instance created/re-created with {len(api_keys)} API keys and {len(vertex_api_keys)} Vertex API keys."
232
+ )
233
+
234
+ # 1. 恢复失败计数
235
+ if _preserved_failure_counts:
236
+ current_failure_counts = {
237
+ key: 0 for key in _singleton_instance.api_keys
238
+ }
239
+ for key, count in _preserved_failure_counts.items():
240
+ if key in current_failure_counts:
241
+ current_failure_counts[key] = count
242
+ _singleton_instance.key_failure_counts = current_failure_counts
243
+ logger.info("Inherited failure counts for applicable keys.")
244
+ _preserved_failure_counts = None
245
+
246
+ if _preserved_vertex_failure_counts:
247
+ current_vertex_failure_counts = {
248
+ key: 0 for key in _singleton_instance.vertex_api_keys
249
+ }
250
+ for key, count in _preserved_vertex_failure_counts.items():
251
+ if key in current_vertex_failure_counts:
252
+ current_vertex_failure_counts[key] = count
253
+ _singleton_instance.vertex_key_failure_counts = (
254
+ current_vertex_failure_counts
255
+ )
256
+ logger.info(
257
+ "Inherited failure counts for applicable Vertex keys.")
258
+ _preserved_vertex_failure_counts = None
259
+
260
+ # 2. 调整 key_cycle 的起始点
261
+ start_key_for_new_cycle = None
262
+ if (
263
+ _preserved_old_api_keys_for_reset
264
+ and _preserved_next_key_in_cycle
265
+ and _singleton_instance.api_keys
266
+ ):
267
+ try:
268
+ start_idx_in_old = _preserved_old_api_keys_for_reset.index(
269
+ _preserved_next_key_in_cycle
270
+ )
271
+
272
+ for i in range(len(_preserved_old_api_keys_for_reset)):
273
+ current_old_key_idx = (start_idx_in_old + i) % len(
274
+ _preserved_old_api_keys_for_reset
275
+ )
276
+ key_candidate = _preserved_old_api_keys_for_reset[
277
+ current_old_key_idx
278
+ ]
279
+ if key_candidate in _singleton_instance.api_keys:
280
+ start_key_for_new_cycle = key_candidate
281
+ break
282
+ except ValueError:
283
+ logger.warning(
284
+ f"Preserved next key '{_preserved_next_key_in_cycle}' not found in preserved old API keys. "
285
+ "New cycle will start from the beginning of the new list."
286
+ )
287
+ except Exception as e:
288
+ logger.error(
289
+ f"Error determining start key for new cycle from preserved state: {e}. "
290
+ "New cycle will start from the beginning."
291
+ )
292
+
293
+ if start_key_for_new_cycle and _singleton_instance.api_keys:
294
+ try:
295
+ target_idx = _singleton_instance.api_keys.index(
296
+ start_key_for_new_cycle
297
+ )
298
+ for _ in range(target_idx):
299
+ next(_singleton_instance.key_cycle)
300
+ logger.info(
301
+ f"Key cycle in new instance advanced. Next call to get_next_key() will yield: {start_key_for_new_cycle}"
302
+ )
303
+ except ValueError:
304
+ logger.warning(
305
+ f"Determined start key '{start_key_for_new_cycle}' not found in new API keys during cycle advancement. "
306
+ "New cycle will start from the beginning."
307
+ )
308
+ except StopIteration:
309
+ logger.error(
310
+ "StopIteration while advancing key cycle, implies empty new API key list previously missed."
311
+ )
312
+ except Exception as e:
313
+ logger.error(
314
+ f"Error advancing new key cycle: {e}. Cycle will start from beginning."
315
+ )
316
+ else:
317
+ if _singleton_instance.api_keys:
318
+ logger.info(
319
+ "New key cycle will start from the beginning of the new API key list (no specific start key determined or needed)."
320
+ )
321
+ else:
322
+ logger.info(
323
+ "New key cycle not applicable as the new API key list is empty."
324
+ )
325
+
326
+ # 清理所有保存的状态
327
+ _preserved_old_api_keys_for_reset = None
328
+ _preserved_next_key_in_cycle = None
329
+
330
+ # 3. 调整 vertex_key_cycle 的起始点
331
+ start_key_for_new_vertex_cycle = None
332
+ if (
333
+ _preserved_vertex_old_api_keys_for_reset
334
+ and _preserved_vertex_next_key_in_cycle
335
+ and _singleton_instance.vertex_api_keys
336
+ ):
337
+ try:
338
+ start_idx_in_old = _preserved_vertex_old_api_keys_for_reset.index(
339
+ _preserved_vertex_next_key_in_cycle
340
+ )
341
+
342
+ for i in range(len(_preserved_vertex_old_api_keys_for_reset)):
343
+ current_old_key_idx = (start_idx_in_old + i) % len(
344
+ _preserved_vertex_old_api_keys_for_reset
345
+ )
346
+ key_candidate = _preserved_vertex_old_api_keys_for_reset[
347
+ current_old_key_idx
348
+ ]
349
+ if key_candidate in _singleton_instance.vertex_api_keys:
350
+ start_key_for_new_vertex_cycle = key_candidate
351
+ break
352
+ except ValueError:
353
+ logger.warning(
354
+ f"Preserved next key '{_preserved_vertex_next_key_in_cycle}' not found in preserved old Vertex API keys. "
355
+ "New cycle will start from the beginning of the new list."
356
+ )
357
+ except Exception as e:
358
+ logger.error(
359
+ f"Error determining start key for new Vertex key cycle from preserved state: {e}. "
360
+ "New cycle will start from the beginning."
361
+ )
362
+
363
+ if start_key_for_new_vertex_cycle and _singleton_instance.vertex_api_keys:
364
+ try:
365
+ target_idx = _singleton_instance.vertex_api_keys.index(
366
+ start_key_for_new_vertex_cycle
367
+ )
368
+ for _ in range(target_idx):
369
+ next(_singleton_instance.vertex_key_cycle)
370
+ logger.info(
371
+ f"Vertex key cycle in new instance advanced. Next call to get_next_vertex_key() will yield: {start_key_for_new_vertex_cycle}"
372
+ )
373
+ except ValueError:
374
+ logger.warning(
375
+ f"Determined start key '{start_key_for_new_vertex_cycle}' not found in new Vertex API keys during cycle advancement. "
376
+ "New cycle will start from the beginning."
377
+ )
378
+ except StopIteration:
379
+ logger.error(
380
+ "StopIteration while advancing Vertex key cycle, implies empty new Vertex API key list previously missed."
381
+ )
382
+ except Exception as e:
383
+ logger.error(
384
+ f"Error advancing new Vertex key cycle: {e}. Cycle will start from beginning."
385
+ )
386
+ else:
387
+ if _singleton_instance.vertex_api_keys:
388
+ logger.info(
389
+ "New Vertex key cycle will start from the beginning of the new Vertex API key list (no specific start key determined or needed)."
390
+ )
391
+ else:
392
+ logger.info(
393
+ "New Vertex key cycle not applicable as the new Vertex API key list is empty."
394
+ )
395
+
396
+ # 清理所有保存的状态
397
+ _preserved_vertex_old_api_keys_for_reset = None
398
+ _preserved_vertex_next_key_in_cycle = None
399
+
400
+ return _singleton_instance
401
+
402
+
403
+ async def reset_key_manager_instance():
404
+ """
405
+ 重置 KeyManager 单例实例。
406
+ 将保存当前实例的状态(失败计数、旧 API keys、下一个 key 提示)
407
+ 以供下一次 get_key_manager_instance 调用时恢复。
408
+ """
409
+ global _singleton_instance, _preserved_failure_counts, _preserved_vertex_failure_counts, _preserved_old_api_keys_for_reset, _preserved_vertex_old_api_keys_for_reset, _preserved_next_key_in_cycle, _preserved_vertex_next_key_in_cycle
410
+ async with _singleton_lock:
411
+ if _singleton_instance:
412
+ # 1. 保存失败计数
413
+ _preserved_failure_counts = _singleton_instance.key_failure_counts.copy()
414
+ _preserved_vertex_failure_counts = _singleton_instance.vertex_key_failure_counts.copy()
415
+
416
+ # 2. 保存旧的 API keys 列表
417
+ _preserved_old_api_keys_for_reset = _singleton_instance.api_keys.copy()
418
+ _preserved_vertex_old_api_keys_for_reset = _singleton_instance.vertex_api_keys.copy()
419
+
420
+ # 3. 保存 key_cycle 的下一个 key 提示
421
+ try:
422
+ if _singleton_instance.api_keys:
423
+ _preserved_next_key_in_cycle = (
424
+ await _singleton_instance.get_next_key()
425
+ )
426
+ else:
427
+ _preserved_next_key_in_cycle = None
428
+ except StopIteration:
429
+ logger.warning(
430
+ "Could not preserve next key hint: key cycle was empty or exhausted in old instance."
431
+ )
432
+ _preserved_next_key_in_cycle = None
433
+ except Exception as e:
434
+ logger.error(
435
+ f"Error preserving next key hint during reset: {e}")
436
+ _preserved_next_key_in_cycle = None
437
+
438
+ # 4. 保存 vertex_key_cycle 的下一个 key 提示
439
+ try:
440
+ if _singleton_instance.vertex_api_keys:
441
+ _preserved_vertex_next_key_in_cycle = (
442
+ await _singleton_instance.get_next_vertex_key()
443
+ )
444
+ else:
445
+ _preserved_vertex_next_key_in_cycle = None
446
+ except StopIteration:
447
+ logger.warning(
448
+ "Could not preserve next key hint: Vertex key cycle was empty or exhausted in old instance."
449
+ )
450
+ _preserved_vertex_next_key_in_cycle = None
451
+ except Exception as e:
452
+ logger.error(
453
+ f"Error preserving next key hint during reset: {e}")
454
+ _preserved_vertex_next_key_in_cycle = None
455
+
456
+ _singleton_instance = None
457
+ logger.info(
458
+ "KeyManager instance has been reset. State (failure counts, old keys, next key hint) preserved for next instantiation."
459
+ )
460
+ else:
461
+ logger.info(
462
+ "KeyManager instance was not set (or already reset), no reset action performed."
463
+ )
app/service/model/model_service.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ from typing import Any, Dict, Optional
3
+
4
+ from app.config.config import settings
5
+ from app.log.logger import get_model_logger
6
+ from app.service.client.api_client import GeminiApiClient
7
+
8
+ logger = get_model_logger()
9
+
10
+
11
+ class ModelService:
12
+ async def get_gemini_models(self, api_key: str) -> Optional[Dict[str, Any]]:
13
+ api_client = GeminiApiClient(base_url=settings.BASE_URL)
14
+ gemini_models = await api_client.get_models(api_key)
15
+
16
+ if gemini_models is None:
17
+ logger.error("从 API 客户端获取模型列表失败。")
18
+ return None
19
+
20
+ try:
21
+ filtered_models_list = []
22
+ for model in gemini_models.get("models", []):
23
+ model_id = model["name"].split("/")[-1]
24
+ if model_id not in settings.FILTERED_MODELS:
25
+ filtered_models_list.append(model)
26
+ else:
27
+ logger.debug(f"Filtered out model: {model_id}")
28
+
29
+ gemini_models["models"] = filtered_models_list
30
+ return gemini_models
31
+ except Exception as e:
32
+ logger.error(f"处理模型列表时出错: {e}")
33
+ return None
34
+
35
+ async def get_gemini_openai_models(self, api_key: str) -> Optional[Dict[str, Any]]:
36
+ """获取 Gemini 模型并转换为 OpenAI 格式"""
37
+ gemini_models = await self.get_gemini_models(api_key)
38
+ if gemini_models is None:
39
+ return None
40
+
41
+ return await self.convert_to_openai_models_format(gemini_models)
42
+
43
+ async def convert_to_openai_models_format(
44
+ self, gemini_models: Dict[str, Any]
45
+ ) -> Dict[str, Any]:
46
+ openai_format = {"object": "list", "data": [], "success": True}
47
+
48
+ for model in gemini_models.get("models", []):
49
+ model_id = model["name"].split("/")[-1]
50
+ openai_model = {
51
+ "id": model_id,
52
+ "object": "model",
53
+ "created": int(datetime.now(timezone.utc).timestamp()),
54
+ "owned_by": "google",
55
+ "permission": [],
56
+ "root": model["name"],
57
+ "parent": None,
58
+ }
59
+ openai_format["data"].append(openai_model)
60
+
61
+ if model_id in settings.SEARCH_MODELS:
62
+ search_model = openai_model.copy()
63
+ search_model["id"] = f"{model_id}-search"
64
+ openai_format["data"].append(search_model)
65
+ if model_id in settings.IMAGE_MODELS:
66
+ image_model = openai_model.copy()
67
+ image_model["id"] = f"{model_id}-image"
68
+ openai_format["data"].append(image_model)
69
+ if model_id in settings.THINKING_MODELS:
70
+ non_thinking_model = openai_model.copy()
71
+ non_thinking_model["id"] = f"{model_id}-non-thinking"
72
+ openai_format["data"].append(non_thinking_model)
73
+
74
+ if settings.CREATE_IMAGE_MODEL:
75
+ image_model = openai_model.copy()
76
+ image_model["id"] = f"{settings.CREATE_IMAGE_MODEL}-chat"
77
+ openai_format["data"].append(image_model)
78
+ return openai_format
79
+
80
+ async def check_model_support(self, model: str) -> bool:
81
+ if not model or not isinstance(model, str):
82
+ return False
83
+
84
+ model = model.strip()
85
+ if model.endswith("-search"):
86
+ model = model[:-7]
87
+ return model in settings.SEARCH_MODELS
88
+ if model.endswith("-image"):
89
+ model = model[:-6]
90
+ return model in settings.IMAGE_MODELS
91
+
92
+ return model not in settings.FILTERED_MODELS
app/service/openai_compatiable/openai_compatiable_service.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import datetime
3
+ import json
4
+ import re
5
+ import time
6
+ from typing import Any, AsyncGenerator, Dict, Union
7
+
8
+ from app.config.config import settings
9
+ from app.database.services import (
10
+ add_error_log,
11
+ add_request_log,
12
+ )
13
+ from app.domain.openai_models import ChatRequest, ImageGenerationRequest
14
+ from app.service.client.api_client import OpenaiApiClient
15
+ from app.service.key.key_manager import KeyManager
16
+ from app.log.logger import get_openai_compatible_logger
17
+
18
+ logger = get_openai_compatible_logger()
19
+
20
+ class OpenAICompatiableService:
21
+
22
+ def __init__(self, base_url: str, key_manager: KeyManager = None):
23
+ self.key_manager = key_manager
24
+ self.base_url = base_url
25
+ self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
26
+
27
+ async def get_models(self, api_key: str) -> Dict[str, Any]:
28
+ return await self.api_client.get_models(api_key)
29
+
30
+ async def create_chat_completion(
31
+ self,
32
+ request: ChatRequest,
33
+ api_key: str,
34
+ ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
35
+ """创建聊天完成"""
36
+ request_dict = request.model_dump()
37
+ # 移除值为null的
38
+ request_dict = {k: v for k, v in request_dict.items() if v is not None}
39
+ del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
40
+ if request.stream:
41
+ return self._handle_stream_completion(request.model, request_dict, api_key)
42
+ return await self._handle_normal_completion(request.model, request_dict, api_key)
43
+
44
+ async def generate_images(
45
+ self,
46
+ request: ImageGenerationRequest,
47
+ ) -> Dict[str, Any]:
48
+ """生成图片"""
49
+ request_dict = request.model_dump()
50
+ # 移除值为null的
51
+ request_dict = {k: v for k, v in request_dict.items() if v is not None}
52
+ api_key = settings.PAID_KEY
53
+ return await self.api_client.generate_images(request_dict, api_key)
54
+
55
+ async def create_embeddings(
56
+ self,
57
+ input_text: str,
58
+ model: str,
59
+ api_key: str,
60
+ ) -> Dict[str, Any]:
61
+ """创建嵌入"""
62
+ return await self.api_client.create_embeddings(input_text, model, api_key)
63
+
64
+ async def _handle_normal_completion(
65
+ self, model: str, request: dict, api_key: str
66
+ ) -> Dict[str, Any]:
67
+ """处理普通聊天完成"""
68
+ start_time = time.perf_counter()
69
+ request_datetime = datetime.datetime.now()
70
+ is_success = False
71
+ status_code = None
72
+ response = None
73
+ try:
74
+ response = await self.api_client.generate_content(request, api_key)
75
+ is_success = True
76
+ status_code = 200
77
+ return response
78
+ except Exception as e:
79
+ is_success = False
80
+ error_log_msg = str(e)
81
+ logger.error(f"Normal API call failed with error: {error_log_msg}")
82
+ match = re.search(r"status code (\d+)", error_log_msg)
83
+ if match:
84
+ status_code = int(match.group(1))
85
+ else:
86
+ status_code = 500
87
+
88
+ await add_error_log(
89
+ gemini_key=api_key,
90
+ model_name=model,
91
+ error_type="openai-compatiable-non-stream",
92
+ error_log=error_log_msg,
93
+ error_code=status_code,
94
+ request_msg=request,
95
+ )
96
+ raise e
97
+ finally:
98
+ end_time = time.perf_counter()
99
+ latency_ms = int((end_time - start_time) * 1000)
100
+ await add_request_log(
101
+ model_name=model,
102
+ api_key=api_key,
103
+ is_success=is_success,
104
+ status_code=status_code,
105
+ latency_ms=latency_ms,
106
+ request_time=request_datetime,
107
+ )
108
+
109
+ async def _handle_stream_completion(
110
+ self, model: str, payload: dict, api_key: str
111
+ ) -> AsyncGenerator[str, None]:
112
+ """处理流式聊天完成,添加重试逻辑"""
113
+ retries = 0
114
+ max_retries = settings.MAX_RETRIES
115
+ is_success = False
116
+ status_code = None
117
+ final_api_key = api_key
118
+
119
+ while retries < max_retries:
120
+ start_time = time.perf_counter()
121
+ request_datetime = datetime.datetime.now()
122
+ current_attempt_key = api_key
123
+ final_api_key = current_attempt_key
124
+ try:
125
+ async for line in self.api_client.stream_generate_content(
126
+ payload, current_attempt_key
127
+ ):
128
+ if line.startswith("data:"):
129
+ # print(line)
130
+ yield line + "\n\n"
131
+ logger.info("Streaming completed successfully")
132
+ is_success = True
133
+ status_code = 200
134
+ break
135
+ except Exception as e:
136
+ retries += 1
137
+ is_success = False
138
+ error_log_msg = str(e)
139
+ logger.warning(
140
+ f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
141
+ )
142
+ match = re.search(r"status code (\d+)", error_log_msg)
143
+ if match:
144
+ status_code = int(match.group(1))
145
+ else:
146
+ status_code = 500
147
+
148
+ await add_error_log(
149
+ gemini_key=current_attempt_key,
150
+ model_name=model,
151
+ error_type="openai-compatiable-stream",
152
+ error_log=error_log_msg,
153
+ error_code=status_code,
154
+ request_msg=payload,
155
+ )
156
+
157
+ if self.key_manager:
158
+ api_key = await self.key_manager.handle_api_failure(
159
+ current_attempt_key, retries
160
+ )
161
+ if api_key:
162
+ logger.info(f"Switched to new API key: {api_key}")
163
+ else:
164
+ logger.error(
165
+ f"No valid API key available after {retries} retries."
166
+ )
167
+ break
168
+ else:
169
+ logger.error("KeyManager not available for retry logic.")
170
+ break
171
+
172
+ if retries >= max_retries:
173
+ logger.error(f"Max retries ({max_retries}) reached for streaming.")
174
+ break
175
+ finally:
176
+ end_time = time.perf_counter()
177
+ latency_ms = int((end_time - start_time) * 1000)
178
+ await add_request_log(
179
+ model_name=model,
180
+ api_key=final_api_key,
181
+ is_success=is_success,
182
+ status_code=status_code,
183
+ latency_ms=latency_ms,
184
+ request_time=request_datetime,
185
+ )
186
+ if not is_success and retries >= max_retries:
187
+ yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
188
+ yield "data: [DONE]\n\n"
189
+
190
+
app/service/request_log/request_log_service.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service for request log operations.
3
+ """
4
+
5
+ from datetime import datetime, timedelta, timezone
6
+
7
+ from sqlalchemy import delete
8
+
9
+ from app.database.connection import database
10
+ from app.config.config import settings
11
+ from app.database.models import RequestLog
12
+ from app.log.logger import get_request_log_logger
13
+
14
+ logger = get_request_log_logger()
15
+
16
+
17
+ async def delete_old_request_logs_task():
18
+ """
19
+ 定时删除旧的请求日志。
20
+ """
21
+ if not settings.AUTO_DELETE_REQUEST_LOGS_ENABLED:
22
+ logger.info(
23
+ "Auto-delete for request logs is disabled by settings. Skipping task."
24
+ )
25
+ return
26
+
27
+ days_to_keep = settings.AUTO_DELETE_REQUEST_LOGS_DAYS
28
+ logger.info(
29
+ f"Starting scheduled task to delete old request logs older than {days_to_keep} days."
30
+ )
31
+
32
+ try:
33
+ cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
34
+
35
+ query = delete(RequestLog).where(RequestLog.request_time < cutoff_date)
36
+
37
+ if not database.is_connected:
38
+ logger.info("Connecting to database for request log deletion.")
39
+ await database.connect()
40
+
41
+ result = await database.execute(query)
42
+ logger.info(
43
+ f"Request logs older than {cutoff_date} potentially deleted. Rows affected: {result.rowcount if result else 'N/A'}"
44
+ )
45
+
46
+ except Exception as e:
47
+ logger.error(
48
+ f"An error occurred during the scheduled request log deletion: {str(e)}",
49
+ exc_info=True,
50
+ )
app/service/stats/stats_service.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/service/stats_service.py
2
+
3
+ import datetime
4
+ from typing import Union
5
+
6
+ from sqlalchemy import and_, case, func, or_, select
7
+
8
+ from app.database.connection import database
9
+ from app.database.models import RequestLog
10
+ from app.log.logger import get_stats_logger
11
+
12
+ logger = get_stats_logger()
13
+
14
+
15
+ class StatsService:
16
+ """Service class for handling statistics related operations."""
17
+
18
+ async def get_calls_in_last_seconds(self, seconds: int) -> dict[str, int]:
19
+ """获取过去 N 秒内的调用次数 (总数、成功、失败)"""
20
+ try:
21
+ cutoff_time = datetime.datetime.now() - datetime.timedelta(seconds=seconds)
22
+ query = select(
23
+ func.count(RequestLog.id).label("total"),
24
+ func.sum(
25
+ case(
26
+ (
27
+ and_(
28
+ RequestLog.status_code >= 200,
29
+ RequestLog.status_code < 300,
30
+ ),
31
+ 1,
32
+ ),
33
+ else_=0,
34
+ )
35
+ ).label("success"),
36
+ func.sum(
37
+ case(
38
+ (
39
+ or_(
40
+ RequestLog.status_code < 200,
41
+ RequestLog.status_code >= 300,
42
+ ),
43
+ 1,
44
+ ),
45
+ (RequestLog.status_code is None, 1),
46
+ else_=0,
47
+ )
48
+ ).label("failure"),
49
+ ).where(RequestLog.request_time >= cutoff_time)
50
+ result = await database.fetch_one(query)
51
+ if result:
52
+ return {
53
+ "total": result["total"] or 0,
54
+ "success": result["success"] or 0,
55
+ "failure": result["failure"] or 0,
56
+ }
57
+ return {"total": 0, "success": 0, "failure": 0}
58
+ except Exception as e:
59
+ logger.error(f"Failed to get calls in last {seconds} seconds: {e}")
60
+ return {"total": 0, "success": 0, "failure": 0}
61
+
62
+ async def get_calls_in_last_minutes(self, minutes: int) -> dict[str, int]:
63
+ """获取过去 N 分钟内的调用次数 (总数、成功、失败)"""
64
+ return await self.get_calls_in_last_seconds(minutes * 60)
65
+
66
+ async def get_calls_in_last_hours(self, hours: int) -> dict[str, int]:
67
+ """获取过去 N 小时内的调用次数 (总数、成功、失败)"""
68
+ return await self.get_calls_in_last_seconds(hours * 3600)
69
+
70
+ async def get_calls_in_current_month(self) -> dict[str, int]:
71
+ """获取当前自然月内的调用次数 (总数、成功、失败)"""
72
+ try:
73
+ now = datetime.datetime.now()
74
+ start_of_month = now.replace(
75
+ day=1, hour=0, minute=0, second=0, microsecond=0
76
+ )
77
+ query = select(
78
+ func.count(RequestLog.id).label("total"),
79
+ func.sum(
80
+ case(
81
+ (
82
+ and_(
83
+ RequestLog.status_code >= 200,
84
+ RequestLog.status_code < 300,
85
+ ),
86
+ 1,
87
+ ),
88
+ else_=0,
89
+ )
90
+ ).label("success"),
91
+ func.sum(
92
+ case(
93
+ (
94
+ or_(
95
+ RequestLog.status_code < 200,
96
+ RequestLog.status_code >= 300,
97
+ ),
98
+ 1,
99
+ ),
100
+ (RequestLog.status_code is None, 1),
101
+ else_=0,
102
+ )
103
+ ).label("failure"),
104
+ ).where(RequestLog.request_time >= start_of_month)
105
+ result = await database.fetch_one(query)
106
+ if result:
107
+ return {
108
+ "total": result["total"] or 0,
109
+ "success": result["success"] or 0,
110
+ "failure": result["failure"] or 0,
111
+ }
112
+ return {"total": 0, "success": 0, "failure": 0}
113
+ except Exception as e:
114
+ logger.error(f"Failed to get calls in current month: {e}")
115
+ return {"total": 0, "success": 0, "failure": 0}
116
+
117
+ async def get_api_usage_stats(self) -> dict:
118
+ """获取所有需要的 API 使用统计数据 (总数、成功、失败)"""
119
+ try:
120
+ stats_1m = await self.get_calls_in_last_minutes(1)
121
+ stats_1h = await self.get_calls_in_last_hours(1)
122
+ stats_24h = await self.get_calls_in_last_hours(24)
123
+ stats_month = await self.get_calls_in_current_month()
124
+
125
+ return {
126
+ "calls_1m": stats_1m,
127
+ "calls_1h": stats_1h,
128
+ "calls_24h": stats_24h,
129
+ "calls_month": stats_month,
130
+ }
131
+ except Exception as e:
132
+ logger.error(f"Failed to get API usage stats: {e}")
133
+ default_stat = {"total": 0, "success": 0, "failure": 0}
134
+ return {
135
+ "calls_1m": default_stat.copy(),
136
+ "calls_1h": default_stat.copy(),
137
+ "calls_24h": default_stat.copy(),
138
+ "calls_month": default_stat.copy(),
139
+ }
140
+
141
+ async def get_api_call_details(self, period: str) -> list[dict]:
142
+ """
143
+ 获取指定时间段内的 API 调用详情
144
+
145
+ Args:
146
+ period: 时间段标识 ('1m', '1h', '24h')
147
+
148
+ Returns:
149
+ 包含调用详情的字典列表,每个字典包含 timestamp, key, model, status
150
+
151
+ Raises:
152
+ ValueError: 如果 period 无效
153
+ """
154
+ now = datetime.datetime.now()
155
+ if period == "1m":
156
+ start_time = now - datetime.timedelta(minutes=1)
157
+ elif period == "1h":
158
+ start_time = now - datetime.timedelta(hours=1)
159
+ elif period == "24h":
160
+ start_time = now - datetime.timedelta(hours=24)
161
+ else:
162
+ raise ValueError(f"无效的时间段标识: {period}")
163
+
164
+ try:
165
+ query = (
166
+ select(
167
+ RequestLog.request_time.label("timestamp"),
168
+ RequestLog.api_key.label("key"),
169
+ RequestLog.model_name.label("model"),
170
+ RequestLog.status_code,
171
+ )
172
+ .where(RequestLog.request_time >= start_time)
173
+ .order_by(RequestLog.request_time.desc())
174
+ )
175
+
176
+ results = await database.fetch_all(query)
177
+
178
+ details = []
179
+ for row in results:
180
+ status = "failure"
181
+ if row["status_code"] is not None:
182
+ status = "success" if 200 <= row["status_code"] < 300 else "failure"
183
+ details.append(
184
+ {
185
+ "timestamp": row[
186
+ "timestamp"
187
+ ].isoformat(),
188
+ "key": row["key"],
189
+ "model": row["model"],
190
+ "status": status,
191
+ }
192
+ )
193
+ logger.info(
194
+ f"Retrieved {len(details)} API call details for period '{period}'"
195
+ )
196
+ return details
197
+
198
+ except Exception as e:
199
+ logger.error(
200
+ f"Failed to get API call details for period '{period}': {e}")
201
+ raise
202
+
203
+ async def get_key_usage_details_last_24h(self, key: str) -> Union[dict, None]:
204
+ """
205
+ 获取指定 API 密钥在过去 24 小时内按模型统计的调用次数。
206
+
207
+ Args:
208
+ key: 要查询的 API 密钥。
209
+
210
+ Returns:
211
+ 一个字典,其中键是模型名称,值是调用次数。
212
+ 如果查询出错或没有找到记录,可能返回 None 或空字典。
213
+ Example: {"gemini-pro": 10, "gemini-1.5-pro-latest": 5}
214
+ """
215
+ logger.info(
216
+ f"Fetching usage details for key ending in ...{key[-4:]} for the last 24h."
217
+ )
218
+ cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=24)
219
+
220
+ try:
221
+ query = (
222
+ select(
223
+ RequestLog.model_name, func.count(
224
+ RequestLog.id).label("call_count")
225
+ )
226
+ .where(
227
+ RequestLog.api_key == key,
228
+ RequestLog.request_time >= cutoff_time,
229
+ RequestLog.model_name.isnot(None),
230
+ )
231
+ .group_by(RequestLog.model_name)
232
+ .order_by(func.count(RequestLog.id).desc())
233
+ )
234
+
235
+ results = await database.fetch_all(query)
236
+
237
+ if not results:
238
+ logger.info(
239
+ f"No usage details found for key ending in ...{key[-4:]} in the last 24h."
240
+ )
241
+ return {}
242
+
243
+ usage_details = {row["model_name"]: row["call_count"]
244
+ for row in results}
245
+ logger.info(
246
+ f"Successfully fetched usage details for key ending in ...{key[-4:]}: {usage_details}"
247
+ )
248
+ return usage_details
249
+
250
+ except Exception as e:
251
+ logger.error(
252
+ f"Failed to get key usage details for key ending in ...{key[-4:]}: {e}",
253
+ exc_info=True,
254
+ )
255
+ raise
app/service/tts/tts_service.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import io
3
+ import re
4
+ import time
5
+ import wave
6
+ from typing import Optional
7
+
8
+ from google import genai
9
+
10
+ from app.config.config import settings
11
+ from app.database.services import add_error_log, add_request_log
12
+ from app.domain.openai_models import TTSRequest
13
+ from app.log.logger import get_openai_logger
14
+
15
+ logger = get_openai_logger()
16
+
17
+
18
+ def _create_wav_file(audio_data: bytes) -> bytes:
19
+ """Creates a WAV file in memory from raw audio data."""
20
+ with io.BytesIO() as wav_file:
21
+ with wave.open(wav_file, "wb") as wf:
22
+ wf.setnchannels(1) # Mono
23
+ wf.setsampwidth(2) # 16-bit
24
+ wf.setframerate(24000) # 24kHz sample rate
25
+ wf.writeframes(audio_data)
26
+ return wav_file.getvalue()
27
+
28
+
29
+ class TTSService:
30
+ async def create_tts(self, request: TTSRequest, api_key: str) -> Optional[bytes]:
31
+ """
32
+ 使用 Google Gemini SDK 创建音频。
33
+ """
34
+ start_time = time.perf_counter()
35
+ request_datetime = datetime.datetime.now()
36
+ is_success = False
37
+ status_code = None
38
+ response = None
39
+ error_log_msg = ""
40
+ try:
41
+ client = genai.Client(api_key=api_key)
42
+ response =await client.aio.models.generate_content(
43
+ model=settings.TTS_MODEL,
44
+ contents=f"Speak in a {settings.TTS_SPEED} speed voice: {request.input}",
45
+ config={
46
+ "response_modalities": ["Audio"],
47
+ "speech_config": {
48
+ "voice_config": {
49
+ "prebuilt_voice_config": {
50
+ "voice_name": settings.TTS_VOICE_NAME
51
+ }
52
+ }
53
+ },
54
+ },
55
+ )
56
+ if (
57
+ response.candidates
58
+ and response.candidates[0].content.parts
59
+ and response.candidates[0].content.parts[0].inline_data
60
+ ):
61
+ raw_audio_data = response.candidates[0].content.parts[0].inline_data.data
62
+ is_success = True
63
+ status_code = 200
64
+ return _create_wav_file(raw_audio_data)
65
+ except Exception as e:
66
+ is_success = False
67
+ error_log_msg = f"Generic error: {e}"
68
+ logger.error(f"An error occurred in TTSService: {error_log_msg}")
69
+ match = re.search(r"status code (\d+)", str(e))
70
+ if match:
71
+ status_code = int(match.group(1))
72
+ else:
73
+ status_code = 500
74
+ raise
75
+ finally:
76
+ end_time = time.perf_counter()
77
+ latency_ms = int((end_time - start_time) * 1000)
78
+ if not is_success:
79
+ await add_error_log(
80
+ gemini_key=api_key,
81
+ model_name=settings.TTS_MODEL,
82
+ error_type="google-tts",
83
+ error_log=error_log_msg,
84
+ error_code=status_code,
85
+ request_msg=request.input
86
+ )
87
+ await add_request_log(
88
+ model_name=settings.TTS_MODEL,
89
+ api_key=api_key,
90
+ is_success=is_success,
91
+ status_code=status_code,
92
+ latency_ms=latency_ms,
93
+ request_time=request_datetime
94
+ )