File size: 6,104 Bytes
543ec94
 
 
 
 
 
df1a318
543ec94
7183ec8
 
543ec94
 
 
 
 
 
 
 
7183ec8
 
 
543ec94
 
 
 
 
df1a318
543ec94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7791c2
543ec94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df1a318
543ec94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7791c2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse
import httpx
import logging
import os, json
import re  # 用于URL路径处理
from key_selector import KeySelector # 自动选择key

from app.routers import key_management # Import the new router

def get_target_url(url: str) -> str:
    """将url参数变了转换为合法的目标url;from http/ or https/ to http:// or https://"""
    url = re.sub(r"^http/", "http://", url)
    url = re.sub(r"^https/", "https://", url)
    return url

app = FastAPI()

# Include the new key management router
app.include_router(key_management.router, prefix="/api/keys", tags=["Key Management"])

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("uvicorn.error")

# 从环境变量获取配置
# X_Goog_Api_Key = os.getenv("X_Goog_Api_Key", "")

@app.get("/")
async def read_root():
    return {"message": "FastAPI Proxy is running"}

@app.post("/v25/{path:path}")
async def proxy(request: Request, path: str):
    # 添加流式请求判断逻辑
    is_streaming = ":streamGenerateContent" in path.lower()
    target_url = get_target_url(path)
    method = request.method
    headers = {k: v for k, v in request.headers.items()
        # if k.lower() not in ["host", "connection", "Postman-Token", "content-length"]}
        if k.lower() not in ["host", "content-length"]}
    
    key_selector = KeySelector()
    headers["X-Goog-Api-Key"] = key_selector.get_api_key_info()['key_value'] # 从数据库获取API密钥

    try:
        # 关键修复:禁用KeepAlive防止连接冲突
        transport = httpx.AsyncHTTPTransport(retries=3, http1=True)
        async with httpx.AsyncClient(
            transport=transport,
            timeout=httpx.Timeout(300.0, connect=30.0)
        ) as client:
            # 处理请求体
            req_content = await request.body()
            # 发送请求到上游服务
            response = await client.request(
                method=method,
                url=target_url,
                headers=headers,
                content=req_content,
                follow_redirects=True  # 自动处理重定向
            )
            if is_streaming:
                # 流式响应处理
                async def stream_generator():
                    try:
                        async for chunk in response.aiter_bytes():
                            yield chunk
                    except Exception as e:
                        logger.error(f"Stream interrupted: {str(e)}")
                        yield json.dumps({"error": "流中断"}).encode()
            
                # 移除冲突头部
                headers = dict(response.headers)
                headers.pop("Content-Length", None)
                return StreamingResponse(
                    content=stream_generator(),
                    status_code=response.status_code,
                    headers=headers,
                    media_type="application/x-ndjson"  # Gemini流式格式
                )
            else:
                # 非流式响应处理
                # 解析 JSON 字符串
                try:
                    data = json.loads(response.text)
                    # 格式化输出 JSON 数据
                    formatted_json = json.dumps(data, ensure_ascii=False, indent=4)
                    # return formatted_json
                    return Response(
                        content=formatted_json,
                        media_type="application/json"
                    )
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")
    except httpx.ConnectError as e:
        logger.error(f"Connection failed to {target_url}: {e}")
        raise HTTPException(502, f"无法连接到上游服务: {target_url}") # Modified error message
    except httpx.ReadTimeout as e:
        logger.error(f"Timeout: {e}")
        raise HTTPException(504, "上游服务响应超时")
    except httpx.HTTPError as e:  # 捕获所有HTTP异常
        try:
            # 安全地获取异常信息
            error_type = type(e).__name__
            # 尝试获取状态码(如果存在)
            status_code = getattr(e, 'response', None) and e.response.status_code
            # 安全地获取错误详情
            error_detail = ""
            try:
                # 尝试获取文本响应(限制长度)
                if hasattr(e, 'response') and e.response:
                    error_detail = e.response.text[:500]  # 只取前500个字符
            except Exception as ex:
                error_detail = f"无法获取错误详情: {type(ex).__name__}"
            # 安全地记录日志
            logger.error(
                "HTTP代理错误 | "
                f"类型: {error_type} | "
                f"状态码: {status_code or 'N/A'} | "
                f"目标URL: {target_url} | "
                f"详情: {error_detail[:200]}"  # 日志中只记录前200字符
            )
            # 打印到控制台以便调试
            print(f"目标URL: {target_url}")
            print(f"状态码: {status_code or 'N/A'}")
            print(f"错误详情: {error_detail[:500]}")
        except Exception as ex:
            # 如果记录日志本身出错,使用最安全的方式记录
            logger.error(f"记录HTTP错误时发生异常: {type(ex).__name__}")
            print(f"严重错误: 记录HTTP错误时发生异常: {ex}")

        # 返回用户友好的错误响应
        raise HTTPException(
            status_code=502,
            detail=f"网关服务错误: {error_type} (上游状态: {status_code or '未知'})"
        )
    except Exception as e:
        logger.exception("Unexpected proxy error")
        raise HTTPException(500, f"内部服务器错误: {str(e)}")

if __name__ == "__main__":
    # In a real application, you would typically run uvicorn here:
    # import uvicorn
    # uvicorn.run(app, host="0.0.0.0", port=8000)
    pass # Placeholder to fix indentation error