Spaces:
Sleeping
Sleeping
File size: 6,104 Bytes
543ec94 df1a318 543ec94 7183ec8 543ec94 7183ec8 543ec94 df1a318 543ec94 b7791c2 543ec94 df1a318 543ec94 b7791c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse
import httpx
import logging
import os, json
import re # 用于URL路径处理
from key_selector import KeySelector # 自动选择key
from app.routers import key_management # Import the new router
def get_target_url(url: str) -> str:
"""将url参数变了转换为合法的目标url;from http/ or https/ to http:// or https://"""
url = re.sub(r"^http/", "http://", url)
url = re.sub(r"^https/", "https://", url)
return url
app = FastAPI()
# Include the new key management router
app.include_router(key_management.router, prefix="/api/keys", tags=["Key Management"])
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("uvicorn.error")
# 从环境变量获取配置
# X_Goog_Api_Key = os.getenv("X_Goog_Api_Key", "")
@app.get("/")
async def read_root():
return {"message": "FastAPI Proxy is running"}
@app.post("/v25/{path:path}")
async def proxy(request: Request, path: str):
# 添加流式请求判断逻辑
is_streaming = ":streamGenerateContent" in path.lower()
target_url = get_target_url(path)
method = request.method
headers = {k: v for k, v in request.headers.items()
# if k.lower() not in ["host", "connection", "Postman-Token", "content-length"]}
if k.lower() not in ["host", "content-length"]}
key_selector = KeySelector()
headers["X-Goog-Api-Key"] = key_selector.get_api_key_info()['key_value'] # 从数据库获取API密钥
try:
# 关键修复:禁用KeepAlive防止连接冲突
transport = httpx.AsyncHTTPTransport(retries=3, http1=True)
async with httpx.AsyncClient(
transport=transport,
timeout=httpx.Timeout(300.0, connect=30.0)
) as client:
# 处理请求体
req_content = await request.body()
# 发送请求到上游服务
response = await client.request(
method=method,
url=target_url,
headers=headers,
content=req_content,
follow_redirects=True # 自动处理重定向
)
if is_streaming:
# 流式响应处理
async def stream_generator():
try:
async for chunk in response.aiter_bytes():
yield chunk
except Exception as e:
logger.error(f"Stream interrupted: {str(e)}")
yield json.dumps({"error": "流中断"}).encode()
# 移除冲突头部
headers = dict(response.headers)
headers.pop("Content-Length", None)
return StreamingResponse(
content=stream_generator(),
status_code=response.status_code,
headers=headers,
media_type="application/x-ndjson" # Gemini流式格式
)
else:
# 非流式响应处理
# 解析 JSON 字符串
try:
data = json.loads(response.text)
# 格式化输出 JSON 数据
formatted_json = json.dumps(data, ensure_ascii=False, indent=4)
# return formatted_json
return Response(
content=formatted_json,
media_type="application/json"
)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
except httpx.ConnectError as e:
logger.error(f"Connection failed to {target_url}: {e}")
raise HTTPException(502, f"无法连接到上游服务: {target_url}") # Modified error message
except httpx.ReadTimeout as e:
logger.error(f"Timeout: {e}")
raise HTTPException(504, "上游服务响应超时")
except httpx.HTTPError as e: # 捕获所有HTTP异常
try:
# 安全地获取异常信息
error_type = type(e).__name__
# 尝试获取状态码(如果存在)
status_code = getattr(e, 'response', None) and e.response.status_code
# 安全地获取错误详情
error_detail = ""
try:
# 尝试获取文本响应(限制长度)
if hasattr(e, 'response') and e.response:
error_detail = e.response.text[:500] # 只取前500个字符
except Exception as ex:
error_detail = f"无法获取错误详情: {type(ex).__name__}"
# 安全地记录日志
logger.error(
"HTTP代理错误 | "
f"类型: {error_type} | "
f"状态码: {status_code or 'N/A'} | "
f"目标URL: {target_url} | "
f"详情: {error_detail[:200]}" # 日志中只记录前200字符
)
# 打印到控制台以便调试
print(f"目标URL: {target_url}")
print(f"状态码: {status_code or 'N/A'}")
print(f"错误详情: {error_detail[:500]}")
except Exception as ex:
# 如果记录日志本身出错,使用最安全的方式记录
logger.error(f"记录HTTP错误时发生异常: {type(ex).__name__}")
print(f"严重错误: 记录HTTP错误时发生异常: {ex}")
# 返回用户友好的错误响应
raise HTTPException(
status_code=502,
detail=f"网关服务错误: {error_type} (上游状态: {status_code or '未知'})"
)
except Exception as e:
logger.exception("Unexpected proxy error")
raise HTTPException(500, f"内部服务器错误: {str(e)}")
if __name__ == "__main__":
# In a real application, you would typically run uvicorn here:
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
pass # Placeholder to fix indentation error
|