File size: 7,898 Bytes
543ec94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import StreamingResponse
import httpx
import logging
import os, json
import re  # 用于URL路径处理
from pprint import pprint
from key_selector import KeySelector

def get_target_url(url: str) -> str:
    """将url参数变了转换为合法的目标url;from http/ or https/ to http:// or https://"""
    url = re.sub(r"^http/", "http://", url)
    url = re.sub(r"^https/", "https://", url)
    return url

app = FastAPI()

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("uvicorn.error")

# 从环境变量获取配置
# UPSTREAM_HOST = os.getenv("UPSTREAM_HOST", "http://127.0.0.1")  # 必须包含端口号
# AUTH_TOKEN = os.getenv("AUTH_TOKEN", "YIG8ANC8q2QxFV_Gf8qwkPdBj2EpsqGqlfc3qvSdg7ksVkZcokOUtQn43XGK0NK3UUdBlIkYzNWefco_Wu4RcKnB0kpNgtZ2nTeqNum0i3fTUEhEcWSlJtT8FQgRK7bi")  # 安全认证令牌
X_Goog_Api_Key = os.getenv("X_Goog_Api_Key", "")

@app.get("/")
async def read_root():
    return {"message": "FastAPI Proxy is running"}

@app.post("/v25/{path:path}")
async def proxy(request: Request, path: str):
    # 添加流式请求判断逻辑
    is_streaming = ":streamGenerateContent" in path.lower()
    print(f"path: {path}")
    # path = "https/generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
    # target_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
    target_url = get_target_url(path)
    # target_url = f"{target_url}?key={X_Goog_Api_Key}"
    print(f"target_url: {target_url}")
    method = request.method
    headers = {k: v for k, v in request.headers.items()
        # if k.lower() not in ["host", "connection", "Postman-Token", "content-length"]}
        if k.lower() not in ["host", "content-length"]}
    
    key_selector = KeySelector()
    # print('apikey',key_selector.get_api_key()['key_value'])

    headers["X-Goog-Api-Key"] = key_selector.get_api_key()['key_value']
    # print(key_selector.api_key_info.key_value)
    # headers["host"] = "generativelanguage.googleapis.com"

    try:
        # 关键修复:禁用KeepAlive防止连接冲突
        transport = httpx.AsyncHTTPTransport(retries=3, http1=True)

        async with httpx.AsyncClient(
            transport=transport,
            timeout=httpx.Timeout(300.0, connect=30.0)
        ) as client:
            # 处理请求体
            req_content = await request.body()
            print(f"Request headers: {headers}")
            # print(f"req_content: {req_content}")
            # req_content_str = req_content.decode('utf-8')  # 假设内容是 UTF-8 编码
            # print(f"req_content_str: {req_content_str}")
            print(f'target_url: {target_url}')
            print(f'method: {method}')
            # 发送请求到上游服务
            response = await client.request(
                method=method,
                url=target_url,
                headers=headers,
                content=req_content,
                follow_redirects=True  # 自动处理重定向
            )

            if is_streaming:
                # 流式响应处理
                async def stream_generator():
                    try:
                        async for chunk in response.aiter_bytes():
                            yield chunk
                    except Exception as e:
                        logger.error(f"Stream interrupted: {str(e)}")
                        yield json.dumps({"error": "流中断"}).encode()
            
                # 移除冲突头部
                headers = dict(response.headers)
                headers.pop("Content-Length", None)
            
                return StreamingResponse(
                    content=stream_generator(),
                    status_code=response.status_code,
                    headers=headers,
                    media_type="application/x-ndjson"  # Gemini流式格式
                )
            else:

                print(f"response.status_code: {response.status_code}")
                print(f"response.text: {response.text}")
                # 解析 JSON 字符串
                try:
                    data = json.loads(response.text)
                    # 格式化输出 JSON 数据
                    formatted_json = json.dumps(data, ensure_ascii=False, indent=4)
                    print('formatted_json')
                    print(formatted_json)
                    # return formatted_json
                    return Response(
                        content=formatted_json,
                        media_type="application/json"
                    )
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON: {e}")


        # # 详细记录错误响应
        # if response.status_code >= 400:
        #     error_detail = response.text[:1000]  # 增加错误详情长度
        #     logger.error(f"Upstream error {response.status_code}: {error_detail}")
        #     # 可选:记录完整响应到文件
        #     # with open("error.log", "a") as f:
        #     #     f.write(f"\n--- {target_url} ---\n{response.text}\n")

        # # 流式传输响应
        # return StreamingResponse(
        #     content=response.aiter_bytes(),
        #     status_code=response.status_code,
        #     headers=dict(response.headers)
        # )

    except httpx.ConnectError as e:
        logger.error(f"Connection failed to {target_url}: {e}")
        raise HTTPException(502, f"无法连接到上游服务: {target_url}") # Modified error message
    except httpx.ReadTimeout as e:
        logger.error(f"Timeout: {e}")
        raise HTTPException(504, "上游服务响应超时")
    except httpx.HTTPError as e:  # 捕获所有HTTP异常
        print('111111111111111111111')

        try:
            # 安全地获取异常信息
            error_type = type(e).__name__

            # 尝试获取状态码(如果存在)
            status_code = getattr(e, 'response', None) and e.response.status_code

            # 安全地获取错误详情
            error_detail = ""
            try:
                # 尝试获取文本响应(限制长度)
                if hasattr(e, 'response') and e.response:
                    error_detail = e.response.text[:500]  # 只取前500个字符
            except Exception as ex:
                error_detail = f"无法获取错误详情: {type(ex).__name__}"

            # 安全地记录日志
            logger.error(
                "HTTP代理错误 | "
                f"类型: {error_type} | "
                f"状态码: {status_code or 'N/A'} | "
                f"目标URL: {target_url} | "
                f"详情: {error_detail[:200]}"  # 日志中只记录前200字符
            )

            # 打印到控制台以便调试
            print(f"111111111111111111111 - HTTP代理错误: {error_type}")
            print(f"目标URL: {target_url}")
            print(f"状态码: {status_code or 'N/A'}")
            print(f"错误详情: {error_detail[:500]}")

        except Exception as ex:
            # 如果记录日志本身出错,使用最安全的方式记录
            logger.error(f"记录HTTP错误时发生异常: {type(ex).__name__}")
            print(f"严重错误: 记录HTTP错误时发生异常: {ex}")

        # 返回用户友好的错误响应
        raise HTTPException(
            status_code=502,
            detail=f"网关服务错误: {error_type} (上游状态: {status_code or '未知'})"
        )


        print('111111111111111111111')
        # logger.error(f"HTTP error: {str(e)}")

        raise HTTPException(502, f"网关错误: {str(e)}")
    except Exception as e:
        logger.exception("Unexpected proxy error")
        raise HTTPException(500, f"内部服务器错误: {str(e)}")