Commit
·
02a1c2c
1
Parent(s):
0746311
added openai model back
Browse files- app/api_helpers.py +1 -1
- app/credentials_manager.py +17 -0
- app/requirements.txt +3 -1
- app/routes/chat_api.py +102 -2
- app/routes/models_api.py +16 -1
app/api_helpers.py
CHANGED
@@ -40,7 +40,7 @@ def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
|
|
40 |
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
|
41 |
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
|
42 |
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
43 |
-
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="
|
44 |
]
|
45 |
return config
|
46 |
|
|
|
40 |
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
|
41 |
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
|
42 |
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
43 |
+
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="NONE")
|
44 |
]
|
45 |
return config
|
46 |
|
app/credentials_manager.py
CHANGED
@@ -3,6 +3,7 @@ import glob
|
|
3 |
import random
|
4 |
import json
|
5 |
from typing import List, Dict, Any
|
|
|
6 |
from google.oauth2 import service_account
|
7 |
import config as app_config # Changed from relative
|
8 |
|
@@ -51,6 +52,22 @@ def parse_multiple_json_credentials(json_str: str) -> List[Dict[str, Any]]:
|
|
51 |
|
52 |
print(f"DEBUG: Parsed {len(credentials_list)} credential objects from the input string.")
|
53 |
return credentials_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
# Credential Manager for handling multiple service accounts
|
|
|
3 |
import random
|
4 |
import json
|
5 |
from typing import List, Dict, Any
|
6 |
+
from google.auth.transport.requests import Request as AuthRequest
|
7 |
from google.oauth2 import service_account
|
8 |
import config as app_config # Changed from relative
|
9 |
|
|
|
52 |
|
53 |
print(f"DEBUG: Parsed {len(credentials_list)} credential objects from the input string.")
|
54 |
return credentials_list
|
55 |
+
def _refresh_auth(credentials):
|
56 |
+
"""Helper function to refresh GCP token."""
|
57 |
+
if not credentials:
|
58 |
+
print("ERROR: _refresh_auth called with no credentials.")
|
59 |
+
return None
|
60 |
+
try:
|
61 |
+
# Assuming credentials object has a project_id attribute for logging
|
62 |
+
project_id_for_log = getattr(credentials, 'project_id', 'Unknown')
|
63 |
+
print(f"INFO: Attempting to refresh token for project: {project_id_for_log}...")
|
64 |
+
credentials.refresh(AuthRequest())
|
65 |
+
print(f"INFO: Token refreshed successfully for project: {project_id_for_log}")
|
66 |
+
return credentials.token
|
67 |
+
except Exception as e:
|
68 |
+
project_id_for_log = getattr(credentials, 'project_id', 'Unknown')
|
69 |
+
print(f"ERROR: Error refreshing GCP token for project {project_id_for_log}: {e}")
|
70 |
+
return None
|
71 |
|
72 |
|
73 |
# Credential Manager for handling multiple service accounts
|
app/requirements.txt
CHANGED
@@ -4,4 +4,6 @@ google-auth==2.38.0
|
|
4 |
google-cloud-aiplatform==1.86.0
|
5 |
pydantic==2.6.1
|
6 |
google-genai==1.13.0
|
7 |
-
httpx>=0.25.0
|
|
|
|
|
|
4 |
google-cloud-aiplatform==1.86.0
|
5 |
pydantic==2.6.1
|
6 |
google-genai==1.13.0
|
7 |
+
httpx>=0.25.0
|
8 |
+
openai
|
9 |
+
google-auth-oauthlib
|
app/routes/chat_api.py
CHANGED
@@ -7,6 +7,8 @@ from typing import List, Dict, Any
|
|
7 |
# Google and OpenAI specific imports
|
8 |
from google.genai import types
|
9 |
from google import genai
|
|
|
|
|
10 |
|
11 |
# Local module imports
|
12 |
from models import OpenAIRequest, OpenAIMessage
|
@@ -31,6 +33,8 @@ router = APIRouter()
|
|
31 |
async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api_key: str = Depends(get_api_key)):
|
32 |
try:
|
33 |
credential_manager_instance = fastapi_request.app.state.credential_manager
|
|
|
|
|
34 |
|
35 |
# Dynamically fetch allowed models for validation
|
36 |
vertex_model_ids = await get_vertex_models()
|
@@ -58,9 +62,16 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
58 |
all_allowed_model_ids.update(vertex_express_model_ids)
|
59 |
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
if not request.model or request.model not in all_allowed_model_ids:
|
62 |
return JSONResponse(status_code=400, content=create_openai_error_response(400, f"Model '{request.model}' not found or not supported by this adapter. Valid models are: {sorted(list(all_allowed_model_ids))}", "invalid_request_error"))
|
63 |
|
|
|
64 |
is_auto_model = request.model.endswith("-auto")
|
65 |
is_grounded_search = request.model.endswith("-search")
|
66 |
is_encrypted_model = request.model.endswith("-encrypt")
|
@@ -71,7 +82,9 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
71 |
|
72 |
# Determine base_model_name by stripping known suffixes
|
73 |
# This order matters if a model could have multiple (e.g. -encrypt-auto, though not currently a pattern)
|
74 |
-
if
|
|
|
|
|
75 |
elif is_grounded_search: base_model_name = request.model[:-len("-search")]
|
76 |
elif is_encrypted_full_model: base_model_name = request.model[:-len("-encrypt-full")] # Must be before -encrypt
|
77 |
elif is_encrypted_model: base_model_name = request.model[:-len("-encrypt")]
|
@@ -113,8 +126,95 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
|
|
113 |
return JSONResponse(status_code=500, content=create_openai_error_response(500, "Vertex AI client not available. Ensure credentials are set up correctly (env var or files).", "server_error"))
|
114 |
|
115 |
encryption_instructions_placeholder = ["// Protocol Instructions Placeholder //"] # Actual instructions are in message_processing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
print(f"Processing auto model: {request.model}")
|
119 |
attempts = [
|
120 |
{"name": "base", "model": base_model_name, "prompt_func": create_gemini_prompt, "config_modifier": lambda c: c},
|
|
|
7 |
# Google and OpenAI specific imports
|
8 |
from google.genai import types
|
9 |
from google import genai
|
10 |
+
import openai
|
11 |
+
from app.credentials_manager import _refresh_auth
|
12 |
|
13 |
# Local module imports
|
14 |
from models import OpenAIRequest, OpenAIMessage
|
|
|
33 |
async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api_key: str = Depends(get_api_key)):
|
34 |
try:
|
35 |
credential_manager_instance = fastapi_request.app.state.credential_manager
|
36 |
+
OPENAI_DIRECT_SUFFIX = "-openai"
|
37 |
+
EXPERIMENTAL_MARKER = "-exp-"
|
38 |
|
39 |
# Dynamically fetch allowed models for validation
|
40 |
vertex_model_ids = await get_vertex_models()
|
|
|
62 |
all_allowed_model_ids.update(vertex_express_model_ids)
|
63 |
|
64 |
|
65 |
+
# Add potential -openai models if they contain -exp-
|
66 |
+
potential_openai_direct_models = set()
|
67 |
+
for base_id in vertex_model_ids: # vertex_model_ids are base models
|
68 |
+
if EXPERIMENTAL_MARKER in base_id:
|
69 |
+
potential_openai_direct_models.add(f"{base_id}{OPENAI_DIRECT_SUFFIX}")
|
70 |
+
all_allowed_model_ids.update(potential_openai_direct_models)
|
71 |
if not request.model or request.model not in all_allowed_model_ids:
|
72 |
return JSONResponse(status_code=400, content=create_openai_error_response(400, f"Model '{request.model}' not found or not supported by this adapter. Valid models are: {sorted(list(all_allowed_model_ids))}", "invalid_request_error"))
|
73 |
|
74 |
+
is_openai_direct_model = request.model.endswith(OPENAI_DIRECT_SUFFIX) and EXPERIMENTAL_MARKER in request.model
|
75 |
is_auto_model = request.model.endswith("-auto")
|
76 |
is_grounded_search = request.model.endswith("-search")
|
77 |
is_encrypted_model = request.model.endswith("-encrypt")
|
|
|
82 |
|
83 |
# Determine base_model_name by stripping known suffixes
|
84 |
# This order matters if a model could have multiple (e.g. -encrypt-auto, though not currently a pattern)
|
85 |
+
if is_openai_direct_model:
|
86 |
+
base_model_name = request.model[:-len(OPENAI_DIRECT_SUFFIX)]
|
87 |
+
elif is_auto_model: base_model_name = request.model[:-len("-auto")]
|
88 |
elif is_grounded_search: base_model_name = request.model[:-len("-search")]
|
89 |
elif is_encrypted_full_model: base_model_name = request.model[:-len("-encrypt-full")] # Must be before -encrypt
|
90 |
elif is_encrypted_model: base_model_name = request.model[:-len("-encrypt")]
|
|
|
126 |
return JSONResponse(status_code=500, content=create_openai_error_response(500, "Vertex AI client not available. Ensure credentials are set up correctly (env var or files).", "server_error"))
|
127 |
|
128 |
encryption_instructions_placeholder = ["// Protocol Instructions Placeholder //"] # Actual instructions are in message_processing
|
129 |
+
if is_openai_direct_model:
|
130 |
+
print(f"INFO: Using OpenAI Direct Path for model: {request.model}")
|
131 |
+
# This mode exclusively uses rotated credentials, not express keys.
|
132 |
+
rotated_credentials, rotated_project_id = credential_manager_instance.get_random_credentials()
|
133 |
+
|
134 |
+
if not rotated_credentials or not rotated_project_id:
|
135 |
+
error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully."
|
136 |
+
print(f"ERROR: {error_msg}")
|
137 |
+
return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
|
138 |
+
|
139 |
+
print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
|
140 |
+
gcp_token = _refresh_auth(rotated_credentials)
|
141 |
+
|
142 |
+
if not gcp_token:
|
143 |
+
error_msg = f"Failed to obtain valid GCP token for OpenAI client (Source: Credential Manager, Project: {rotated_project_id})."
|
144 |
+
print(f"ERROR: {error_msg}")
|
145 |
+
return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
|
146 |
+
|
147 |
+
PROJECT_ID = rotated_project_id
|
148 |
+
LOCATION = "us-central1" # Fixed as per user confirmation
|
149 |
+
VERTEX_AI_OPENAI_ENDPOINT_URL = (
|
150 |
+
f"https://{LOCATION}-aiplatform.googleapis.com/v1beta1/"
|
151 |
+
f"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi"
|
152 |
+
)
|
153 |
+
# base_model_name is already extracted (e.g., "gemini-1.5-pro-exp-v1")
|
154 |
+
UNDERLYING_MODEL_ID = f"google/{base_model_name}"
|
155 |
+
|
156 |
+
openai_client = openai.AsyncOpenAI(
|
157 |
+
base_url=VERTEX_AI_OPENAI_ENDPOINT_URL,
|
158 |
+
api_key=gcp_token, # OAuth token
|
159 |
+
)
|
160 |
|
161 |
+
openai_safety_settings = [
|
162 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
163 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
164 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
165 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
166 |
+
{"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'NONE'}
|
167 |
+
]
|
168 |
+
|
169 |
+
openai_params = {
|
170 |
+
"model": UNDERLYING_MODEL_ID,
|
171 |
+
"messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
|
172 |
+
"temperature": request.temperature,
|
173 |
+
"max_tokens": request.max_tokens,
|
174 |
+
"top_p": request.top_p,
|
175 |
+
"stream": request.stream,
|
176 |
+
"stop": request.stop,
|
177 |
+
"seed": request.seed,
|
178 |
+
"n": request.n,
|
179 |
+
}
|
180 |
+
openai_params = {k: v for k, v in openai_params.items() if v is not None}
|
181 |
+
|
182 |
+
openai_extra_body = {
|
183 |
+
'google': {
|
184 |
+
'safety_settings': openai_safety_settings
|
185 |
+
}
|
186 |
+
}
|
187 |
+
|
188 |
+
if request.stream:
|
189 |
+
async def openai_stream_generator():
|
190 |
+
try:
|
191 |
+
stream_response = await openai_client.chat.completions.create(
|
192 |
+
**openai_params,
|
193 |
+
extra_body=openai_extra_body
|
194 |
+
)
|
195 |
+
async for chunk in stream_response:
|
196 |
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
197 |
+
yield "data: [DONE]\n\n"
|
198 |
+
except Exception as stream_error:
|
199 |
+
error_msg_stream = f"Error during OpenAI client streaming for {request.model}: {str(stream_error)}"
|
200 |
+
print(f"ERROR: {error_msg_stream}")
|
201 |
+
error_response_content = create_openai_error_response(500, error_msg_stream, "server_error")
|
202 |
+
yield f"data: {json.dumps(error_response_content)}\n\n" # Ensure json is imported
|
203 |
+
yield "data: [DONE]\n\n"
|
204 |
+
return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
|
205 |
+
else: # Not streaming
|
206 |
+
try:
|
207 |
+
response = await openai_client.chat.completions.create(
|
208 |
+
**openai_params,
|
209 |
+
extra_body=openai_extra_body
|
210 |
+
)
|
211 |
+
return JSONResponse(content=response.model_dump(exclude_unset=True))
|
212 |
+
except Exception as generate_error:
|
213 |
+
error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
|
214 |
+
print(f"ERROR: {error_msg_generate}")
|
215 |
+
error_response = create_openai_error_response(500, error_msg_generate, "server_error")
|
216 |
+
return JSONResponse(status_code=500, content=error_response)
|
217 |
+
elif is_auto_model:
|
218 |
print(f"Processing auto model: {request.model}")
|
219 |
attempts = [
|
220 |
{"name": "base", "model": base_model_name, "prompt_func": create_gemini_prompt, "config_modifier": lambda c: c},
|
app/routes/models_api.py
CHANGED
@@ -12,6 +12,8 @@ router = APIRouter()
|
|
12 |
async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_key)):
|
13 |
await refresh_models_config_cache()
|
14 |
|
|
|
|
|
15 |
# Access credential_manager from app state
|
16 |
credential_manager_instance: CredentialManager = fastapi_request.app.state.credential_manager
|
17 |
|
@@ -80,7 +82,20 @@ async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_k
|
|
80 |
"permission": [], "root": model_id, "parent": None
|
81 |
})
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
final_models_data_map = {m["id"]: m for m in dynamic_models_data}
|
85 |
|
86 |
return {"object": "list", "data": list(final_models_data_map.values())}
|
|
|
12 |
async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_key)):
|
13 |
await refresh_models_config_cache()
|
14 |
|
15 |
+
OPENAI_DIRECT_SUFFIX = "-openai"
|
16 |
+
EXPERIMENTAL_MARKER = "-exp-"
|
17 |
# Access credential_manager from app state
|
18 |
credential_manager_instance: CredentialManager = fastapi_request.app.state.credential_manager
|
19 |
|
|
|
82 |
"permission": [], "root": model_id, "parent": None
|
83 |
})
|
84 |
|
85 |
+
# Ensure uniqueness again after adding suffixes
|
86 |
+
# Add OpenAI direct variations for experimental models if SA creds are available
|
87 |
+
if has_sa_creds: # OpenAI direct mode only works with SA credentials
|
88 |
+
# We should iterate through the base models that could be experimental.
|
89 |
+
# `raw_vertex_models` should contain these.
|
90 |
+
for model_id in raw_vertex_models: # Iterate through the original list of base models
|
91 |
+
if EXPERIMENTAL_MARKER in model_id:
|
92 |
+
suffixed_id = f"{model_id}{OPENAI_DIRECT_SUFFIX}"
|
93 |
+
# Check if already added (e.g. if remote config somehow already listed it)
|
94 |
+
if not any(m['id'] == suffixed_id for m in dynamic_models_data):
|
95 |
+
dynamic_models_data.append({
|
96 |
+
"id": suffixed_id, "object": "model", "created": current_time, "owned_by": "google",
|
97 |
+
"permission": [], "root": model_id, "parent": None
|
98 |
+
})
|
99 |
final_models_data_map = {m["id"]: m for m in dynamic_models_data}
|
100 |
|
101 |
return {"object": "list", "data": list(final_models_data_map.values())}
|