Commit
·
d342ca5
1
Parent(s):
0527a50
added openai mode for express
Browse files- app/model_loader.py +1 -3
- app/openai_handler.py +148 -44
- app/routes/chat_api.py +10 -5
- app/routes/models_api.py +47 -107
app/model_loader.py
CHANGED
@@ -33,11 +33,9 @@ async def fetch_and_parse_models_config() -> Optional[Dict[str, List[str]]]:
|
|
33 |
print("Successfully fetched and parsed model configuration.")
|
34 |
|
35 |
# Add [EXPRESS] prefix to express models
|
36 |
-
prefixed_express_models = [f"[EXPRESS] {model_name}" for model_name in data["vertex_express_models"]]
|
37 |
-
|
38 |
return {
|
39 |
"vertex_models": data["vertex_models"],
|
40 |
-
"vertex_express_models":
|
41 |
}
|
42 |
else:
|
43 |
print(f"ERROR: Fetched model configuration has an invalid structure: {data}")
|
|
|
33 |
print("Successfully fetched and parsed model configuration.")
|
34 |
|
35 |
# Add [EXPRESS] prefix to express models
|
|
|
|
|
36 |
return {
|
37 |
"vertex_models": data["vertex_models"],
|
38 |
+
"vertex_express_models": data["vertex_express_models"]
|
39 |
}
|
40 |
else:
|
41 |
print(f"ERROR: Fetched model configuration has an invalid structure: {data}")
|
app/openai_handler.py
CHANGED
@@ -5,7 +5,8 @@ This module encapsulates all OpenAI-specific logic that was previously in chat_a
|
|
5 |
import json
|
6 |
import time
|
7 |
import asyncio
|
8 |
-
|
|
|
9 |
|
10 |
from fastapi.responses import JSONResponse, StreamingResponse
|
11 |
import openai
|
@@ -21,13 +22,104 @@ from api_helpers import (
|
|
21 |
)
|
22 |
from message_processing import extract_reasoning_by_tags
|
23 |
from credentials_manager import _refresh_auth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
class OpenAIDirectHandler:
|
27 |
"""Handles OpenAI Direct mode operations including client creation and response processing."""
|
28 |
|
29 |
-
def __init__(self, credential_manager):
|
30 |
self.credential_manager = credential_manager
|
|
|
31 |
self.safety_settings = [
|
32 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
33 |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
@@ -35,7 +127,7 @@ class OpenAIDirectHandler:
|
|
35 |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
36 |
{"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
|
37 |
]
|
38 |
-
|
39 |
def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
|
40 |
"""Create an OpenAI client configured for Vertex AI endpoint."""
|
41 |
endpoint_url = (
|
@@ -80,7 +172,7 @@ class OpenAIDirectHandler:
|
|
80 |
|
81 |
async def handle_streaming_response(
|
82 |
self,
|
83 |
-
openai_client: openai.AsyncOpenAI
|
84 |
openai_params: Dict[str, Any],
|
85 |
openai_extra_body: Dict[str, Any],
|
86 |
request: OpenAIRequest
|
@@ -107,7 +199,7 @@ class OpenAIDirectHandler:
|
|
107 |
|
108 |
async def _true_stream_generator(
|
109 |
self,
|
110 |
-
openai_client: openai.AsyncOpenAI
|
111 |
openai_params: Dict[str, Any],
|
112 |
openai_extra_body: Dict[str, Any],
|
113 |
request: OpenAIRequest
|
@@ -136,6 +228,7 @@ class OpenAIDirectHandler:
|
|
136 |
delta = choices[0].get('delta')
|
137 |
if delta and isinstance(delta, dict):
|
138 |
# Always remove extra_content if present
|
|
|
139 |
if 'extra_content' in delta:
|
140 |
del delta['extra_content']
|
141 |
|
@@ -242,7 +335,7 @@ class OpenAIDirectHandler:
|
|
242 |
|
243 |
async def handle_non_streaming_response(
|
244 |
self,
|
245 |
-
openai_client: openai.AsyncOpenAI
|
246 |
openai_params: Dict[str, Any],
|
247 |
openai_extra_body: Dict[str, Any],
|
248 |
request: OpenAIRequest
|
@@ -296,44 +389,55 @@ class OpenAIDirectHandler:
|
|
296 |
content=create_openai_error_response(500, error_msg, "server_error")
|
297 |
)
|
298 |
|
299 |
-
async def process_request(self, request: OpenAIRequest, base_model_name: str):
|
300 |
"""Main entry point for processing OpenAI Direct mode requests."""
|
301 |
-
print(f"INFO: Using OpenAI Direct Path for model: {request.model}")
|
302 |
-
|
303 |
-
# Get credentials
|
304 |
-
rotated_credentials, rotated_project_id = self.credential_manager.get_credentials()
|
305 |
-
|
306 |
-
if not rotated_credentials or not rotated_project_id:
|
307 |
-
error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully."
|
308 |
-
print(f"ERROR: {error_msg}")
|
309 |
-
return JSONResponse(
|
310 |
-
status_code=500,
|
311 |
-
content=create_openai_error_response(500, error_msg, "server_error")
|
312 |
-
)
|
313 |
-
|
314 |
-
print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
|
315 |
-
gcp_token = _refresh_auth(rotated_credentials)
|
316 |
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
print(f"ERROR: {error_msg}")
|
320 |
-
return JSONResponse(
|
321 |
-
status_code=500,
|
322 |
-
content=create_openai_error_response(500, error_msg, "server_error")
|
323 |
-
)
|
324 |
-
|
325 |
-
# Create client and prepare parameters
|
326 |
-
openai_client = self.create_openai_client(rotated_project_id, gcp_token)
|
327 |
-
model_id = f"google/{base_model_name}"
|
328 |
-
openai_params = self.prepare_openai_params(request, model_id)
|
329 |
-
openai_extra_body = self.prepare_extra_body()
|
330 |
-
|
331 |
-
# Handle streaming vs non-streaming
|
332 |
-
if request.stream:
|
333 |
-
return await self.handle_streaming_response(
|
334 |
-
openai_client, openai_params, openai_extra_body, request
|
335 |
-
)
|
336 |
-
else:
|
337 |
-
return await self.handle_non_streaming_response(
|
338 |
-
openai_client, openai_params, openai_extra_body, request
|
339 |
-
)
|
|
|
5 |
import json
|
6 |
import time
|
7 |
import asyncio
|
8 |
+
import httpx
|
9 |
+
from typing import Dict, Any, AsyncGenerator, Optional
|
10 |
|
11 |
from fastapi.responses import JSONResponse, StreamingResponse
|
12 |
import openai
|
|
|
22 |
)
|
23 |
from message_processing import extract_reasoning_by_tags
|
24 |
from credentials_manager import _refresh_auth
|
25 |
+
from project_id_discovery import discover_project_id
|
26 |
+
|
27 |
+
|
28 |
+
# Wrapper classes to mimic OpenAI SDK responses for direct httpx calls
|
29 |
+
class FakeChatCompletionChunk:
|
30 |
+
"""A fake ChatCompletionChunk to wrap the dictionary from a direct API stream."""
|
31 |
+
def __init__(self, data: Dict[str, Any]):
|
32 |
+
self._data = data
|
33 |
+
|
34 |
+
def model_dump(self, exclude_unset=True, exclude_none=True) -> Dict[str, Any]:
|
35 |
+
return self._data
|
36 |
+
|
37 |
+
class FakeChatCompletion:
|
38 |
+
"""A fake ChatCompletion to wrap the dictionary from a direct non-streaming API call."""
|
39 |
+
def __init__(self, data: Dict[str, Any]):
|
40 |
+
self._data = data
|
41 |
+
|
42 |
+
def model_dump(self, exclude_unset=True, exclude_none=True) -> Dict[str, Any]:
|
43 |
+
return self._data
|
44 |
+
|
45 |
+
class ExpressClientWrapper:
|
46 |
+
"""
|
47 |
+
A wrapper that mimics the openai.AsyncOpenAI client interface but uses direct
|
48 |
+
httpx calls for Vertex AI Express Mode. This allows it to be used with the
|
49 |
+
existing response handling logic.
|
50 |
+
"""
|
51 |
+
def __init__(self, project_id: str, api_key: str, location: str = "global"):
|
52 |
+
self.project_id = project_id
|
53 |
+
self.api_key = api_key
|
54 |
+
self.location = location
|
55 |
+
self.base_url = f"https://aiplatform.googleapis.com/v1beta1/projects/{self.project_id}/locations/{self.location}/endpoints/openapi"
|
56 |
+
|
57 |
+
# The 'chat.completions' structure mimics the real OpenAI client
|
58 |
+
self.chat = self
|
59 |
+
self.completions = self
|
60 |
+
|
61 |
+
async def _stream_generator(self, response: httpx.Response) -> AsyncGenerator[FakeChatCompletionChunk, None]:
|
62 |
+
"""Processes the SSE stream from httpx and yields fake chunk objects."""
|
63 |
+
async for line in response.aiter_lines():
|
64 |
+
if line.startswith("data:"):
|
65 |
+
json_str = line[len("data: "):].strip()
|
66 |
+
if json_str == "[DONE]":
|
67 |
+
break
|
68 |
+
try:
|
69 |
+
data = json.loads(json_str)
|
70 |
+
yield FakeChatCompletionChunk(data)
|
71 |
+
except json.JSONDecodeError:
|
72 |
+
print(f"Warning: Could not decode JSON from stream line: {json_str}")
|
73 |
+
continue
|
74 |
+
|
75 |
+
async def _streaming_create(self, **kwargs) -> AsyncGenerator[FakeChatCompletionChunk, None]:
|
76 |
+
"""Handles the creation of a streaming request using httpx."""
|
77 |
+
endpoint = f"{self.base_url}/chat/completions"
|
78 |
+
headers = {"Content-Type": "application/json"}
|
79 |
+
params = {"key": self.api_key}
|
80 |
+
|
81 |
+
payload = kwargs.copy()
|
82 |
+
if 'extra_body' in payload:
|
83 |
+
payload.update(payload.pop('extra_body'))
|
84 |
+
|
85 |
+
async with httpx.AsyncClient(timeout=300) as client:
|
86 |
+
async with client.stream("POST", endpoint, headers=headers, params=params, json=payload, timeout=None) as response:
|
87 |
+
response.raise_for_status()
|
88 |
+
async for chunk in self._stream_generator(response):
|
89 |
+
yield chunk
|
90 |
+
|
91 |
+
async def create(self, **kwargs) -> Any:
|
92 |
+
"""
|
93 |
+
Mimics the 'create' method of the OpenAI client.
|
94 |
+
It builds and sends a direct HTTP request using httpx, delegating
|
95 |
+
to the appropriate streaming or non-streaming handler.
|
96 |
+
"""
|
97 |
+
is_streaming = kwargs.get("stream", False)
|
98 |
+
|
99 |
+
if is_streaming:
|
100 |
+
return self._streaming_create(**kwargs)
|
101 |
+
|
102 |
+
# Non-streaming logic
|
103 |
+
endpoint = f"{self.base_url}/chat/completions"
|
104 |
+
headers = {"Content-Type": "application/json"}
|
105 |
+
params = {"key": self.api_key}
|
106 |
+
|
107 |
+
payload = kwargs.copy()
|
108 |
+
if 'extra_body' in payload:
|
109 |
+
payload.update(payload.pop('extra_body'))
|
110 |
+
|
111 |
+
async with httpx.AsyncClient(timeout=300) as client:
|
112 |
+
response = await client.post(endpoint, headers=headers, params=params, json=payload, timeout=None)
|
113 |
+
response.raise_for_status()
|
114 |
+
return FakeChatCompletion(response.json())
|
115 |
|
116 |
|
117 |
class OpenAIDirectHandler:
|
118 |
"""Handles OpenAI Direct mode operations including client creation and response processing."""
|
119 |
|
120 |
+
def __init__(self, credential_manager=None, express_key_manager=None):
|
121 |
self.credential_manager = credential_manager
|
122 |
+
self.express_key_manager = express_key_manager
|
123 |
self.safety_settings = [
|
124 |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
125 |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
|
|
127 |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
128 |
{"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
|
129 |
]
|
130 |
+
|
131 |
def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
|
132 |
"""Create an OpenAI client configured for Vertex AI endpoint."""
|
133 |
endpoint_url = (
|
|
|
172 |
|
173 |
async def handle_streaming_response(
|
174 |
self,
|
175 |
+
openai_client: Any, # Can be openai.AsyncOpenAI or our wrapper
|
176 |
openai_params: Dict[str, Any],
|
177 |
openai_extra_body: Dict[str, Any],
|
178 |
request: OpenAIRequest
|
|
|
199 |
|
200 |
async def _true_stream_generator(
|
201 |
self,
|
202 |
+
openai_client: Any, # Can be openai.AsyncOpenAI or our wrapper
|
203 |
openai_params: Dict[str, Any],
|
204 |
openai_extra_body: Dict[str, Any],
|
205 |
request: OpenAIRequest
|
|
|
228 |
delta = choices[0].get('delta')
|
229 |
if delta and isinstance(delta, dict):
|
230 |
# Always remove extra_content if present
|
231 |
+
|
232 |
if 'extra_content' in delta:
|
233 |
del delta['extra_content']
|
234 |
|
|
|
335 |
|
336 |
async def handle_non_streaming_response(
|
337 |
self,
|
338 |
+
openai_client: Any, # Can be openai.AsyncOpenAI or our wrapper
|
339 |
openai_params: Dict[str, Any],
|
340 |
openai_extra_body: Dict[str, Any],
|
341 |
request: OpenAIRequest
|
|
|
389 |
content=create_openai_error_response(500, error_msg, "server_error")
|
390 |
)
|
391 |
|
392 |
+
async def process_request(self, request: OpenAIRequest, base_model_name: str, is_express: bool = False):
|
393 |
"""Main entry point for processing OpenAI Direct mode requests."""
|
394 |
+
print(f"INFO: Using OpenAI Direct Path for model: {request.model} (Express: {is_express})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
+
client: Any = None # Can be openai.AsyncOpenAI or our wrapper
|
397 |
+
|
398 |
+
try:
|
399 |
+
if is_express:
|
400 |
+
if not self.express_key_manager:
|
401 |
+
raise Exception("Express mode requires an ExpressKeyManager, but it was not provided.")
|
402 |
+
|
403 |
+
key_tuple = self.express_key_manager.get_express_api_key()
|
404 |
+
if not key_tuple:
|
405 |
+
raise Exception("OpenAI Express Mode requires an API key, but none were available.")
|
406 |
+
|
407 |
+
_, express_api_key = key_tuple
|
408 |
+
project_id = await discover_project_id(express_api_key)
|
409 |
+
|
410 |
+
client = ExpressClientWrapper(project_id=project_id, api_key=express_api_key)
|
411 |
+
print(f"INFO: [OpenAI Express Path] Using ExpressClientWrapper for project: {project_id}")
|
412 |
+
|
413 |
+
else: # Standard SA-based OpenAI SDK Path
|
414 |
+
if not self.credential_manager:
|
415 |
+
raise Exception("Standard OpenAI Direct mode requires a CredentialManager.")
|
416 |
+
|
417 |
+
rotated_credentials, rotated_project_id = self.credential_manager.get_credentials()
|
418 |
+
if not rotated_credentials or not rotated_project_id:
|
419 |
+
raise Exception("OpenAI Direct Mode requires GCP credentials, but none were available.")
|
420 |
+
|
421 |
+
print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
|
422 |
+
gcp_token = _refresh_auth(rotated_credentials)
|
423 |
+
if not gcp_token:
|
424 |
+
raise Exception(f"Failed to obtain valid GCP token for OpenAI client (Project: {rotated_project_id}).")
|
425 |
+
|
426 |
+
client = self.create_openai_client(rotated_project_id, gcp_token)
|
427 |
+
|
428 |
+
model_id = f"google/{base_model_name}"
|
429 |
+
openai_params = self.prepare_openai_params(request, model_id)
|
430 |
+
openai_extra_body = self.prepare_extra_body()
|
431 |
+
|
432 |
+
if request.stream:
|
433 |
+
return await self.handle_streaming_response(
|
434 |
+
client, openai_params, openai_extra_body, request
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
return await self.handle_non_streaming_response(
|
438 |
+
client, openai_params, openai_extra_body, request
|
439 |
+
)
|
440 |
+
except Exception as e:
|
441 |
+
error_msg = f"Error in process_request for {request.model}: {e}"
|
442 |
print(f"ERROR: {error_msg}")
|
443 |
+
return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/routes/chat_api.py
CHANGED
@@ -46,9 +46,10 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
46 |
is_openai_direct_model = False
|
47 |
if request.model.endswith(OPENAI_DIRECT_SUFFIX):
|
48 |
temp_name_for_marker_check = request.model[:-len(OPENAI_DIRECT_SUFFIX)]
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
52 |
is_openai_direct_model = True
|
53 |
is_auto_model = request.model.endswith("-auto")
|
54 |
is_grounded_search = request.model.endswith("-search")
|
@@ -175,8 +176,12 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
175 |
|
176 |
if is_openai_direct_model:
|
177 |
# Use the new OpenAI handler
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
180 |
elif is_auto_model:
|
181 |
print(f"Processing auto model: {request.model}")
|
182 |
attempts = [
|
|
|
46 |
is_openai_direct_model = False
|
47 |
if request.model.endswith(OPENAI_DIRECT_SUFFIX):
|
48 |
temp_name_for_marker_check = request.model[:-len(OPENAI_DIRECT_SUFFIX)]
|
49 |
+
# An OpenAI model can be prefixed with PAY, EXPRESS, or contain EXP
|
50 |
+
if temp_name_for_marker_check.startswith(PAY_PREFIX) or \
|
51 |
+
temp_name_for_marker_check.startswith(EXPRESS_PREFIX) or \
|
52 |
+
EXPERIMENTAL_MARKER in temp_name_for_marker_check:
|
53 |
is_openai_direct_model = True
|
54 |
is_auto_model = request.model.endswith("-auto")
|
55 |
is_grounded_search = request.model.endswith("-search")
|
|
|
176 |
|
177 |
if is_openai_direct_model:
|
178 |
# Use the new OpenAI handler
|
179 |
+
if is_express_model_request:
|
180 |
+
openai_handler = OpenAIDirectHandler(express_key_manager=express_key_manager_instance)
|
181 |
+
return await openai_handler.process_request(request, base_model_name, is_express=True)
|
182 |
+
else:
|
183 |
+
openai_handler = OpenAIDirectHandler(credential_manager=credential_manager_instance)
|
184 |
+
return await openai_handler.process_request(request, base_model_name)
|
185 |
elif is_auto_model:
|
186 |
print(f"Processing auto model: {request.model}")
|
187 |
attempts = [
|
app/routes/models_api.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import time
|
2 |
-
from fastapi import APIRouter, Depends, Request
|
3 |
-
from typing import List, Dict, Any
|
4 |
from auth import get_api_key
|
5 |
from model_loader import get_vertex_models, get_vertex_express_models, refresh_models_config_cache
|
6 |
-
import config as app_config
|
7 |
-
from credentials_manager import CredentialManager
|
8 |
|
9 |
router = APIRouter()
|
10 |
|
@@ -12,10 +12,10 @@ router = APIRouter()
|
|
12 |
async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_key)):
|
13 |
await refresh_models_config_cache()
|
14 |
|
15 |
-
OPENAI_DIRECT_SUFFIX = "-openai"
|
16 |
-
EXPERIMENTAL_MARKER = "-exp-"
|
17 |
PAY_PREFIX = "[PAY]"
|
18 |
-
|
|
|
|
|
19 |
credential_manager_instance: CredentialManager = fastapi_request.app.state.credential_manager
|
20 |
express_key_manager_instance = fastapi_request.app.state.express_key_manager
|
21 |
|
@@ -25,109 +25,49 @@ async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_k
|
|
25 |
raw_vertex_models = await get_vertex_models()
|
26 |
raw_express_models = await get_vertex_express_models()
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
if has_express_key:
|
32 |
-
candidate_model_ids.update(raw_express_models)
|
33 |
-
# If *only* express key is available, only express models (and their variants) should be listed.
|
34 |
-
# The current `vertex_model_ids` from remote config might contain non-express models.
|
35 |
-
# The `get_vertex_express_models()` should be the source of truth for express-eligible base models.
|
36 |
-
if not has_sa_creds:
|
37 |
-
# Only list models that are explicitly in the express list.
|
38 |
-
# Suffix generation will apply only to these if they are not gemini-2.0
|
39 |
-
all_model_ids = set(raw_express_models)
|
40 |
-
else:
|
41 |
-
# Both SA and Express are available, combine all known models
|
42 |
-
all_model_ids = set(raw_vertex_models + raw_express_models)
|
43 |
-
elif has_sa_creds:
|
44 |
-
# Only SA creds available, use all vertex_models (which might include express-eligible ones)
|
45 |
-
all_model_ids = set(raw_vertex_models)
|
46 |
-
else:
|
47 |
-
# No credentials available
|
48 |
-
all_model_ids = set()
|
49 |
-
|
50 |
-
# Create extended model list with variations (search, encrypt, auto etc.)
|
51 |
-
# This logic might need to be more sophisticated based on actual supported features per base model.
|
52 |
-
# For now, let's assume for each base model, we might have these variations.
|
53 |
-
# A better approach would be if the remote config specified these variations.
|
54 |
-
|
55 |
-
dynamic_models_data: List[Dict[str, Any]] = []
|
56 |
current_time = int(time.time())
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
current_display_prefix = ""
|
61 |
-
# Only add PAY_PREFIX if the model is not already an EXPRESS model (which has its own prefix)
|
62 |
-
# Apply PAY_PREFIX if SA creds are present, it's a model from raw_vertex_models,
|
63 |
-
# it's not experimental, and not already an EXPRESS model.
|
64 |
-
if has_sa_creds and \
|
65 |
-
original_model_id in raw_vertex_models_set and \
|
66 |
-
EXPERIMENTAL_MARKER not in original_model_id and \
|
67 |
-
not original_model_id.startswith("[EXPRESS]"):
|
68 |
-
current_display_prefix = PAY_PREFIX
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
"permission": [], "root": original_model_id, "parent": None
|
75 |
-
})
|
76 |
-
|
77 |
-
# Conditionally add common variations (standard suffixes)
|
78 |
-
if not original_model_id.startswith("gemini-2.0"): # Suffix rules based on original_model_id
|
79 |
-
standard_suffixes = ["-search", "-encrypt", "-encrypt-full", "-auto"]
|
80 |
-
for suffix in standard_suffixes:
|
81 |
-
# Suffix is applied to the original model ID part
|
82 |
-
suffixed_model_part = f"{original_model_id}{suffix}"
|
83 |
-
# Then the whole thing is prefixed
|
84 |
-
final_suffixed_display_id = f"{current_display_prefix}{suffixed_model_part}"
|
85 |
-
|
86 |
-
# Check if this suffixed ID is already in all_model_ids (unlikely with prefix) or already added
|
87 |
-
if final_suffixed_display_id not in all_model_ids and not any(m['id'] == final_suffixed_display_id for m in dynamic_models_data):
|
88 |
-
dynamic_models_data.append({
|
89 |
-
"id": final_suffixed_display_id, "object": "model", "created": current_time, "owned_by": "google",
|
90 |
-
"permission": [], "root": original_model_id, "parent": None
|
91 |
-
})
|
92 |
-
|
93 |
-
# Apply special suffixes for models starting with "gemini-2.5-flash" or containing "gemini-2.5-pro"
|
94 |
-
# This includes both regular and EXPRESS versions
|
95 |
-
if "gemini-2.5-flash" in original_model_id or "gemini-2.5-pro" in original_model_id: # Suffix rules based on original_model_id
|
96 |
-
special_thinking_suffixes = ["-nothinking", "-max"]
|
97 |
-
for special_suffix in special_thinking_suffixes:
|
98 |
-
suffixed_model_part = f"{original_model_id}{special_suffix}"
|
99 |
-
final_special_suffixed_display_id = f"{current_display_prefix}{suffixed_model_part}"
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
# model_list = list(final_models_data_map.values())
|
131 |
-
# model_list.sort()
|
132 |
-
|
133 |
-
return {"object": "list", "data": sorted(dynamic_models_data, key=lambda x: x['id'])}
|
|
|
1 |
import time
|
2 |
+
from fastapi import APIRouter, Depends, Request
|
3 |
+
from typing import List, Dict, Any, Set
|
4 |
from auth import get_api_key
|
5 |
from model_loader import get_vertex_models, get_vertex_express_models, refresh_models_config_cache
|
6 |
+
import config as app_config
|
7 |
+
from credentials_manager import CredentialManager
|
8 |
|
9 |
router = APIRouter()
|
10 |
|
|
|
12 |
async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_key)):
|
13 |
await refresh_models_config_cache()
|
14 |
|
|
|
|
|
15 |
PAY_PREFIX = "[PAY]"
|
16 |
+
EXPRESS_PREFIX = "[EXPRESS] "
|
17 |
+
OPENAI_DIRECT_SUFFIX = "-openai"
|
18 |
+
|
19 |
credential_manager_instance: CredentialManager = fastapi_request.app.state.credential_manager
|
20 |
express_key_manager_instance = fastapi_request.app.state.express_key_manager
|
21 |
|
|
|
25 |
raw_vertex_models = await get_vertex_models()
|
26 |
raw_express_models = await get_vertex_express_models()
|
27 |
|
28 |
+
final_model_list: List[Dict[str, Any]] = []
|
29 |
+
processed_ids: Set[str] = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
current_time = int(time.time())
|
31 |
|
32 |
+
def add_model_and_variants(base_id: str, prefix: str):
|
33 |
+
"""Adds a model and its variants to the list if not already present."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
# Define all possible suffixes for a given model
|
36 |
+
suffixes = [""] # For the base model itself
|
37 |
+
if not base_id.startswith("gemini-2.0"):
|
38 |
+
suffixes.extend(["-search", "-encrypt", "-encrypt-full", "-auto"])
|
39 |
+
if "gemini-2.5-flash" in base_id or "gemini-2.5-pro" in base_id:
|
40 |
+
suffixes.extend(["-nothinking", "-max"])
|
41 |
|
42 |
+
# Add the openai variant for all models
|
43 |
+
suffixes.append(OPENAI_DIRECT_SUFFIX)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
for suffix in suffixes:
|
46 |
+
model_id_with_suffix = f"{base_id}{suffix}"
|
47 |
+
|
48 |
+
# Experimental models have no prefix
|
49 |
+
final_id = f"{prefix}{model_id_with_suffix}" if "-exp-" not in base_id else model_id_with_suffix
|
50 |
|
51 |
+
if final_id not in processed_ids:
|
52 |
+
final_model_list.append({
|
53 |
+
"id": final_id,
|
54 |
+
"object": "model",
|
55 |
+
"created": current_time,
|
56 |
+
"owned_by": "google",
|
57 |
+
"permission": [],
|
58 |
+
"root": base_id,
|
59 |
+
"parent": None
|
60 |
+
})
|
61 |
+
processed_ids.add(final_id)
|
62 |
+
|
63 |
+
# Process Express Key models first
|
64 |
+
if has_express_key:
|
65 |
+
for model_id in raw_express_models:
|
66 |
+
add_model_and_variants(model_id, EXPRESS_PREFIX)
|
67 |
+
|
68 |
+
# Process Service Account (PAY) models, they have lower priority
|
69 |
+
if has_sa_creds:
|
70 |
+
for model_id in raw_vertex_models:
|
71 |
+
add_model_and_variants(model_id, PAY_PREFIX)
|
72 |
+
|
73 |
+
return {"object": "list", "data": sorted(final_model_list, key=lambda x: x['id'])}
|
|
|
|
|
|
|
|