bibibi12345 commited on
Commit
0796d75
·
1 Parent(s): b1f7ea4

use official sdk for express mode 0605

Browse files
app/api_helpers.py CHANGED
@@ -254,7 +254,7 @@ def is_gemini_response_valid(response: Any) -> bool:
254
  if hasattr(response, 'text') and isinstance(response.text, str) and response.text.strip():
255
  return True
256
 
257
- # Check for candidates (both SDK and DirectVertexClient responses)
258
  if hasattr(response, 'candidates') and response.candidates:
259
  for candidate in response.candidates:
260
  # Check for direct text on candidate
 
254
  if hasattr(response, 'text') and isinstance(response.text, str) and response.text.strip():
255
  return True
256
 
257
+ # Check for candidates in the response
258
  if hasattr(response, 'candidates') and response.candidates:
259
  for candidate in response.candidates:
260
  # Check for direct text on candidate
app/direct_vertex_client.py DELETED
@@ -1,423 +0,0 @@
1
- import aiohttp
2
- import asyncio
3
- import json
4
- import re
5
- from typing import Dict, Any, List, Union, Optional, AsyncGenerator
6
- import time
7
-
8
- # Global cache for project IDs: {api_key: project_id}
9
- PROJECT_ID_CACHE: Dict[str, str] = {}
10
-
11
-
12
- class DirectVertexClient:
13
- """
14
- A client that connects to Vertex AI using direct URLs instead of the SDK.
15
- Mimics the interface of genai.Client for seamless integration.
16
- """
17
-
18
- def __init__(self, api_key: str):
19
- self.api_key = api_key
20
- self.project_id: Optional[str] = None
21
- self.base_url = "https://aiplatform.googleapis.com/v1"
22
- self.session: Optional[aiohttp.ClientSession] = None
23
- # Mimic the model_name attribute that might be accessed
24
- self.model_name = "direct_vertex_client"
25
-
26
- # Create nested structure to mimic genai.Client interface
27
- self.aio = self._AioNamespace(self)
28
-
29
- class _AioNamespace:
30
- def __init__(self, parent):
31
- self.parent = parent
32
- self.models = self._ModelsNamespace(parent)
33
-
34
- class _ModelsNamespace:
35
- def __init__(self, parent):
36
- self.parent = parent
37
-
38
- async def generate_content(self, model: str, contents: Any, config: Dict[str, Any]) -> Any:
39
- """Non-streaming content generation"""
40
- return await self.parent._generate_content(model, contents, config, stream=False)
41
-
42
- async def generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]):
43
- """Streaming content generation - returns an async generator"""
44
- # This needs to be an async method that returns the generator
45
- # to match the SDK's interface where you await the method call
46
- return self.parent._generate_content_stream(model, contents, config)
47
-
48
- async def _ensure_session(self):
49
- """Ensure aiohttp session is created"""
50
- if self.session is None:
51
- self.session = aiohttp.ClientSession()
52
-
53
- async def close(self):
54
- """Clean up resources"""
55
- if self.session:
56
- await self.session.close()
57
- self.session = None
58
-
59
- async def discover_project_id(self) -> None:
60
- """Discover project ID by triggering an intentional error"""
61
- # Check cache first
62
- if self.api_key in PROJECT_ID_CACHE:
63
- self.project_id = PROJECT_ID_CACHE[self.api_key]
64
- print(f"INFO: Using cached project ID: {self.project_id}")
65
- return
66
-
67
- await self._ensure_session()
68
-
69
- # Use a non-existent model to trigger error
70
- error_url = f"{self.base_url}/publishers/google/models/gemini-2.7-pro-preview-05-06:streamGenerateContent?key={self.api_key}"
71
-
72
- try:
73
- # Send minimal request to trigger error
74
- payload = {
75
- "contents": [{"role": "user", "parts": [{"text": "test"}]}]
76
- }
77
-
78
- async with self.session.post(error_url, json=payload) as response:
79
- response_text = await response.text()
80
-
81
- try:
82
- # Try to parse as JSON first
83
- error_data = json.loads(response_text)
84
-
85
- # Handle array response format
86
- if isinstance(error_data, list) and len(error_data) > 0:
87
- error_data = error_data[0]
88
-
89
- if "error" in error_data:
90
- error_message = error_data["error"].get("message", "")
91
- # Extract project ID from error message
92
- # Pattern: "projects/39982734461/locations/..."
93
- match = re.search(r'projects/(\d+)/locations/', error_message)
94
- if match:
95
- self.project_id = match.group(1)
96
- PROJECT_ID_CACHE[self.api_key] = self.project_id
97
- print(f"INFO: Discovered project ID: {self.project_id}")
98
- return
99
- except json.JSONDecodeError:
100
- # If not JSON, try to find project ID in raw text
101
- match = re.search(r'projects/(\d+)/locations/', response_text)
102
- if match:
103
- self.project_id = match.group(1)
104
- PROJECT_ID_CACHE[self.api_key] = self.project_id
105
- print(f"INFO: Discovered project ID from raw response: {self.project_id}")
106
- return
107
-
108
- raise Exception(f"Failed to discover project ID. Status: {response.status}, Response: {response_text[:500]}")
109
-
110
- except Exception as e:
111
- print(f"ERROR: Failed to discover project ID: {e}")
112
- raise
113
-
114
- def _convert_contents(self, contents: Any) -> List[Dict[str, Any]]:
115
- """Convert SDK Content objects to REST API format"""
116
- if isinstance(contents, list):
117
- return [self._convert_content_item(item) for item in contents]
118
- else:
119
- return [self._convert_content_item(contents)]
120
-
121
- def _convert_content_item(self, content: Any) -> Dict[str, Any]:
122
- """Convert a single content item to REST API format"""
123
- if isinstance(content, dict):
124
- return content
125
-
126
- # Handle SDK Content objects
127
- result = {}
128
- if hasattr(content, 'role'):
129
- result['role'] = content.role
130
- if hasattr(content, 'parts'):
131
- result['parts'] = []
132
- for part in content.parts:
133
- if isinstance(part, dict):
134
- result['parts'].append(part)
135
- elif hasattr(part, 'text'):
136
- result['parts'].append({'text': part.text})
137
- elif hasattr(part, 'inline_data'):
138
- result['parts'].append({
139
- 'inline_data': {
140
- 'mime_type': part.inline_data.mime_type,
141
- 'data': part.inline_data.data
142
- }
143
- })
144
- return result
145
-
146
- def _convert_safety_settings(self, safety_settings: Any) -> List[Dict[str, str]]:
147
- """Convert SDK SafetySetting objects to REST API format"""
148
- if not safety_settings:
149
- return []
150
-
151
- result = []
152
- for setting in safety_settings:
153
- if isinstance(setting, dict):
154
- result.append(setting)
155
- elif hasattr(setting, 'category') and hasattr(setting, 'threshold'):
156
- # Convert SDK SafetySetting to dict
157
- result.append({
158
- 'category': setting.category,
159
- 'threshold': setting.threshold
160
- })
161
- return result
162
-
163
- def _convert_tools(self, tools: Any) -> List[Dict[str, Any]]:
164
- """Convert SDK Tool objects to REST API format"""
165
- if not tools:
166
- return []
167
-
168
- result = []
169
- for tool in tools:
170
- if isinstance(tool, dict):
171
- result.append(tool)
172
- else:
173
- # Convert SDK Tool object to dict
174
- result.append(self._convert_tool_item(tool))
175
- return result
176
-
177
- def _convert_tool_item(self, tool: Any) -> Dict[str, Any]:
178
- """Convert a single tool item to REST API format"""
179
- if isinstance(tool, dict):
180
- return tool
181
-
182
- tool_dict = {}
183
-
184
- # Convert all non-private attributes
185
- if hasattr(tool, '__dict__'):
186
- for attr_name, attr_value in tool.__dict__.items():
187
- if not attr_name.startswith('_'):
188
- # Convert attribute names from snake_case to camelCase for REST API
189
- rest_api_name = self._to_camel_case(attr_name)
190
-
191
- # Special handling for known types
192
- if attr_name == 'google_search' and attr_value is not None:
193
- tool_dict[rest_api_name] = {} # GoogleSearch is empty object in REST
194
- elif attr_name == 'function_declarations' and attr_value is not None:
195
- tool_dict[rest_api_name] = attr_value
196
- elif attr_value is not None:
197
- # Recursively convert any other SDK objects
198
- tool_dict[rest_api_name] = self._convert_sdk_object(attr_value)
199
-
200
- return tool_dict
201
-
202
- def _to_camel_case(self, snake_str: str) -> str:
203
- """Convert snake_case to camelCase"""
204
- components = snake_str.split('_')
205
- return components[0] + ''.join(x.title() for x in components[1:])
206
-
207
- def _convert_sdk_object(self, obj: Any) -> Any:
208
- """Generic SDK object converter"""
209
- if isinstance(obj, (str, int, float, bool, type(None))):
210
- return obj
211
- elif isinstance(obj, dict):
212
- return {k: self._convert_sdk_object(v) for k, v in obj.items()}
213
- elif isinstance(obj, list):
214
- return [self._convert_sdk_object(item) for item in obj]
215
- elif hasattr(obj, '__dict__'):
216
- # Convert SDK object to dict
217
- result = {}
218
- for key, value in obj.__dict__.items():
219
- if not key.startswith('_'):
220
- result[self._to_camel_case(key)] = self._convert_sdk_object(value)
221
- return result
222
- else:
223
- return obj
224
-
225
- async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any:
226
- """Internal method for content generation"""
227
- if not self.project_id:
228
- raise ValueError("Project ID not discovered. Call discover_project_id() first.")
229
-
230
- await self._ensure_session()
231
-
232
- # Build URL
233
- endpoint = "streamGenerateContent" if stream else "generateContent"
234
- url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}"
235
-
236
- # Convert contents to REST API format
237
- payload = {
238
- "contents": self._convert_contents(contents)
239
- }
240
-
241
- # Extract specific config sections
242
- if "system_instruction" in config:
243
- # System instruction should be a content object
244
- if isinstance(config["system_instruction"], dict):
245
- payload["systemInstruction"] = config["system_instruction"]
246
- else:
247
- payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
248
-
249
- if "safety_settings" in config:
250
- payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
251
-
252
- if "tools" in config:
253
- payload["tools"] = self._convert_tools(config["tools"])
254
-
255
- # All other config goes under generationConfig
256
- generation_config = {}
257
- for key, value in config.items():
258
- if key not in ["system_instruction", "safety_settings", "tools"]:
259
- generation_config[key] = value
260
-
261
- if generation_config:
262
- payload["generationConfig"] = generation_config
263
-
264
- try:
265
- async with self.session.post(url, json=payload) as response:
266
- if response.status != 200:
267
- error_data = await response.json()
268
- error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}")
269
- raise Exception(f"Vertex AI API error: {error_msg}")
270
-
271
- # Get the JSON response
272
- response_data = await response.json()
273
-
274
- # Convert dict to object with attributes for compatibility
275
- return self._dict_to_obj(response_data)
276
-
277
- except Exception as e:
278
- print(f"ERROR: Direct Vertex API call failed: {e}")
279
- raise
280
-
281
- def _dict_to_obj(self, data):
282
- """Convert a dict to an object with attributes"""
283
- if isinstance(data, dict):
284
- # Create a simple object that allows attribute access
285
- class AttrDict:
286
- def __init__(self, d):
287
- for key, value in d.items():
288
- setattr(self, key, self._convert_value(value))
289
-
290
- def _convert_value(self, value):
291
- if isinstance(value, dict):
292
- return AttrDict(value)
293
- elif isinstance(value, list):
294
- return [self._convert_value(item) for item in value]
295
- else:
296
- return value
297
-
298
- return AttrDict(data)
299
- elif isinstance(data, list):
300
- return [self._dict_to_obj(item) for item in data]
301
- else:
302
- return data
303
-
304
- async def _generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]) -> AsyncGenerator:
305
- """Internal method for streaming content generation"""
306
- if not self.project_id:
307
- raise ValueError("Project ID not discovered. Call discover_project_id() first.")
308
-
309
- await self._ensure_session()
310
-
311
- # Build URL for streaming
312
- url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}"
313
-
314
- # Convert contents to REST API format
315
- payload = {
316
- "contents": self._convert_contents(contents)
317
- }
318
-
319
- # Extract specific config sections
320
- if "system_instruction" in config:
321
- # System instruction should be a content object
322
- if isinstance(config["system_instruction"], dict):
323
- payload["systemInstruction"] = config["system_instruction"]
324
- else:
325
- payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
326
-
327
- if "safety_settings" in config:
328
- payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
329
-
330
- if "tools" in config:
331
- payload["tools"] = self._convert_tools(config["tools"])
332
-
333
- # All other config goes under generationConfig
334
- generation_config = {}
335
- for key, value in config.items():
336
- if key not in ["system_instruction", "safety_settings", "tools"]:
337
- generation_config[key] = value
338
-
339
- if generation_config:
340
- payload["generationConfig"] = generation_config
341
-
342
- try:
343
- async with self.session.post(url, json=payload) as response:
344
- if response.status != 200:
345
- error_data = await response.json()
346
- # Handle array response format
347
- if isinstance(error_data, list) and len(error_data) > 0:
348
- error_data = error_data[0]
349
- error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}") if isinstance(error_data, dict) else str(error_data)
350
- raise Exception(f"Vertex AI API error: {error_msg}")
351
-
352
- # The Vertex AI streaming endpoint returns JSON array elements
353
- # We need to parse these as they arrive
354
- buffer = ""
355
-
356
- async for chunk in response.content.iter_any():
357
- decoded_chunk = chunk.decode('utf-8')
358
- buffer += decoded_chunk
359
-
360
- # Try to extract complete JSON objects from the buffer
361
- while True:
362
- # Skip whitespace and array brackets
363
- buffer = buffer.lstrip()
364
- if buffer.startswith('['):
365
- buffer = buffer[1:].lstrip()
366
- continue
367
- if buffer.startswith(']'):
368
- # End of array
369
- return
370
-
371
- # Skip comma and whitespace between objects
372
- if buffer.startswith(','):
373
- buffer = buffer[1:].lstrip()
374
- continue
375
-
376
- # Look for a complete JSON object
377
- if buffer.startswith('{'):
378
- # Find the matching closing brace
379
- brace_count = 0
380
- in_string = False
381
- escape_next = False
382
-
383
- for i, char in enumerate(buffer):
384
- if escape_next:
385
- escape_next = False
386
- continue
387
-
388
- if char == '\\' and in_string:
389
- escape_next = True
390
- continue
391
-
392
- if char == '"' and not in_string:
393
- in_string = True
394
- elif char == '"' and in_string:
395
- in_string = False
396
- elif char == '{' and not in_string:
397
- brace_count += 1
398
- elif char == '}' and not in_string:
399
- brace_count -= 1
400
-
401
- if brace_count == 0:
402
- # Found complete object
403
- obj_str = buffer[:i+1]
404
- buffer = buffer[i+1:]
405
-
406
- try:
407
- chunk_data = json.loads(obj_str)
408
- converted_obj = self._dict_to_obj(chunk_data)
409
- yield converted_obj
410
- except json.JSONDecodeError as e:
411
- print(f"ERROR: DirectVertexClient - Failed to parse JSON: {e}")
412
-
413
- break
414
- else:
415
- # No complete object found, need more data
416
- break
417
- else:
418
- # No more objects to process in current buffer
419
- break
420
-
421
- except Exception as e:
422
- print(f"ERROR: Direct Vertex streaming API call failed: {e}")
423
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/project_id_discovery.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ import json
3
+ import re
4
+ from typing import Dict, Optional
5
+
6
+ # Global cache for project IDs: {api_key: project_id}
7
+ PROJECT_ID_CACHE: Dict[str, str] = {}
8
+
9
+
10
+ async def discover_project_id(api_key: str) -> str:
11
+ """
12
+ Discover project ID by triggering an intentional error with a non-existent model.
13
+ The project ID is extracted from the error message and cached for future use.
14
+
15
+ Args:
16
+ api_key: The Vertex AI Express API key
17
+
18
+ Returns:
19
+ The discovered project ID
20
+
21
+ Raises:
22
+ Exception: If project ID discovery fails
23
+ """
24
+ # Check cache first
25
+ if api_key in PROJECT_ID_CACHE:
26
+ print(f"INFO: Using cached project ID: {PROJECT_ID_CACHE[api_key]}")
27
+ return PROJECT_ID_CACHE[api_key]
28
+
29
+ # Use a non-existent model to trigger error
30
+ error_url = f"https://aiplatform.googleapis.com/v1/publishers/google/models/gemini-2.7-pro-preview-05-06:streamGenerateContent?key={api_key}"
31
+
32
+ # Create minimal request payload
33
+ payload = {
34
+ "contents": [{"role": "user", "parts": [{"text": "test"}]}]
35
+ }
36
+
37
+ async with aiohttp.ClientSession() as session:
38
+ try:
39
+ async with session.post(error_url, json=payload) as response:
40
+ response_text = await response.text()
41
+
42
+ try:
43
+ # Try to parse as JSON first
44
+ error_data = json.loads(response_text)
45
+
46
+ # Handle array response format
47
+ if isinstance(error_data, list) and len(error_data) > 0:
48
+ error_data = error_data[0]
49
+
50
+ if "error" in error_data:
51
+ error_message = error_data["error"].get("message", "")
52
+ # Extract project ID from error message
53
+ # Pattern: "projects/39982734461/locations/..."
54
+ match = re.search(r'projects/(\d+)/locations/', error_message)
55
+ if match:
56
+ project_id = match.group(1)
57
+ PROJECT_ID_CACHE[api_key] = project_id
58
+ print(f"INFO: Discovered project ID: {project_id}")
59
+ return project_id
60
+ except json.JSONDecodeError:
61
+ # If not JSON, try to find project ID in raw text
62
+ match = re.search(r'projects/(\d+)/locations/', response_text)
63
+ if match:
64
+ project_id = match.group(1)
65
+ PROJECT_ID_CACHE[api_key] = project_id
66
+ print(f"INFO: Discovered project ID from raw response: {project_id}")
67
+ return project_id
68
+
69
+ raise Exception(f"Failed to discover project ID. Status: {response.status}, Response: {response_text[:500]}")
70
+
71
+ except Exception as e:
72
+ print(f"ERROR: Failed to discover project ID: {e}")
73
+ raise
app/routes/chat_api.py CHANGED
@@ -24,7 +24,7 @@ from api_helpers import (
24
  execute_gemini_call,
25
  )
26
  from openai_handler import OpenAIDirectHandler
27
- from direct_vertex_client import DirectVertexClient
28
 
29
  router = APIRouter()
30
 
@@ -118,9 +118,14 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
118
  try:
119
  # Check if model contains "gemini-2.5-pro" or "gemini-2.5-flash" for direct URL approach
120
  if "gemini-2.5-pro" in base_model_name or "gemini-2.5-flash" in base_model_name:
121
- client_to_use = DirectVertexClient(api_key=key_val)
122
- await client_to_use.discover_project_id()
123
- print(f"INFO: Attempt {attempt+1}/{total_keys} - Using DirectVertexClient for model {request.model} (base: {base_model_name}) with API key (original index: {original_idx}).")
 
 
 
 
 
124
  else:
125
  client_to_use = genai.Client(vertexai=True, api_key=key_val)
126
  print(f"INFO: Attempt {attempt+1}/{total_keys} - Using Vertex Express Mode SDK for model {request.model} (base: {base_model_name}) with API key (original index: {original_idx}).")
@@ -185,9 +190,6 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
185
  try:
186
  # Pass is_auto_attempt=True for auto-mode calls
187
  result = await execute_gemini_call(client_to_use, attempt["model"], attempt["prompt_func"], current_gen_config, request, is_auto_attempt=True)
188
- # Clean up DirectVertexClient session if used
189
- if isinstance(client_to_use, DirectVertexClient):
190
- await client_to_use.close()
191
  return result
192
  except Exception as e_auto:
193
  last_err = e_auto
@@ -196,9 +198,6 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
196
 
197
  print(f"All auto attempts failed. Last error: {last_err}")
198
  err_msg = f"All auto-mode attempts failed for model {request.model}. Last error: {str(last_err)}"
199
- # Clean up DirectVertexClient session if used
200
- if isinstance(client_to_use, DirectVertexClient):
201
- await client_to_use.close()
202
  if not request.stream and last_err:
203
  return JSONResponse(status_code=500, content=create_openai_error_response(500, err_msg, "server_error"))
204
  elif request.stream:
@@ -245,17 +244,9 @@ async def chat_completions(fastapi_request: Request, request: OpenAIRequest, api
245
  # but the API call might need the full "gemini-1.5-pro-search".
246
  # Let's use `request.model` for the API call here, and `base_model_name` for checks like Express eligibility.
247
  # For non-auto mode, is_auto_attempt defaults to False in execute_gemini_call
248
- try:
249
- return await execute_gemini_call(client_to_use, base_model_name, current_prompt_func, generation_config, request)
250
- finally:
251
- # Clean up DirectVertexClient session if used
252
- if isinstance(client_to_use, DirectVertexClient):
253
- await client_to_use.close()
254
 
255
  except Exception as e:
256
  error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
257
  print(error_msg)
258
- # Clean up DirectVertexClient session if it exists
259
- if 'client_to_use' in locals() and isinstance(client_to_use, DirectVertexClient):
260
- await client_to_use.close()
261
  return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))
 
24
  execute_gemini_call,
25
  )
26
  from openai_handler import OpenAIDirectHandler
27
+ from project_id_discovery import discover_project_id
28
 
29
  router = APIRouter()
30
 
 
118
  try:
119
  # Check if model contains "gemini-2.5-pro" or "gemini-2.5-flash" for direct URL approach
120
  if "gemini-2.5-pro" in base_model_name or "gemini-2.5-flash" in base_model_name:
121
+ project_id = await discover_project_id(key_val)
122
+ base_url = f"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global"
123
+ client_to_use = genai.Client(
124
+ vertexai=True,
125
+ api_key=key_val,
126
+ http_options=types.HttpOptions(base_url=base_url)
127
+ )
128
+ print(f"INFO: Attempt {attempt+1}/{total_keys} - Using Vertex Express Mode with custom base URL for model {request.model} (base: {base_model_name}) with API key (original index: {original_idx}).")
129
  else:
130
  client_to_use = genai.Client(vertexai=True, api_key=key_val)
131
  print(f"INFO: Attempt {attempt+1}/{total_keys} - Using Vertex Express Mode SDK for model {request.model} (base: {base_model_name}) with API key (original index: {original_idx}).")
 
190
  try:
191
  # Pass is_auto_attempt=True for auto-mode calls
192
  result = await execute_gemini_call(client_to_use, attempt["model"], attempt["prompt_func"], current_gen_config, request, is_auto_attempt=True)
 
 
 
193
  return result
194
  except Exception as e_auto:
195
  last_err = e_auto
 
198
 
199
  print(f"All auto attempts failed. Last error: {last_err}")
200
  err_msg = f"All auto-mode attempts failed for model {request.model}. Last error: {str(last_err)}"
 
 
 
201
  if not request.stream and last_err:
202
  return JSONResponse(status_code=500, content=create_openai_error_response(500, err_msg, "server_error"))
203
  elif request.stream:
 
244
  # but the API call might need the full "gemini-1.5-pro-search".
245
  # Let's use `request.model` for the API call here, and `base_model_name` for checks like Express eligibility.
246
  # For non-auto mode, is_auto_attempt defaults to False in execute_gemini_call
247
+ return await execute_gemini_call(client_to_use, base_model_name, current_prompt_func, generation_config, request)
 
 
 
 
 
248
 
249
  except Exception as e:
250
  error_msg = f"Unexpected error in chat_completions endpoint: {str(e)}"
251
  print(error_msg)
 
 
 
252
  return JSONResponse(status_code=500, content=create_openai_error_response(500, error_msg, "server_error"))