bibibi12345 commited on
Commit
b5abced
·
1 Parent(s): da7a18e

changed openai cot streaming handling. added roundrobin mode for credentials. various refactoring

Browse files
Files changed (2) hide show
  1. app/express_key_manager.py +97 -0
  2. app/openai_handler.py +284 -0
app/express_key_manager.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Optional, Tuple
3
+ import config as app_config
4
+
5
+
6
+ class ExpressKeyManager:
7
+ """
8
+ Manager for Vertex Express API keys with support for both random and round-robin selection strategies.
9
+ Similar to CredentialManager but specifically for Express API keys.
10
+ """
11
+
12
+ def __init__(self):
13
+ """Initialize the Express Key Manager with API keys from config."""
14
+ self.express_keys: List[str] = app_config.VERTEX_EXPRESS_API_KEY_VAL
15
+ self.round_robin_index: int = 0
16
+
17
+ def get_total_keys(self) -> int:
18
+ """Get the total number of available Express API keys."""
19
+ return len(self.express_keys)
20
+
21
+ def _get_key_with_index(self, key: str, index: int) -> Tuple[str, int]:
22
+ """Return a tuple of (key, original_index) for logging purposes."""
23
+ return (key, index)
24
+
25
+ def get_random_express_key(self) -> Optional[Tuple[str, int]]:
26
+ """
27
+ Get a random Express API key.
28
+ Returns (key, original_index) tuple or None if no keys available.
29
+ """
30
+ if not self.express_keys:
31
+ print("WARNING: No Express API keys available for selection.")
32
+ return None
33
+
34
+ print(f"DEBUG: Using random Express API key selection strategy.")
35
+
36
+ # Create list of indexed keys
37
+ indexed_keys = list(enumerate(self.express_keys))
38
+ # Shuffle to randomize order
39
+ random.shuffle(indexed_keys)
40
+
41
+ # Return the first key (which is random due to shuffle)
42
+ original_idx, key = indexed_keys[0]
43
+ return self._get_key_with_index(key, original_idx)
44
+
45
+ def get_roundrobin_express_key(self) -> Optional[Tuple[str, int]]:
46
+ """
47
+ Get an Express API key using round-robin selection.
48
+ Returns (key, original_index) tuple or None if no keys available.
49
+ """
50
+ if not self.express_keys:
51
+ print("WARNING: No Express API keys available for selection.")
52
+ return None
53
+
54
+ print(f"DEBUG: Using round-robin Express API key selection strategy.")
55
+
56
+ # Ensure round_robin_index is within bounds
57
+ if self.round_robin_index >= len(self.express_keys):
58
+ self.round_robin_index = 0
59
+
60
+ # Get the key at current index
61
+ key = self.express_keys[self.round_robin_index]
62
+ original_idx = self.round_robin_index
63
+
64
+ # Move to next index for next call
65
+ self.round_robin_index = (self.round_robin_index + 1) % len(self.express_keys)
66
+
67
+ return self._get_key_with_index(key, original_idx)
68
+
69
+ def get_express_api_key(self) -> Optional[Tuple[str, int]]:
70
+ """
71
+ Get an Express API key based on the configured selection strategy.
72
+ Checks ROUNDROBIN config and calls the appropriate method.
73
+ Returns (key, original_index) tuple or None if no keys available.
74
+ """
75
+ if app_config.ROUNDROBIN:
76
+ return self.get_roundrobin_express_key()
77
+ else:
78
+ return self.get_random_express_key()
79
+
80
+ def get_all_keys_indexed(self) -> List[Tuple[int, str]]:
81
+ """
82
+ Get all Express API keys with their indices.
83
+ Useful for retry logic where we need to try all keys.
84
+ Returns list of (original_index, key) tuples.
85
+ """
86
+ return list(enumerate(self.express_keys))
87
+
88
+ def refresh_keys(self):
89
+ """
90
+ Refresh the Express API keys from config.
91
+ This allows for dynamic updates if the config is reloaded.
92
+ """
93
+ self.express_keys = app_config.VERTEX_EXPRESS_API_KEY_VAL
94
+ # Reset round-robin index if keys changed
95
+ if self.round_robin_index >= len(self.express_keys):
96
+ self.round_robin_index = 0
97
+ print(f"INFO: Express API keys refreshed. Total keys: {self.get_total_keys()}")
app/openai_handler.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI handler module for creating clients and processing OpenAI Direct mode responses.
3
+ This module encapsulates all OpenAI-specific logic that was previously in chat_api.py.
4
+ """
5
+ import json
6
+ import time
7
+ import asyncio
8
+ from typing import Dict, Any, AsyncGenerator
9
+
10
+ from fastapi.responses import JSONResponse, StreamingResponse
11
+ import openai
12
+ from google.auth.transport.requests import Request as AuthRequest
13
+
14
+ from models import OpenAIRequest
15
+ from config import VERTEX_REASONING_TAG
16
+ import config as app_config
17
+ from api_helpers import (
18
+ create_openai_error_response,
19
+ openai_fake_stream_generator,
20
+ StreamingReasoningProcessor
21
+ )
22
+ from message_processing import extract_reasoning_by_tags
23
+ from credentials_manager import _refresh_auth
24
+
25
+
26
+ class OpenAIDirectHandler:
27
+ """Handles OpenAI Direct mode operations including client creation and response processing."""
28
+
29
+ def __init__(self, credential_manager):
30
+ self.credential_manager = credential_manager
31
+ self.safety_settings = [
32
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
33
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
34
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
35
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
36
+ {"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
37
+ ]
38
+
39
+ def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
40
+ """Create an OpenAI client configured for Vertex AI endpoint."""
41
+ endpoint_url = (
42
+ f"https://aiplatform.googleapis.com/v1beta1/"
43
+ f"projects/{project_id}/locations/{location}/endpoints/openapi"
44
+ )
45
+
46
+ return openai.AsyncOpenAI(
47
+ base_url=endpoint_url,
48
+ api_key=gcp_token, # OAuth token
49
+ )
50
+
51
+ def prepare_openai_params(self, request: OpenAIRequest, model_id: str) -> Dict[str, Any]:
52
+ """Prepare parameters for OpenAI API call."""
53
+ params = {
54
+ "model": model_id,
55
+ "messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
56
+ "temperature": request.temperature,
57
+ "max_tokens": request.max_tokens,
58
+ "top_p": request.top_p,
59
+ "stream": request.stream,
60
+ "stop": request.stop,
61
+ "seed": request.seed,
62
+ "n": request.n,
63
+ }
64
+ # Remove None values
65
+ return {k: v for k, v in params.items() if v is not None}
66
+
67
+ def prepare_extra_body(self) -> Dict[str, Any]:
68
+ """Prepare extra body parameters for OpenAI API call."""
69
+ return {
70
+ "extra_body": {
71
+ 'google': {
72
+ 'safety_settings': self.safety_settings,
73
+ 'thought_tag_marker': VERTEX_REASONING_TAG
74
+ }
75
+ }
76
+ }
77
+
78
+ async def handle_streaming_response(
79
+ self,
80
+ openai_client: openai.AsyncOpenAI,
81
+ openai_params: Dict[str, Any],
82
+ openai_extra_body: Dict[str, Any],
83
+ request: OpenAIRequest
84
+ ) -> StreamingResponse:
85
+ """Handle streaming responses for OpenAI Direct mode."""
86
+ if app_config.FAKE_STREAMING_ENABLED:
87
+ print(f"INFO: OpenAI Fake Streaming (SSE Simulation) ENABLED for model '{request.model}'.")
88
+ return StreamingResponse(
89
+ openai_fake_stream_generator(
90
+ openai_client=openai_client,
91
+ openai_params=openai_params,
92
+ openai_extra_body=openai_extra_body,
93
+ request_obj=request,
94
+ is_auto_attempt=False
95
+ ),
96
+ media_type="text/event-stream"
97
+ )
98
+ else:
99
+ print(f"INFO: OpenAI True Streaming ENABLED for model '{request.model}'.")
100
+ return StreamingResponse(
101
+ self._true_stream_generator(openai_client, openai_params, openai_extra_body, request),
102
+ media_type="text/event-stream"
103
+ )
104
+
105
+ async def _true_stream_generator(
106
+ self,
107
+ openai_client: openai.AsyncOpenAI,
108
+ openai_params: Dict[str, Any],
109
+ openai_extra_body: Dict[str, Any],
110
+ request: OpenAIRequest
111
+ ) -> AsyncGenerator[str, None]:
112
+ """Generate true streaming response."""
113
+ try:
114
+ # Ensure stream=True is explicitly passed for real streaming
115
+ openai_params_for_stream = {**openai_params, "stream": True}
116
+ stream_response = await openai_client.chat.completions.create(
117
+ **openai_params_for_stream,
118
+ extra_body=openai_extra_body['extra_body']
119
+ )
120
+
121
+ # Create processor for tag-based extraction across chunks
122
+ reasoning_processor = StreamingReasoningProcessor(VERTEX_REASONING_TAG)
123
+
124
+ async for chunk in stream_response:
125
+ try:
126
+ chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
127
+
128
+ choices = chunk_as_dict.get('choices')
129
+ if choices and isinstance(choices, list) and len(choices) > 0:
130
+ delta = choices[0].get('delta')
131
+ if delta and isinstance(delta, dict):
132
+ # Always remove extra_content if present
133
+ if 'extra_content' in delta:
134
+ del delta['extra_content']
135
+
136
+ content = delta.get('content', '')
137
+ if content:
138
+ # Use the processor to extract reasoning
139
+ processed_content, current_reasoning = reasoning_processor.process_chunk(content)
140
+
141
+ # Update delta with processed content
142
+ if current_reasoning:
143
+ delta['reasoning_content'] = current_reasoning
144
+ if processed_content:
145
+ delta['content'] = processed_content
146
+ elif 'content' in delta:
147
+ del delta['content']
148
+
149
+ yield f"data: {json.dumps(chunk_as_dict)}\n\n"
150
+
151
+ except Exception as chunk_error:
152
+ error_msg = f"Error processing OpenAI chunk for {request.model}: {str(chunk_error)}"
153
+ print(f"ERROR: {error_msg}")
154
+ if len(error_msg) > 1024:
155
+ error_msg = error_msg[:1024] + "..."
156
+ error_response = create_openai_error_response(500, error_msg, "server_error")
157
+ yield f"data: {json.dumps(error_response)}\n\n"
158
+ yield "data: [DONE]\n\n"
159
+ return
160
+
161
+ # Handle any remaining buffer content
162
+ if reasoning_processor.tag_buffer and not reasoning_processor.inside_tag:
163
+ # Output any remaining content
164
+ final_chunk = {
165
+ "id": f"chatcmpl-{int(time.time())}",
166
+ "object": "chat.completion.chunk",
167
+ "created": int(time.time()),
168
+ "model": request.model,
169
+ "choices": [{"index": 0, "delta": {"content": reasoning_processor.tag_buffer}, "finish_reason": None}]
170
+ }
171
+ yield f"data: {json.dumps(final_chunk)}\n\n"
172
+ elif reasoning_processor.inside_tag and reasoning_processor.reasoning_buffer:
173
+ # We were inside a tag but never found the closing tag
174
+ print(f"WARNING: Unclosed reasoning tag detected. Partial reasoning: {reasoning_processor.reasoning_buffer[:100]}...")
175
+
176
+ yield "data: [DONE]\n\n"
177
+
178
+ except Exception as stream_error:
179
+ error_msg = str(stream_error)
180
+ if len(error_msg) > 1024:
181
+ error_msg = error_msg[:1024] + "..."
182
+ error_msg_full = f"Error during OpenAI streaming for {request.model}: {error_msg}"
183
+ print(f"ERROR: {error_msg_full}")
184
+ error_response = create_openai_error_response(500, error_msg_full, "server_error")
185
+ yield f"data: {json.dumps(error_response)}\n\n"
186
+ yield "data: [DONE]\n\n"
187
+
188
+ async def handle_non_streaming_response(
189
+ self,
190
+ openai_client: openai.AsyncOpenAI,
191
+ openai_params: Dict[str, Any],
192
+ openai_extra_body: Dict[str, Any],
193
+ request: OpenAIRequest
194
+ ) -> JSONResponse:
195
+ """Handle non-streaming responses for OpenAI Direct mode."""
196
+ try:
197
+ # Ensure stream=False is explicitly passed
198
+ openai_params_non_stream = {**openai_params, "stream": False}
199
+ response = await openai_client.chat.completions.create(
200
+ **openai_params_non_stream,
201
+ extra_body=openai_extra_body['extra_body']
202
+ )
203
+ response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
204
+
205
+ try:
206
+ choices = response_dict.get('choices')
207
+ if choices and isinstance(choices, list) and len(choices) > 0:
208
+ message_dict = choices[0].get('message')
209
+ if message_dict and isinstance(message_dict, dict):
210
+ # Always remove extra_content from the message if it exists
211
+ if 'extra_content' in message_dict:
212
+ del message_dict['extra_content']
213
+
214
+ # Extract reasoning from content
215
+ full_content = message_dict.get('content')
216
+ actual_content = full_content if isinstance(full_content, str) else ""
217
+
218
+ if actual_content:
219
+ print(f"INFO: OpenAI Direct Non-Streaming - Applying tag extraction with fixed marker: '{VERTEX_REASONING_TAG}'")
220
+ reasoning_text, actual_content = extract_reasoning_by_tags(actual_content, VERTEX_REASONING_TAG)
221
+ message_dict['content'] = actual_content
222
+ if reasoning_text:
223
+ message_dict['reasoning_content'] = reasoning_text
224
+ print(f"DEBUG: Tag extraction success. Reasoning len: {len(reasoning_text)}, Content len: {len(actual_content)}")
225
+ else:
226
+ print(f"DEBUG: No content found within fixed tag '{VERTEX_REASONING_TAG}'.")
227
+ else:
228
+ print(f"WARNING: OpenAI Direct Non-Streaming - No initial content found in message.")
229
+ message_dict['content'] = ""
230
+
231
+ except Exception as e_reasoning:
232
+ print(f"WARNING: Error during non-streaming reasoning processing for model {request.model}: {e_reasoning}")
233
+
234
+ return JSONResponse(content=response_dict)
235
+
236
+ except Exception as e:
237
+ error_msg = f"Error calling OpenAI client for {request.model}: {str(e)}"
238
+ print(f"ERROR: {error_msg}")
239
+ return JSONResponse(
240
+ status_code=500,
241
+ content=create_openai_error_response(500, error_msg, "server_error")
242
+ )
243
+
244
+ async def process_request(self, request: OpenAIRequest, base_model_name: str):
245
+ """Main entry point for processing OpenAI Direct mode requests."""
246
+ print(f"INFO: Using OpenAI Direct Path for model: {request.model}")
247
+
248
+ # Get credentials
249
+ rotated_credentials, rotated_project_id = self.credential_manager.get_credentials()
250
+
251
+ if not rotated_credentials or not rotated_project_id:
252
+ error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully."
253
+ print(f"ERROR: {error_msg}")
254
+ return JSONResponse(
255
+ status_code=500,
256
+ content=create_openai_error_response(500, error_msg, "server_error")
257
+ )
258
+
259
+ print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
260
+ gcp_token = _refresh_auth(rotated_credentials)
261
+
262
+ if not gcp_token:
263
+ error_msg = f"Failed to obtain valid GCP token for OpenAI client (Project: {rotated_project_id})."
264
+ print(f"ERROR: {error_msg}")
265
+ return JSONResponse(
266
+ status_code=500,
267
+ content=create_openai_error_response(500, error_msg, "server_error")
268
+ )
269
+
270
+ # Create client and prepare parameters
271
+ openai_client = self.create_openai_client(rotated_project_id, gcp_token)
272
+ model_id = f"google/{base_model_name}"
273
+ openai_params = self.prepare_openai_params(request, model_id)
274
+ openai_extra_body = self.prepare_extra_body()
275
+
276
+ # Handle streaming vs non-streaming
277
+ if request.stream:
278
+ return await self.handle_streaming_response(
279
+ openai_client, openai_params, openai_extra_body, request
280
+ )
281
+ else:
282
+ return await self.handle_non_streaming_response(
283
+ openai_client, openai_params, openai_extra_body, request
284
+ )