tanbushi's picture
Sun Jun 8 16:41:37 CST 2025
7183ec8
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse
import httpx
import logging
import os, json
import re # 用于URL路径处理
from key_selector import KeySelector # 自动选择key
from app.routers import key_management # Import the new router
def get_target_url(url: str) -> str:
"""将url参数变了转换为合法的目标url;from http/ or https/ to http:// or https://"""
url = re.sub(r"^http/", "http://", url)
url = re.sub(r"^https/", "https://", url)
return url
app = FastAPI()
# Include the new key management router
app.include_router(key_management.router, prefix="/api/keys", tags=["Key Management"])
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("uvicorn.error")
# 从环境变量获取配置
# X_Goog_Api_Key = os.getenv("X_Goog_Api_Key", "")
@app.get("/")
async def read_root():
return {"message": "FastAPI Proxy is running"}
@app.post("/v25/{path:path}")
async def proxy(request: Request, path: str):
# 添加流式请求判断逻辑
is_streaming = ":streamGenerateContent" in path.lower()
target_url = get_target_url(path)
method = request.method
headers = {k: v for k, v in request.headers.items()
# if k.lower() not in ["host", "connection", "Postman-Token", "content-length"]}
if k.lower() not in ["host", "content-length"]}
key_selector = KeySelector()
headers["X-Goog-Api-Key"] = key_selector.get_api_key_info()['key_value'] # 从数据库获取API密钥
try:
# 关键修复:禁用KeepAlive防止连接冲突
transport = httpx.AsyncHTTPTransport(retries=3, http1=True)
async with httpx.AsyncClient(
transport=transport,
timeout=httpx.Timeout(300.0, connect=30.0)
) as client:
# 处理请求体
req_content = await request.body()
# 发送请求到上游服务
response = await client.request(
method=method,
url=target_url,
headers=headers,
content=req_content,
follow_redirects=True # 自动处理重定向
)
if is_streaming:
# 流式响应处理
async def stream_generator():
try:
async for chunk in response.aiter_bytes():
yield chunk
except Exception as e:
logger.error(f"Stream interrupted: {str(e)}")
yield json.dumps({"error": "流中断"}).encode()
# 移除冲突头部
headers = dict(response.headers)
headers.pop("Content-Length", None)
return StreamingResponse(
content=stream_generator(),
status_code=response.status_code,
headers=headers,
media_type="application/x-ndjson" # Gemini流式格式
)
else:
# 非流式响应处理
# 解析 JSON 字符串
try:
data = json.loads(response.text)
# 格式化输出 JSON 数据
formatted_json = json.dumps(data, ensure_ascii=False, indent=4)
# return formatted_json
return Response(
content=formatted_json,
media_type="application/json"
)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
except httpx.ConnectError as e:
logger.error(f"Connection failed to {target_url}: {e}")
raise HTTPException(502, f"无法连接到上游服务: {target_url}") # Modified error message
except httpx.ReadTimeout as e:
logger.error(f"Timeout: {e}")
raise HTTPException(504, "上游服务响应超时")
except httpx.HTTPError as e: # 捕获所有HTTP异常
try:
# 安全地获取异常信息
error_type = type(e).__name__
# 尝试获取状态码(如果存在)
status_code = getattr(e, 'response', None) and e.response.status_code
# 安全地获取错误详情
error_detail = ""
try:
# 尝试获取文本响应(限制长度)
if hasattr(e, 'response') and e.response:
error_detail = e.response.text[:500] # 只取前500个字符
except Exception as ex:
error_detail = f"无法获取错误详情: {type(ex).__name__}"
# 安全地记录日志
logger.error(
"HTTP代理错误 | "
f"类型: {error_type} | "
f"状态码: {status_code or 'N/A'} | "
f"目标URL: {target_url} | "
f"详情: {error_detail[:200]}" # 日志中只记录前200字符
)
# 打印到控制台以便调试
print(f"目标URL: {target_url}")
print(f"状态码: {status_code or 'N/A'}")
print(f"错误详情: {error_detail[:500]}")
except Exception as ex:
# 如果记录日志本身出错,使用最安全的方式记录
logger.error(f"记录HTTP错误时发生异常: {type(ex).__name__}")
print(f"严重错误: 记录HTTP错误时发生异常: {ex}")
# 返回用户友好的错误响应
raise HTTPException(
status_code=502,
detail=f"网关服务错误: {error_type} (上游状态: {status_code or '未知'})"
)
except Exception as e:
logger.exception("Unexpected proxy error")
raise HTTPException(500, f"内部服务器错误: {str(e)}")
if __name__ == "__main__":
# In a real application, you would typically run uvicorn here:
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
pass # Placeholder to fix indentation error