bibibi12345 commited on
Commit
db96d95
·
1 Parent(s): 5c95d1b

added global region to express mode

Browse files
Files changed (1) hide show
  1. app/direct_vertex_client.py +423 -0
app/direct_vertex_client.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ import asyncio
3
+ import json
4
+ import re
5
+ from typing import Dict, Any, List, Union, Optional, AsyncGenerator
6
+ import time
7
+
8
+ # Global cache for project IDs: {api_key: project_id}
9
+ PROJECT_ID_CACHE: Dict[str, str] = {}
10
+
11
+
12
+ class DirectVertexClient:
13
+ """
14
+ A client that connects to Vertex AI using direct URLs instead of the SDK.
15
+ Mimics the interface of genai.Client for seamless integration.
16
+ """
17
+
18
+ def __init__(self, api_key: str):
19
+ self.api_key = api_key
20
+ self.project_id: Optional[str] = None
21
+ self.base_url = "https://aiplatform.googleapis.com/v1"
22
+ self.session: Optional[aiohttp.ClientSession] = None
23
+ # Mimic the model_name attribute that might be accessed
24
+ self.model_name = "direct_vertex_client"
25
+
26
+ # Create nested structure to mimic genai.Client interface
27
+ self.aio = self._AioNamespace(self)
28
+
29
+ class _AioNamespace:
30
+ def __init__(self, parent):
31
+ self.parent = parent
32
+ self.models = self._ModelsNamespace(parent)
33
+
34
+ class _ModelsNamespace:
35
+ def __init__(self, parent):
36
+ self.parent = parent
37
+
38
+ async def generate_content(self, model: str, contents: Any, config: Dict[str, Any]) -> Any:
39
+ """Non-streaming content generation"""
40
+ return await self.parent._generate_content(model, contents, config, stream=False)
41
+
42
+ async def generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]):
43
+ """Streaming content generation - returns an async generator"""
44
+ # This needs to be an async method that returns the generator
45
+ # to match the SDK's interface where you await the method call
46
+ return self.parent._generate_content_stream(model, contents, config)
47
+
48
+ async def _ensure_session(self):
49
+ """Ensure aiohttp session is created"""
50
+ if self.session is None:
51
+ self.session = aiohttp.ClientSession()
52
+
53
+ async def close(self):
54
+ """Clean up resources"""
55
+ if self.session:
56
+ await self.session.close()
57
+ self.session = None
58
+
59
+ async def discover_project_id(self) -> None:
60
+ """Discover project ID by triggering an intentional error"""
61
+ # Check cache first
62
+ if self.api_key in PROJECT_ID_CACHE:
63
+ self.project_id = PROJECT_ID_CACHE[self.api_key]
64
+ print(f"INFO: Using cached project ID: {self.project_id}")
65
+ return
66
+
67
+ await self._ensure_session()
68
+
69
+ # Use a non-existent model to trigger error
70
+ error_url = f"{self.base_url}/publishers/google/models/gemini-2.7-pro-preview-05-06:streamGenerateContent?key={self.api_key}"
71
+
72
+ try:
73
+ # Send minimal request to trigger error
74
+ payload = {
75
+ "contents": [{"role": "user", "parts": [{"text": "test"}]}]
76
+ }
77
+
78
+ async with self.session.post(error_url, json=payload) as response:
79
+ response_text = await response.text()
80
+
81
+ try:
82
+ # Try to parse as JSON first
83
+ error_data = json.loads(response_text)
84
+
85
+ # Handle array response format
86
+ if isinstance(error_data, list) and len(error_data) > 0:
87
+ error_data = error_data[0]
88
+
89
+ if "error" in error_data:
90
+ error_message = error_data["error"].get("message", "")
91
+ # Extract project ID from error message
92
+ # Pattern: "projects/39982734461/locations/..."
93
+ match = re.search(r'projects/(\d+)/locations/', error_message)
94
+ if match:
95
+ self.project_id = match.group(1)
96
+ PROJECT_ID_CACHE[self.api_key] = self.project_id
97
+ print(f"INFO: Discovered project ID: {self.project_id}")
98
+ return
99
+ except json.JSONDecodeError:
100
+ # If not JSON, try to find project ID in raw text
101
+ match = re.search(r'projects/(\d+)/locations/', response_text)
102
+ if match:
103
+ self.project_id = match.group(1)
104
+ PROJECT_ID_CACHE[self.api_key] = self.project_id
105
+ print(f"INFO: Discovered project ID from raw response: {self.project_id}")
106
+ return
107
+
108
+ raise Exception(f"Failed to discover project ID. Status: {response.status}, Response: {response_text[:500]}")
109
+
110
+ except Exception as e:
111
+ print(f"ERROR: Failed to discover project ID: {e}")
112
+ raise
113
+
114
+ def _convert_contents(self, contents: Any) -> List[Dict[str, Any]]:
115
+ """Convert SDK Content objects to REST API format"""
116
+ if isinstance(contents, list):
117
+ return [self._convert_content_item(item) for item in contents]
118
+ else:
119
+ return [self._convert_content_item(contents)]
120
+
121
+ def _convert_content_item(self, content: Any) -> Dict[str, Any]:
122
+ """Convert a single content item to REST API format"""
123
+ if isinstance(content, dict):
124
+ return content
125
+
126
+ # Handle SDK Content objects
127
+ result = {}
128
+ if hasattr(content, 'role'):
129
+ result['role'] = content.role
130
+ if hasattr(content, 'parts'):
131
+ result['parts'] = []
132
+ for part in content.parts:
133
+ if isinstance(part, dict):
134
+ result['parts'].append(part)
135
+ elif hasattr(part, 'text'):
136
+ result['parts'].append({'text': part.text})
137
+ elif hasattr(part, 'inline_data'):
138
+ result['parts'].append({
139
+ 'inline_data': {
140
+ 'mime_type': part.inline_data.mime_type,
141
+ 'data': part.inline_data.data
142
+ }
143
+ })
144
+ return result
145
+
146
+ def _convert_safety_settings(self, safety_settings: Any) -> List[Dict[str, str]]:
147
+ """Convert SDK SafetySetting objects to REST API format"""
148
+ if not safety_settings:
149
+ return []
150
+
151
+ result = []
152
+ for setting in safety_settings:
153
+ if isinstance(setting, dict):
154
+ result.append(setting)
155
+ elif hasattr(setting, 'category') and hasattr(setting, 'threshold'):
156
+ # Convert SDK SafetySetting to dict
157
+ result.append({
158
+ 'category': setting.category,
159
+ 'threshold': setting.threshold
160
+ })
161
+ return result
162
+
163
+ def _convert_tools(self, tools: Any) -> List[Dict[str, Any]]:
164
+ """Convert SDK Tool objects to REST API format"""
165
+ if not tools:
166
+ return []
167
+
168
+ result = []
169
+ for tool in tools:
170
+ if isinstance(tool, dict):
171
+ result.append(tool)
172
+ else:
173
+ # Convert SDK Tool object to dict
174
+ result.append(self._convert_tool_item(tool))
175
+ return result
176
+
177
+ def _convert_tool_item(self, tool: Any) -> Dict[str, Any]:
178
+ """Convert a single tool item to REST API format"""
179
+ if isinstance(tool, dict):
180
+ return tool
181
+
182
+ tool_dict = {}
183
+
184
+ # Convert all non-private attributes
185
+ if hasattr(tool, '__dict__'):
186
+ for attr_name, attr_value in tool.__dict__.items():
187
+ if not attr_name.startswith('_'):
188
+ # Convert attribute names from snake_case to camelCase for REST API
189
+ rest_api_name = self._to_camel_case(attr_name)
190
+
191
+ # Special handling for known types
192
+ if attr_name == 'google_search' and attr_value is not None:
193
+ tool_dict[rest_api_name] = {} # GoogleSearch is empty object in REST
194
+ elif attr_name == 'function_declarations' and attr_value is not None:
195
+ tool_dict[rest_api_name] = attr_value
196
+ elif attr_value is not None:
197
+ # Recursively convert any other SDK objects
198
+ tool_dict[rest_api_name] = self._convert_sdk_object(attr_value)
199
+
200
+ return tool_dict
201
+
202
+ def _to_camel_case(self, snake_str: str) -> str:
203
+ """Convert snake_case to camelCase"""
204
+ components = snake_str.split('_')
205
+ return components[0] + ''.join(x.title() for x in components[1:])
206
+
207
+ def _convert_sdk_object(self, obj: Any) -> Any:
208
+ """Generic SDK object converter"""
209
+ if isinstance(obj, (str, int, float, bool, type(None))):
210
+ return obj
211
+ elif isinstance(obj, dict):
212
+ return {k: self._convert_sdk_object(v) for k, v in obj.items()}
213
+ elif isinstance(obj, list):
214
+ return [self._convert_sdk_object(item) for item in obj]
215
+ elif hasattr(obj, '__dict__'):
216
+ # Convert SDK object to dict
217
+ result = {}
218
+ for key, value in obj.__dict__.items():
219
+ if not key.startswith('_'):
220
+ result[self._to_camel_case(key)] = self._convert_sdk_object(value)
221
+ return result
222
+ else:
223
+ return obj
224
+
225
+ async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any:
226
+ """Internal method for content generation"""
227
+ if not self.project_id:
228
+ raise ValueError("Project ID not discovered. Call discover_project_id() first.")
229
+
230
+ await self._ensure_session()
231
+
232
+ # Build URL
233
+ endpoint = "streamGenerateContent" if stream else "generateContent"
234
+ url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}"
235
+
236
+ # Convert contents to REST API format
237
+ payload = {
238
+ "contents": self._convert_contents(contents)
239
+ }
240
+
241
+ # Extract specific config sections
242
+ if "system_instruction" in config:
243
+ # System instruction should be a content object
244
+ if isinstance(config["system_instruction"], dict):
245
+ payload["systemInstruction"] = config["system_instruction"]
246
+ else:
247
+ payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
248
+
249
+ if "safety_settings" in config:
250
+ payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
251
+
252
+ if "tools" in config:
253
+ payload["tools"] = self._convert_tools(config["tools"])
254
+
255
+ # All other config goes under generationConfig
256
+ generation_config = {}
257
+ for key, value in config.items():
258
+ if key not in ["system_instruction", "safety_settings", "tools"]:
259
+ generation_config[key] = value
260
+
261
+ if generation_config:
262
+ payload["generationConfig"] = generation_config
263
+
264
+ try:
265
+ async with self.session.post(url, json=payload) as response:
266
+ if response.status != 200:
267
+ error_data = await response.json()
268
+ error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}")
269
+ raise Exception(f"Vertex AI API error: {error_msg}")
270
+
271
+ # Get the JSON response
272
+ response_data = await response.json()
273
+
274
+ # Convert dict to object with attributes for compatibility
275
+ return self._dict_to_obj(response_data)
276
+
277
+ except Exception as e:
278
+ print(f"ERROR: Direct Vertex API call failed: {e}")
279
+ raise
280
+
281
+ def _dict_to_obj(self, data):
282
+ """Convert a dict to an object with attributes"""
283
+ if isinstance(data, dict):
284
+ # Create a simple object that allows attribute access
285
+ class AttrDict:
286
+ def __init__(self, d):
287
+ for key, value in d.items():
288
+ setattr(self, key, self._convert_value(value))
289
+
290
+ def _convert_value(self, value):
291
+ if isinstance(value, dict):
292
+ return AttrDict(value)
293
+ elif isinstance(value, list):
294
+ return [self._convert_value(item) for item in value]
295
+ else:
296
+ return value
297
+
298
+ return AttrDict(data)
299
+ elif isinstance(data, list):
300
+ return [self._dict_to_obj(item) for item in data]
301
+ else:
302
+ return data
303
+
304
+ async def _generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]) -> AsyncGenerator:
305
+ """Internal method for streaming content generation"""
306
+ if not self.project_id:
307
+ raise ValueError("Project ID not discovered. Call discover_project_id() first.")
308
+
309
+ await self._ensure_session()
310
+
311
+ # Build URL for streaming
312
+ url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}"
313
+
314
+ # Convert contents to REST API format
315
+ payload = {
316
+ "contents": self._convert_contents(contents)
317
+ }
318
+
319
+ # Extract specific config sections
320
+ if "system_instruction" in config:
321
+ # System instruction should be a content object
322
+ if isinstance(config["system_instruction"], dict):
323
+ payload["systemInstruction"] = config["system_instruction"]
324
+ else:
325
+ payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
326
+
327
+ if "safety_settings" in config:
328
+ payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
329
+
330
+ if "tools" in config:
331
+ payload["tools"] = self._convert_tools(config["tools"])
332
+
333
+ # All other config goes under generationConfig
334
+ generation_config = {}
335
+ for key, value in config.items():
336
+ if key not in ["system_instruction", "safety_settings", "tools"]:
337
+ generation_config[key] = value
338
+
339
+ if generation_config:
340
+ payload["generationConfig"] = generation_config
341
+
342
+ try:
343
+ async with self.session.post(url, json=payload) as response:
344
+ if response.status != 200:
345
+ error_data = await response.json()
346
+ # Handle array response format
347
+ if isinstance(error_data, list) and len(error_data) > 0:
348
+ error_data = error_data[0]
349
+ error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}") if isinstance(error_data, dict) else str(error_data)
350
+ raise Exception(f"Vertex AI API error: {error_msg}")
351
+
352
+ # The Vertex AI streaming endpoint returns JSON array elements
353
+ # We need to parse these as they arrive
354
+ buffer = ""
355
+
356
+ async for chunk in response.content.iter_any():
357
+ decoded_chunk = chunk.decode('utf-8')
358
+ buffer += decoded_chunk
359
+
360
+ # Try to extract complete JSON objects from the buffer
361
+ while True:
362
+ # Skip whitespace and array brackets
363
+ buffer = buffer.lstrip()
364
+ if buffer.startswith('['):
365
+ buffer = buffer[1:].lstrip()
366
+ continue
367
+ if buffer.startswith(']'):
368
+ # End of array
369
+ return
370
+
371
+ # Skip comma and whitespace between objects
372
+ if buffer.startswith(','):
373
+ buffer = buffer[1:].lstrip()
374
+ continue
375
+
376
+ # Look for a complete JSON object
377
+ if buffer.startswith('{'):
378
+ # Find the matching closing brace
379
+ brace_count = 0
380
+ in_string = False
381
+ escape_next = False
382
+
383
+ for i, char in enumerate(buffer):
384
+ if escape_next:
385
+ escape_next = False
386
+ continue
387
+
388
+ if char == '\\' and in_string:
389
+ escape_next = True
390
+ continue
391
+
392
+ if char == '"' and not in_string:
393
+ in_string = True
394
+ elif char == '"' and in_string:
395
+ in_string = False
396
+ elif char == '{' and not in_string:
397
+ brace_count += 1
398
+ elif char == '}' and not in_string:
399
+ brace_count -= 1
400
+
401
+ if brace_count == 0:
402
+ # Found complete object
403
+ obj_str = buffer[:i+1]
404
+ buffer = buffer[i+1:]
405
+
406
+ try:
407
+ chunk_data = json.loads(obj_str)
408
+ converted_obj = self._dict_to_obj(chunk_data)
409
+ yield converted_obj
410
+ except json.JSONDecodeError as e:
411
+ print(f"ERROR: DirectVertexClient - Failed to parse JSON: {e}")
412
+
413
+ break
414
+ else:
415
+ # No complete object found, need more data
416
+ break
417
+ else:
418
+ # No more objects to process in current buffer
419
+ break
420
+
421
+ except Exception as e:
422
+ print(f"ERROR: Direct Vertex streaming API call failed: {e}")
423
+ raise