from fastapi import FastAPI, Request, HTTPException, Response from fastapi.responses import StreamingResponse import httpx import logging import os, json import re # 用于URL路径处理 from pprint import pprint from key_selector import KeySelector def get_target_url(url: str) -> str: """将url参数变了转换为合法的目标url;from http/ or https/ to http:// or https://""" url = re.sub(r"^http/", "http://", url) url = re.sub(r"^https/", "https://", url) return url app = FastAPI() # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("uvicorn.error") # 从环境变量获取配置 # UPSTREAM_HOST = os.getenv("UPSTREAM_HOST", "http://127.0.0.1") # 必须包含端口号 # AUTH_TOKEN = os.getenv("AUTH_TOKEN", "YIG8ANC8q2QxFV_Gf8qwkPdBj2EpsqGqlfc3qvSdg7ksVkZcokOUtQn43XGK0NK3UUdBlIkYzNWefco_Wu4RcKnB0kpNgtZ2nTeqNum0i3fTUEhEcWSlJtT8FQgRK7bi") # 安全认证令牌 X_Goog_Api_Key = os.getenv("X_Goog_Api_Key", "") @app.get("/") async def read_root(): return {"message": "FastAPI Proxy is running"} @app.post("/v25/{path:path}") async def proxy(request: Request, path: str): # 添加流式请求判断逻辑 is_streaming = ":streamGenerateContent" in path.lower() print(f"path: {path}") # path = "https/generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" # target_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" target_url = get_target_url(path) # target_url = f"{target_url}?key={X_Goog_Api_Key}" print(f"target_url: {target_url}") method = request.method headers = {k: v for k, v in request.headers.items() # if k.lower() not in ["host", "connection", "Postman-Token", "content-length"]} if k.lower() not in ["host", "content-length"]} key_selector = KeySelector() # print('apikey',key_selector.get_api_key()['key_value']) headers["X-Goog-Api-Key"] = key_selector.get_api_key()['key_value'] # print(key_selector.api_key_info.key_value) # headers["host"] = "generativelanguage.googleapis.com" try: # 关键修复:禁用KeepAlive防止连接冲突 transport = httpx.AsyncHTTPTransport(retries=3, http1=True) async with httpx.AsyncClient( transport=transport, timeout=httpx.Timeout(300.0, connect=30.0) ) as client: # 处理请求体 req_content = await request.body() print(f"Request headers: {headers}") # print(f"req_content: {req_content}") # req_content_str = req_content.decode('utf-8') # 假设内容是 UTF-8 编码 # print(f"req_content_str: {req_content_str}") print(f'target_url: {target_url}') print(f'method: {method}') # 发送请求到上游服务 response = await client.request( method=method, url=target_url, headers=headers, content=req_content, follow_redirects=True # 自动处理重定向 ) if is_streaming: # 流式响应处理 async def stream_generator(): try: async for chunk in response.aiter_bytes(): yield chunk except Exception as e: logger.error(f"Stream interrupted: {str(e)}") yield json.dumps({"error": "流中断"}).encode() # 移除冲突头部 headers = dict(response.headers) headers.pop("Content-Length", None) return StreamingResponse( content=stream_generator(), status_code=response.status_code, headers=headers, media_type="application/x-ndjson" # Gemini流式格式 ) else: print(f"response.status_code: {response.status_code}") print(f"response.text: {response.text}") # 解析 JSON 字符串 try: data = json.loads(response.text) # 格式化输出 JSON 数据 formatted_json = json.dumps(data, ensure_ascii=False, indent=4) print('formatted_json') print(formatted_json) # return formatted_json return Response( content=formatted_json, media_type="application/json" ) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") # # 详细记录错误响应 # if response.status_code >= 400: # error_detail = response.text[:1000] # 增加错误详情长度 # logger.error(f"Upstream error {response.status_code}: {error_detail}") # # 可选:记录完整响应到文件 # # with open("error.log", "a") as f: # # f.write(f"\n--- {target_url} ---\n{response.text}\n") # # 流式传输响应 # return StreamingResponse( # content=response.aiter_bytes(), # status_code=response.status_code, # headers=dict(response.headers) # ) except httpx.ConnectError as e: logger.error(f"Connection failed to {target_url}: {e}") raise HTTPException(502, f"无法连接到上游服务: {target_url}") # Modified error message except httpx.ReadTimeout as e: logger.error(f"Timeout: {e}") raise HTTPException(504, "上游服务响应超时") except httpx.HTTPError as e: # 捕获所有HTTP异常 print('111111111111111111111') try: # 安全地获取异常信息 error_type = type(e).__name__ # 尝试获取状态码(如果存在) status_code = getattr(e, 'response', None) and e.response.status_code # 安全地获取错误详情 error_detail = "" try: # 尝试获取文本响应(限制长度) if hasattr(e, 'response') and e.response: error_detail = e.response.text[:500] # 只取前500个字符 except Exception as ex: error_detail = f"无法获取错误详情: {type(ex).__name__}" # 安全地记录日志 logger.error( "HTTP代理错误 | " f"类型: {error_type} | " f"状态码: {status_code or 'N/A'} | " f"目标URL: {target_url} | " f"详情: {error_detail[:200]}" # 日志中只记录前200字符 ) # 打印到控制台以便调试 print(f"111111111111111111111 - HTTP代理错误: {error_type}") print(f"目标URL: {target_url}") print(f"状态码: {status_code or 'N/A'}") print(f"错误详情: {error_detail[:500]}") except Exception as ex: # 如果记录日志本身出错,使用最安全的方式记录 logger.error(f"记录HTTP错误时发生异常: {type(ex).__name__}") print(f"严重错误: 记录HTTP错误时发生异常: {ex}") # 返回用户友好的错误响应 raise HTTPException( status_code=502, detail=f"网关服务错误: {error_type} (上游状态: {status_code or '未知'})" ) print('111111111111111111111') # logger.error(f"HTTP error: {str(e)}") raise HTTPException(502, f"网关错误: {str(e)}") except Exception as e: logger.exception("Unexpected proxy error") raise HTTPException(500, f"内部服务器错误: {str(e)}")