File size: 8,360 Bytes
76b9762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from app.config.config import settings
from app.log.logger import get_main_logger
import re

logger = get_main_logger()

class SmartRoutingMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)
        # 简化的路由规则 - 直接根据检测结果路由
        pass

    async def dispatch(self, request: Request, call_next):
        if not settings.URL_NORMALIZATION_ENABLED:
            return await call_next(request)
        logger.debug(f"request: {request}")
        original_path = str(request.url.path)
        method = request.method
        
        # 尝试修复URL
        fixed_path, fix_info = self.fix_request_url(original_path, method, request)

        if fixed_path != original_path:
            logger.info(f"URL fixed: {method} {original_path}{fixed_path}")
            if fix_info:
                logger.debug(f"Fix details: {fix_info}")

            # 重写请求路径
            request.scope["path"] = fixed_path
            request.scope["raw_path"] = fixed_path.encode()
        
        return await call_next(request)

    def fix_request_url(self, path: str, method: str, request: Request) -> tuple:
        """简化的URL修复逻辑"""

        # 首先检查是否已经是正确的格式,如果是则不处理
        if self.is_already_correct_format(path):
            return path, None

        # 1. 最高优先级:包含generateContent → Gemini格式
        if "generatecontent" in path.lower() or "v1beta/models" in path.lower():
            return self.fix_gemini_by_operation(path, method, request)

        # 2. 第二优先级:包含/openai/ → OpenAI格式
        if "/openai/" in path.lower():
            return self.fix_openai_by_operation(path, method)

        # 3. 第三优先级:包含/v1/ → v1格式
        if "/v1/" in path.lower():
            return self.fix_v1_by_operation(path, method)

        # 4. 第四优先级:包含/chat/completions → chat功能
        if "/chat/completions" in path.lower():
            return "/v1/chat/completions", {"type": "v1_chat"}

        # 5. 默认:原样传递
        return path, None

    def is_already_correct_format(self, path: str) -> bool:
        """检查是否已经是正确的API格式"""
        # 检查是否已经是正确的端点格式
        correct_patterns = [
            r"^/v1beta/models/[^/:]+:(generate|streamGenerate)Content$",  # Gemini原生
            r"^/gemini/v1beta/models/[^/:]+:(generate|streamGenerate)Content$",  # Gemini带前缀
            r"^/v1beta/models$",  # Gemini模型列表
            r"^/gemini/v1beta/models$",  # Gemini带前缀的模型列表
            r"^/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$",  # v1格式
            r"^/openai/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$",  # OpenAI格式
            r"^/hf/v1/(chat/completions|models|embeddings|images/generations|audio/speech)$",  # HF格式
            r"^/vertex-express/v1beta/models/[^/:]+:(generate|streamGenerate)Content$",  # Vertex Express Gemini格式
            r"^/vertex-express/v1beta/models$",  # Vertex Express模型列表
            r"^/vertex-express/v1/(chat/completions|models|embeddings|images/generations)$",  # Vertex Express OpenAI格式
        ]

        for pattern in correct_patterns:
            if re.match(pattern, path):
                return True

        return False

    def fix_gemini_by_operation(
        self, path: str, method: str, request: Request
    ) -> tuple:
        """根据Gemini操作修复,考虑端点偏好"""
        if method == "GET":
            return "/v1beta/models", {
                "role": "gemini_models",
            }

        # 提取模型名称
        try:
            model_name = self.extract_model_name(path, request)
        except ValueError:
            # 无法提取模型名称,返回原路径不做处理
            return path, None

        # 检测是否为流式请求
        is_stream = self.detect_stream_request(path, request)

        # 检查是否有vertex-express偏好
        if "/vertex-express/" in path.lower():
            if is_stream:
                target_url = (
                    f"/vertex-express/v1beta/models/{model_name}:streamGenerateContent"
                )
            else:
                target_url = (
                    f"/vertex-express/v1beta/models/{model_name}:generateContent"
                )

            fix_info = {
                "rule": (
                    "vertex_express_generate"
                    if not is_stream
                    else "vertex_express_stream"
                ),
                "preference": "vertex_express_format",
                "is_stream": is_stream,
                "model": model_name,
            }
        else:
            # 标准Gemini端点
            if is_stream:
                target_url = f"/v1beta/models/{model_name}:streamGenerateContent"
            else:
                target_url = f"/v1beta/models/{model_name}:generateContent"

            fix_info = {
                "rule": "gemini_generate" if not is_stream else "gemini_stream",
                "preference": "gemini_format",
                "is_stream": is_stream,
                "model": model_name,
            }

        return target_url, fix_info

    def fix_openai_by_operation(self, path: str, method: str) -> tuple:
        """根据操作类型修复OpenAI格式"""
        if method == "POST":
            if "chat" in path.lower() or "completion" in path.lower():
                return "/openai/v1/chat/completions", {"type": "openai_chat"}
            elif "embedding" in path.lower():
                return "/openai/v1/embeddings", {"type": "openai_embeddings"}
            elif "image" in path.lower():
                return "/openai/v1/images/generations", {"type": "openai_images"}
            elif "audio" in path.lower():
                return "/openai/v1/audio/speech", {"type": "openai_audio"}
        elif method == "GET":
            if "model" in path.lower():
                return "/openai/v1/models", {"type": "openai_models"}

        return path, None

    def fix_v1_by_operation(self, path: str, method: str) -> tuple:
        """根据操作类型修复v1格式"""
        if method == "POST":
            if "chat" in path.lower() or "completion" in path.lower():
                return "/v1/chat/completions", {"type": "v1_chat"}
            elif "embedding" in path.lower():
                return "/v1/embeddings", {"type": "v1_embeddings"}
            elif "image" in path.lower():
                return "/v1/images/generations", {"type": "v1_images"}
            elif "audio" in path.lower():
                return "/v1/audio/speech", {"type": "v1_audio"}
        elif method == "GET":
            if "model" in path.lower():
                return "/v1/models", {"type": "v1_models"}

        return path, None

    def detect_stream_request(self, path: str, request: Request) -> bool:
        """检测是否为流式请求"""
        # 1. 路径中包含stream关键词
        if "stream" in path.lower():
            return True

        # 2. 查询参数
        if request.query_params.get("stream") == "true":
            return True

        return False

    def extract_model_name(self, path: str, request: Request) -> str:
        """从请求中提取模型名称,用于构建Gemini API URL"""
        # 1. 从请求体中提取
        try:
            if hasattr(request, "_body") and request._body:
                import json

                body = json.loads(request._body.decode())
                if "model" in body and body["model"]:
                    return body["model"]
        except Exception:
            pass

        # 2. 从查询参数中提取
        model_param = request.query_params.get("model")
        if model_param:
            return model_param

        # 3. 从路径中提取(用于已包含模型名称的路径)
        match = re.search(r"/models/([^/:]+)", path, re.IGNORECASE)
        if match:
            return match.group(1)

        # 4. 如果无法提取模型名称,抛出异常
        raise ValueError("Unable to extract model name from request")