Spaces:
Running
Running
File size: 9,716 Bytes
d0dd276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from app.models.schemas import ErrorResponse
from app.services import GeminiClient
from app.utils import (
APIKeyManager,
test_api_key,
ResponseCacheManager,
ActiveRequestsManager,
check_version,
schedule_cache_cleanup,
handle_exception,
log
)
from app.config.persistence import save_settings, load_settings
from app.api import router, init_router, dashboard_router, init_dashboard_router
from app.vertex.vertex_ai_init import init_vertex_ai
from app.vertex.credentials_manager import CredentialManager
import app.config.settings as settings
from app.config.safety import SAFETY_SETTINGS, SAFETY_SETTINGS_G2
import asyncio
import sys
import pathlib
import os
# 设置模板目录
BASE_DIR = pathlib.Path(__file__).parent
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
app = FastAPI(limit="50M")
# --------------- CORS 中间件 ---------------
# 如果 ALLOWED_ORIGINS 为空列表,则不允许任何跨域请求
if settings.ALLOWED_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --------------- 全局实例 ---------------
load_settings()
# 初始化API密钥管理器
key_manager = APIKeyManager()
# 创建全局缓存字典,将作为缓存管理器的内部存储
response_cache = {}
# 初始化缓存管理器,使用全局字典作为存储
response_cache_manager = ResponseCacheManager(
expiry_time=settings.CACHE_EXPIRY_TIME,
max_entries=settings.MAX_CACHE_ENTRIES,
cache_dict=response_cache
)
# 活跃请求池 - 将作为活跃请求管理器的内部存储
active_requests_pool = {}
# 初始化活跃请求管理器
active_requests_manager = ActiveRequestsManager(requests_pool=active_requests_pool)
SKIP_CHECK_API_KEY = os.environ.get("SKIP_CHECK_API_KEY", "").lower() == "true"
# --------------- 工具函数 ---------------
# @app.middleware("http")
# async def log_requests(request: Request, call_next):
# """
# DEBUG用,接收并打印请求内容
# """
# log('info', f"接收到请求: {request.method} {request.url}")
# try:
# body = await request.json()
# log('info', f"请求体: {body}")
# except Exception:
# log('info', "请求体不是 JSON 格式或者为空")
# response = await call_next(request)
# return response
async def check_remaining_keys_async(keys_to_check: list, initial_invalid_keys: list):
"""
在后台异步检查剩余的 API 密钥。
"""
local_invalid_keys = []
found_valid_keys =False
log('info', f" 开始在后台检查剩余 API Key 是否有效")
for key in keys_to_check:
is_valid = await test_api_key(key)
if is_valid:
if key not in key_manager.api_keys: # 避免重复添加
key_manager.api_keys.append(key)
found_valid_keys = True
# log('info', f"API Key {key[:8]}... 有效")
else:
local_invalid_keys.append(key)
log('warning', f" API Key {key[:8]}... 无效")
await asyncio.sleep(0.05) # 短暂休眠,避免请求过于密集
if found_valid_keys:
key_manager._reset_key_stack() # 如果找到新的有效key,重置栈
# 合并所有无效密钥 (初始无效 + 后台检查出的无效)
combined_invalid_keys = list(set(initial_invalid_keys + local_invalid_keys))
# 获取当前设置中的无效密钥
current_invalid_keys_str = settings.INVALID_API_KEYS or ""
current_invalid_keys_set = set(k.strip() for k in current_invalid_keys_str.split(',') if k.strip())
# 更新无效密钥集合
new_invalid_keys_set = current_invalid_keys_set.union(set(combined_invalid_keys))
# 只有当无效密钥列表发生变化时才保存
if new_invalid_keys_set != current_invalid_keys_set:
settings.INVALID_API_KEYS = ','.join(sorted(list(new_invalid_keys_set)))
save_settings()
log('info', f"密钥检查任务完成。当前总可用密钥数量: {len(key_manager.api_keys)}")
# 设置全局异常处理
sys.excepthook = handle_exception
# --------------- 事件处理 ---------------
@app.on_event("startup")
async def startup_event():
# 首先加载持久化设置,确保所有配置都是最新的
load_settings()
# 重新加载vertex配置,确保获取到最新的持久化设置
import app.vertex.config as vertex_config
vertex_config.reload_config()
# 初始化CredentialManager
credential_manager_instance = CredentialManager()
# 添加到应用程序状态
app.state.credential_manager = credential_manager_instance
# 初始化Vertex AI服务
await init_vertex_ai(credential_manager=credential_manager_instance)
schedule_cache_cleanup(response_cache_manager, active_requests_manager)
# 检查版本
await check_version()
# 密钥检查
initial_keys = key_manager.api_keys.copy()
key_manager.api_keys = [] # 清空,等待检查结果
first_valid_key = None
initial_invalid_keys = []
keys_to_check_later = []
# 阻塞式查找第一个有效密钥
for index, key in enumerate(initial_keys):
is_valid = await test_api_key(key)
if is_valid:
log('info', f"找到第一个有效密钥: {key[:8]}...")
first_valid_key = key
key_manager.api_keys.append(key) # 添加到管理器
key_manager._reset_key_stack()
# 将剩余的key放入后台检查列表
keys_to_check_later = initial_keys[index + 1:]
break # 找到即停止
else:
log('warning', f"密钥 {key[:8]}... 无效")
initial_invalid_keys.append(key)
if not first_valid_key:
log('error', "启动时未能找到任何有效 API 密钥!")
keys_to_check_later = [] # 没有有效key,无需后台检查
else:
# 使用第一个有效密钥加载模型
try:
all_models = await GeminiClient.list_available_models(first_valid_key)
GeminiClient.AVAILABLE_MODELS = [model.replace("models/", "") for model in all_models]
log('info', f"使用密钥 {first_valid_key[:8]}... 加载可用模型成功")
except Exception as e:
log('warning', f"使用密钥 {first_valid_key[:8]}... 加载可用模型失败",extra={'error_message': str(e)})
if not SKIP_CHECK_API_KEY:
# 创建后台任务检查剩余密钥
if keys_to_check_later:
asyncio.create_task(check_remaining_keys_async(keys_to_check_later, initial_invalid_keys))
else:
# 如果没有需要后台检查的key,也要处理初始无效key
current_invalid_keys_str = settings.INVALID_API_KEYS or ""
current_invalid_keys_set = set(k.strip() for k in current_invalid_keys_str.split(',') if k.strip())
new_invalid_keys_set = current_invalid_keys_set.union(set(initial_invalid_keys))
if new_invalid_keys_set != current_invalid_keys_set:
settings.INVALID_API_KEYS = ','.join(sorted(list(new_invalid_keys_set)))
save_settings()
log('info', f"更新初始无效密钥列表完成,总无效密钥数: {len(new_invalid_keys_set)}")
else: # 跳过检查
log('info',"跳过 API 密钥检查")
key_manager.api_keys.extend(keys_to_check_later)
key_manager._reset_key_stack()
# 初始化路由器
init_router(
key_manager,
response_cache_manager,
active_requests_manager,
SAFETY_SETTINGS,
SAFETY_SETTINGS_G2,
first_valid_key,
settings.FAKE_STREAMING,
settings.FAKE_STREAMING_INTERVAL,
settings.PASSWORD,
settings.MAX_REQUESTS_PER_MINUTE,
settings.MAX_REQUESTS_PER_DAY_PER_IP
)
# 初始化仪表盘路由器
init_dashboard_router(
key_manager,
response_cache_manager,
active_requests_manager,
credential_manager_instance
)
# --------------- 异常处理 ---------------
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
from app.utils import translate_error
error_message = translate_error(str(exc))
extra_log_unhandled_exception = {'status_code': 500, 'error_message': error_message}
log('error', f"Unhandled exception: {error_message}", extra=extra_log_unhandled_exception)
return JSONResponse(status_code=500, content=ErrorResponse(message=str(exc), type="internal_error").dict())
# --------------- 路由 ---------------
app.include_router(router)
app.include_router(dashboard_router)
# 挂载静态文件目录
app.mount("/assets", StaticFiles(directory="app/templates/assets"), name="assets")
# 设置根路由路径
dashboard_path = f"/{settings.DASHBOARD_URL}" if settings.DASHBOARD_URL else "/"
@app.get(dashboard_path, response_class=HTMLResponse)
async def root(request: Request):
"""
根路由 - 返回静态 HTML 文件
"""
base_url = str(request.base_url).replace("http", "https")
api_url = f"{base_url}v1" if base_url.endswith("/") else f"{base_url}/v1"
# 直接返回 index.html 文件
return templates.TemplateResponse(
"index.html", {"request": request, "api_url": api_url}
)
|