Spaces:
Sleeping
Sleeping
File size: 7,898 Bytes
543ec94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse
import httpx
import logging
import os, json
import re # 用于URL路径处理
from pprint import pprint
from key_selector import KeySelector
def get_target_url(url: str) -> str:
"""将url参数变了转换为合法的目标url;from http/ or https/ to http:// or https://"""
url = re.sub(r"^http/", "http://", url)
url = re.sub(r"^https/", "https://", url)
return url
app = FastAPI()
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("uvicorn.error")
# 从环境变量获取配置
# UPSTREAM_HOST = os.getenv("UPSTREAM_HOST", "http://127.0.0.1") # 必须包含端口号
# AUTH_TOKEN = os.getenv("AUTH_TOKEN", "YIG8ANC8q2QxFV_Gf8qwkPdBj2EpsqGqlfc3qvSdg7ksVkZcokOUtQn43XGK0NK3UUdBlIkYzNWefco_Wu4RcKnB0kpNgtZ2nTeqNum0i3fTUEhEcWSlJtT8FQgRK7bi") # 安全认证令牌
X_Goog_Api_Key = os.getenv("X_Goog_Api_Key", "")
@app.get("/")
async def read_root():
return {"message": "FastAPI Proxy is running"}
@app.post("/v25/{path:path}")
async def proxy(request: Request, path: str):
# 添加流式请求判断逻辑
is_streaming = ":streamGenerateContent" in path.lower()
print(f"path: {path}")
# path = "https/generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
# target_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
target_url = get_target_url(path)
# target_url = f"{target_url}?key={X_Goog_Api_Key}"
print(f"target_url: {target_url}")
method = request.method
headers = {k: v for k, v in request.headers.items()
# if k.lower() not in ["host", "connection", "Postman-Token", "content-length"]}
if k.lower() not in ["host", "content-length"]}
key_selector = KeySelector()
# print('apikey',key_selector.get_api_key()['key_value'])
headers["X-Goog-Api-Key"] = key_selector.get_api_key()['key_value']
# print(key_selector.api_key_info.key_value)
# headers["host"] = "generativelanguage.googleapis.com"
try:
# 关键修复:禁用KeepAlive防止连接冲突
transport = httpx.AsyncHTTPTransport(retries=3, http1=True)
async with httpx.AsyncClient(
transport=transport,
timeout=httpx.Timeout(300.0, connect=30.0)
) as client:
# 处理请求体
req_content = await request.body()
print(f"Request headers: {headers}")
# print(f"req_content: {req_content}")
# req_content_str = req_content.decode('utf-8') # 假设内容是 UTF-8 编码
# print(f"req_content_str: {req_content_str}")
print(f'target_url: {target_url}')
print(f'method: {method}')
# 发送请求到上游服务
response = await client.request(
method=method,
url=target_url,
headers=headers,
content=req_content,
follow_redirects=True # 自动处理重定向
)
if is_streaming:
# 流式响应处理
async def stream_generator():
try:
async for chunk in response.aiter_bytes():
yield chunk
except Exception as e:
logger.error(f"Stream interrupted: {str(e)}")
yield json.dumps({"error": "流中断"}).encode()
# 移除冲突头部
headers = dict(response.headers)
headers.pop("Content-Length", None)
return StreamingResponse(
content=stream_generator(),
status_code=response.status_code,
headers=headers,
media_type="application/x-ndjson" # Gemini流式格式
)
else:
print(f"response.status_code: {response.status_code}")
print(f"response.text: {response.text}")
# 解析 JSON 字符串
try:
data = json.loads(response.text)
# 格式化输出 JSON 数据
formatted_json = json.dumps(data, ensure_ascii=False, indent=4)
print('formatted_json')
print(formatted_json)
# return formatted_json
return Response(
content=formatted_json,
media_type="application/json"
)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
# # 详细记录错误响应
# if response.status_code >= 400:
# error_detail = response.text[:1000] # 增加错误详情长度
# logger.error(f"Upstream error {response.status_code}: {error_detail}")
# # 可选:记录完整响应到文件
# # with open("error.log", "a") as f:
# # f.write(f"\n--- {target_url} ---\n{response.text}\n")
# # 流式传输响应
# return StreamingResponse(
# content=response.aiter_bytes(),
# status_code=response.status_code,
# headers=dict(response.headers)
# )
except httpx.ConnectError as e:
logger.error(f"Connection failed to {target_url}: {e}")
raise HTTPException(502, f"无法连接到上游服务: {target_url}") # Modified error message
except httpx.ReadTimeout as e:
logger.error(f"Timeout: {e}")
raise HTTPException(504, "上游服务响应超时")
except httpx.HTTPError as e: # 捕获所有HTTP异常
print('111111111111111111111')
try:
# 安全地获取异常信息
error_type = type(e).__name__
# 尝试获取状态码(如果存在)
status_code = getattr(e, 'response', None) and e.response.status_code
# 安全地获取错误详情
error_detail = ""
try:
# 尝试获取文本响应(限制长度)
if hasattr(e, 'response') and e.response:
error_detail = e.response.text[:500] # 只取前500个字符
except Exception as ex:
error_detail = f"无法获取错误详情: {type(ex).__name__}"
# 安全地记录日志
logger.error(
"HTTP代理错误 | "
f"类型: {error_type} | "
f"状态码: {status_code or 'N/A'} | "
f"目标URL: {target_url} | "
f"详情: {error_detail[:200]}" # 日志中只记录前200字符
)
# 打印到控制台以便调试
print(f"111111111111111111111 - HTTP代理错误: {error_type}")
print(f"目标URL: {target_url}")
print(f"状态码: {status_code or 'N/A'}")
print(f"错误详情: {error_detail[:500]}")
except Exception as ex:
# 如果记录日志本身出错,使用最安全的方式记录
logger.error(f"记录HTTP错误时发生异常: {type(ex).__name__}")
print(f"严重错误: 记录HTTP错误时发生异常: {ex}")
# 返回用户友好的错误响应
raise HTTPException(
status_code=502,
detail=f"网关服务错误: {error_type} (上游状态: {status_code or '未知'})"
)
print('111111111111111111111')
# logger.error(f"HTTP error: {str(e)}")
raise HTTPException(502, f"网关错误: {str(e)}")
except Exception as e:
logger.exception("Unexpected proxy error")
raise HTTPException(500, f"内部服务器错误: {str(e)}")
|