bibibi12345 commited on
Commit
02a1c2c
·
1 Parent(s): 0746311

added openai model back

Browse files
app/api_helpers.py CHANGED
@@ -40,7 +40,7 @@ def create_generation_config(request: OpenAIRequest) -> Dict[str, Any]:
40
  types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
41
  types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
42
  types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
43
- types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")
44
  ]
45
  return config
46
 
 
40
  types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
41
  types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
42
  types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
43
+ types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="NONE")
44
  ]
45
  return config
46
 
app/credentials_manager.py CHANGED
@@ -3,6 +3,7 @@ import glob
3
  import random
4
  import json
5
  from typing import List, Dict, Any
 
6
  from google.oauth2 import service_account
7
  import config as app_config # Changed from relative
8
 
@@ -51,6 +52,22 @@ def parse_multiple_json_credentials(json_str: str) -> List[Dict[str, Any]]:
51
 
52
  print(f"DEBUG: Parsed {len(credentials_list)} credential objects from the input string.")
53
  return credentials_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  # Credential Manager for handling multiple service accounts
 
3
  import random
4
  import json
5
  from typing import List, Dict, Any
6
+ from google.auth.transport.requests import Request as AuthRequest
7
  from google.oauth2 import service_account
8
  import config as app_config # Changed from relative
9
 
 
52
 
53
  print(f"DEBUG: Parsed {len(credentials_list)} credential objects from the input string.")
54
  return credentials_list
55
+ def _refresh_auth(credentials):
56
+ """Helper function to refresh GCP token."""
57
+ if not credentials:
58
+ print("ERROR: _refresh_auth called with no credentials.")
59
+ return None
60
+ try:
61
+ # Assuming credentials object has a project_id attribute for logging
62
+ project_id_for_log = getattr(credentials, 'project_id', 'Unknown')
63
+ print(f"INFO: Attempting to refresh token for project: {project_id_for_log}...")
64
+ credentials.refresh(AuthRequest())
65
+ print(f"INFO: Token refreshed successfully for project: {project_id_for_log}")
66
+ return credentials.token
67
+ except Exception as e:
68
+ project_id_for_log = getattr(credentials, 'project_id', 'Unknown')
69
+ print(f"ERROR: Error refreshing GCP token for project {project_id_for_log}: {e}")
70
+ return None
71
 
72
 
73
  # Credential Manager for handling multiple service accounts
app/requirements.txt CHANGED
@@ -4,4 +4,6 @@ google-auth==2.38.0
4
  google-cloud-aiplatform==1.86.0
5
  pydantic==2.6.1
6
  google-genai==1.13.0
7
- httpx>=0.25.0
 
 
 
4
  google-cloud-aiplatform==1.86.0
5
  pydantic==2.6.1
6
  google-genai==1.13.0
7
+ httpx>=0.25.0
8
+ openai
9
+ google-auth-oauthlib
app/routes/chat_api.py CHANGED
@@ -7,6 +7,8 @@ from typing import List, Dict, Any
7
  # Google and OpenAI specific imports
8
  from google.genai import types
9
  from google import genai
 
 
10
 
11
  # Local module imports
12
  from models import OpenAIRequest, OpenAIMessage
@@ -31,6 +33,8 @@ router = APIRouter()
31
  async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api_key: str = Depends(get_api_key)):
32
  try:
33
  credential_manager_instance = fastapi_request.app.state.credential_manager
 
 
34
 
35
  # Dynamically fetch allowed models for validation
36
  vertex_model_ids = await get_vertex_models()
@@ -58,9 +62,16 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
58
  all_allowed_model_ids.update(vertex_express_model_ids)
59
 
60
 
 
 
 
 
 
 
61
  if not request.model or request.model not in all_allowed_model_ids:
62
  return JSONResponse(status_code=400, content=create_openai_error_response(400, f"Model '{request.model}' not found or not supported by this adapter. Valid models are: {sorted(list(all_allowed_model_ids))}", "invalid_request_error"))
63
 
 
64
  is_auto_model = request.model.endswith("-auto")
65
  is_grounded_search = request.model.endswith("-search")
66
  is_encrypted_model = request.model.endswith("-encrypt")
@@ -71,7 +82,9 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
71
 
72
  # Determine base_model_name by stripping known suffixes
73
  # This order matters if a model could have multiple (e.g. -encrypt-auto, though not currently a pattern)
74
- if is_auto_model: base_model_name = request.model[:-len("-auto")]
 
 
75
  elif is_grounded_search: base_model_name = request.model[:-len("-search")]
76
  elif is_encrypted_full_model: base_model_name = request.model[:-len("-encrypt-full")] # Must be before -encrypt
77
  elif is_encrypted_model: base_model_name = request.model[:-len("-encrypt")]
@@ -113,8 +126,95 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
113
  return JSONResponse(status_code=500, content=create_openai_error_response(500, "Vertex AI client not available. Ensure credentials are set up correctly (env var or files).", "server_error"))
114
 
115
  encryption_instructions_placeholder = ["// Protocol Instructions Placeholder //"] # Actual instructions are in message_processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- if is_auto_model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  print(f"Processing auto model: {request.model}")
119
  attempts = [
120
  {"name": "base", "model": base_model_name, "prompt_func": create_gemini_prompt, "config_modifier": lambda c: c},
 
7
  # Google and OpenAI specific imports
8
  from google.genai import types
9
  from google import genai
10
+ import openai
11
+ from app.credentials_manager import _refresh_auth
12
 
13
  # Local module imports
14
  from models import OpenAIRequest, OpenAIMessage
 
33
  async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api_key: str = Depends(get_api_key)):
34
  try:
35
  credential_manager_instance = fastapi_request.app.state.credential_manager
36
+ OPENAI_DIRECT_SUFFIX = "-openai"
37
+ EXPERIMENTAL_MARKER = "-exp-"
38
 
39
  # Dynamically fetch allowed models for validation
40
  vertex_model_ids = await get_vertex_models()
 
62
  all_allowed_model_ids.update(vertex_express_model_ids)
63
 
64
 
65
+ # Add potential -openai models if they contain -exp-
66
+ potential_openai_direct_models = set()
67
+ for base_id in vertex_model_ids: # vertex_model_ids are base models
68
+ if EXPERIMENTAL_MARKER in base_id:
69
+ potential_openai_direct_models.add(f"{base_id}{OPENAI_DIRECT_SUFFIX}")
70
+ all_allowed_model_ids.update(potential_openai_direct_models)
71
  if not request.model or request.model not in all_allowed_model_ids:
72
  return JSONResponse(status_code=400, content=create_openai_error_response(400, f"Model '{request.model}' not found or not supported by this adapter. Valid models are: {sorted(list(all_allowed_model_ids))}", "invalid_request_error"))
73
 
74
+ is_openai_direct_model = request.model.endswith(OPENAI_DIRECT_SUFFIX) and EXPERIMENTAL_MARKER in request.model
75
  is_auto_model = request.model.endswith("-auto")
76
  is_grounded_search = request.model.endswith("-search")
77
  is_encrypted_model = request.model.endswith("-encrypt")
 
82
 
83
  # Determine base_model_name by stripping known suffixes
84
  # This order matters if a model could have multiple (e.g. -encrypt-auto, though not currently a pattern)
85
+ if is_openai_direct_model:
86
+ base_model_name = request.model[:-len(OPENAI_DIRECT_SUFFIX)]
87
+ elif is_auto_model: base_model_name = request.model[:-len("-auto")]
88
  elif is_grounded_search: base_model_name = request.model[:-len("-search")]
89
  elif is_encrypted_full_model: base_model_name = request.model[:-len("-encrypt-full")] # Must be before -encrypt
90
  elif is_encrypted_model: base_model_name = request.model[:-len("-encrypt")]
 
126
  return JSONResponse(status_code=500, content=create_openai_error_response(500, "Vertex AI client not available. Ensure credentials are set up correctly (env var or files).", "server_error"))
127
 
128
  encryption_instructions_placeholder = ["// Protocol Instructions Placeholder //"] # Actual instructions are in message_processing
129
+ if is_openai_direct_model:
130
+ print(f"INFO: Using OpenAI Direct Path for model: {request.model}")
131
+ # This mode exclusively uses rotated credentials, not express keys.
132
+ rotated_credentials, rotated_project_id = credential_manager_instance.get_random_credentials()
133
+
134
+ if not rotated_credentials or not rotated_project_id:
135
+ error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully."
136
+ print(f"ERROR: {error_msg}")
137
+ return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
138
+
139
+ print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
140
+ gcp_token = _refresh_auth(rotated_credentials)
141
+
142
+ if not gcp_token:
143
+ error_msg = f"Failed to obtain valid GCP token for OpenAI client (Source: Credential Manager, Project: {rotated_project_id})."
144
+ print(f"ERROR: {error_msg}")
145
+ return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
146
+
147
+ PROJECT_ID = rotated_project_id
148
+ LOCATION = "us-central1" # Fixed as per user confirmation
149
+ VERTEX_AI_OPENAI_ENDPOINT_URL = (
150
+ f"https://{LOCATION}-aiplatform.googleapis.com/v1beta1/"
151
+ f"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi"
152
+ )
153
+ # base_model_name is already extracted (e.g., "gemini-1.5-pro-exp-v1")
154
+ UNDERLYING_MODEL_ID = f"google/{base_model_name}"
155
+
156
+ openai_client = openai.AsyncOpenAI(
157
+ base_url=VERTEX_AI_OPENAI_ENDPOINT_URL,
158
+ api_key=gcp_token, # OAuth token
159
+ )
160
 
161
+ openai_safety_settings = [
162
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
163
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
164
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
165
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
166
+ {"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'NONE'}
167
+ ]
168
+
169
+ openai_params = {
170
+ "model": UNDERLYING_MODEL_ID,
171
+ "messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
172
+ "temperature": request.temperature,
173
+ "max_tokens": request.max_tokens,
174
+ "top_p": request.top_p,
175
+ "stream": request.stream,
176
+ "stop": request.stop,
177
+ "seed": request.seed,
178
+ "n": request.n,
179
+ }
180
+ openai_params = {k: v for k, v in openai_params.items() if v is not None}
181
+
182
+ openai_extra_body = {
183
+ 'google': {
184
+ 'safety_settings': openai_safety_settings
185
+ }
186
+ }
187
+
188
+ if request.stream:
189
+ async def openai_stream_generator():
190
+ try:
191
+ stream_response = await openai_client.chat.completions.create(
192
+ **openai_params,
193
+ extra_body=openai_extra_body
194
+ )
195
+ async for chunk in stream_response:
196
+ yield f"data: {chunk.model_dump_json()}\n\n"
197
+ yield "data: [DONE]\n\n"
198
+ except Exception as stream_error:
199
+ error_msg_stream = f"Error during OpenAI client streaming for {request.model}: {str(stream_error)}"
200
+ print(f"ERROR: {error_msg_stream}")
201
+ error_response_content = create_openai_error_response(500, error_msg_stream, "server_error")
202
+ yield f"data: {json.dumps(error_response_content)}\n\n" # Ensure json is imported
203
+ yield "data: [DONE]\n\n"
204
+ return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
205
+ else: # Not streaming
206
+ try:
207
+ response = await openai_client.chat.completions.create(
208
+ **openai_params,
209
+ extra_body=openai_extra_body
210
+ )
211
+ return JSONResponse(content=response.model_dump(exclude_unset=True))
212
+ except Exception as generate_error:
213
+ error_msg_generate = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
214
+ print(f"ERROR: {error_msg_generate}")
215
+ error_response = create_openai_error_response(500, error_msg_generate, "server_error")
216
+ return JSONResponse(status_code=500, content=error_response)
217
+ elif is_auto_model:
218
  print(f"Processing auto model: {request.model}")
219
  attempts = [
220
  {"name": "base", "model": base_model_name, "prompt_func": create_gemini_prompt, "config_modifier": lambda c: c},
app/routes/models_api.py CHANGED
@@ -12,6 +12,8 @@ router = APIRouter()
12
  async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_key)):
13
  await refresh_models_config_cache()
14
 
 
 
15
  # Access credential_manager from app state
16
  credential_manager_instance: CredentialManager = fastapi_request.app.state.credential_manager
17
 
@@ -80,7 +82,20 @@ async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_k
80
  "permission": [], "root": model_id, "parent": None
81
  })
82
 
83
- # Ensure uniqueness again after adding suffixes
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  final_models_data_map = {m["id"]: m for m in dynamic_models_data}
85
 
86
  return {"object": "list", "data": list(final_models_data_map.values())}
 
12
  async def list_models(fastapi_request: Request, api_key: str = Depends(get_api_key)):
13
  await refresh_models_config_cache()
14
 
15
+ OPENAI_DIRECT_SUFFIX = "-openai"
16
+ EXPERIMENTAL_MARKER = "-exp-"
17
  # Access credential_manager from app state
18
  credential_manager_instance: CredentialManager = fastapi_request.app.state.credential_manager
19
 
 
82
  "permission": [], "root": model_id, "parent": None
83
  })
84
 
85
+ # Ensure uniqueness again after adding suffixes
86
+ # Add OpenAI direct variations for experimental models if SA creds are available
87
+ if has_sa_creds: # OpenAI direct mode only works with SA credentials
88
+ # We should iterate through the base models that could be experimental.
89
+ # `raw_vertex_models` should contain these.
90
+ for model_id in raw_vertex_models: # Iterate through the original list of base models
91
+ if EXPERIMENTAL_MARKER in model_id:
92
+ suffixed_id = f"{model_id}{OPENAI_DIRECT_SUFFIX}"
93
+ # Check if already added (e.g. if remote config somehow already listed it)
94
+ if not any(m['id'] == suffixed_id for m in dynamic_models_data):
95
+ dynamic_models_data.append({
96
+ "id": suffixed_id, "object": "model", "created": current_time, "owned_by": "google",
97
+ "permission": [], "root": model_id, "parent": None
98
+ })
99
  final_models_data_map = {m["id"]: m for m in dynamic_models_data}
100
 
101
  return {"object": "list", "data": list(final_models_data_map.values())}