Commit
·
b5abced
1
Parent(s):
da7a18e
changed openai cot streaming handling. added roundrobin mode for credentials. various refactoring
Browse files- app/express_key_manager.py +97 -0
- app/openai_handler.py +284 -0
app/express_key_manager.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import List, Optional, Tuple
|
3 |
+
import config as app_config
|
4 |
+
|
5 |
+
|
6 |
+
class ExpressKeyManager:
|
7 |
+
"""
|
8 |
+
Manager for Vertex Express API keys with support for both random and round-robin selection strategies.
|
9 |
+
Similar to CredentialManager but specifically for Express API keys.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
"""Initialize the Express Key Manager with API keys from config."""
|
14 |
+
self.express_keys: List[str] = app_config.VERTEX_EXPRESS_API_KEY_VAL
|
15 |
+
self.round_robin_index: int = 0
|
16 |
+
|
17 |
+
def get_total_keys(self) -> int:
|
18 |
+
"""Get the total number of available Express API keys."""
|
19 |
+
return len(self.express_keys)
|
20 |
+
|
21 |
+
def _get_key_with_index(self, key: str, index: int) -> Tuple[str, int]:
|
22 |
+
"""Return a tuple of (key, original_index) for logging purposes."""
|
23 |
+
return (key, index)
|
24 |
+
|
25 |
+
def get_random_express_key(self) -> Optional[Tuple[str, int]]:
|
26 |
+
"""
|
27 |
+
Get a random Express API key.
|
28 |
+
Returns (key, original_index) tuple or None if no keys available.
|
29 |
+
"""
|
30 |
+
if not self.express_keys:
|
31 |
+
print("WARNING: No Express API keys available for selection.")
|
32 |
+
return None
|
33 |
+
|
34 |
+
print(f"DEBUG: Using random Express API key selection strategy.")
|
35 |
+
|
36 |
+
# Create list of indexed keys
|
37 |
+
indexed_keys = list(enumerate(self.express_keys))
|
38 |
+
# Shuffle to randomize order
|
39 |
+
random.shuffle(indexed_keys)
|
40 |
+
|
41 |
+
# Return the first key (which is random due to shuffle)
|
42 |
+
original_idx, key = indexed_keys[0]
|
43 |
+
return self._get_key_with_index(key, original_idx)
|
44 |
+
|
45 |
+
def get_roundrobin_express_key(self) -> Optional[Tuple[str, int]]:
|
46 |
+
"""
|
47 |
+
Get an Express API key using round-robin selection.
|
48 |
+
Returns (key, original_index) tuple or None if no keys available.
|
49 |
+
"""
|
50 |
+
if not self.express_keys:
|
51 |
+
print("WARNING: No Express API keys available for selection.")
|
52 |
+
return None
|
53 |
+
|
54 |
+
print(f"DEBUG: Using round-robin Express API key selection strategy.")
|
55 |
+
|
56 |
+
# Ensure round_robin_index is within bounds
|
57 |
+
if self.round_robin_index >= len(self.express_keys):
|
58 |
+
self.round_robin_index = 0
|
59 |
+
|
60 |
+
# Get the key at current index
|
61 |
+
key = self.express_keys[self.round_robin_index]
|
62 |
+
original_idx = self.round_robin_index
|
63 |
+
|
64 |
+
# Move to next index for next call
|
65 |
+
self.round_robin_index = (self.round_robin_index + 1) % len(self.express_keys)
|
66 |
+
|
67 |
+
return self._get_key_with_index(key, original_idx)
|
68 |
+
|
69 |
+
def get_express_api_key(self) -> Optional[Tuple[str, int]]:
|
70 |
+
"""
|
71 |
+
Get an Express API key based on the configured selection strategy.
|
72 |
+
Checks ROUNDROBIN config and calls the appropriate method.
|
73 |
+
Returns (key, original_index) tuple or None if no keys available.
|
74 |
+
"""
|
75 |
+
if app_config.ROUNDROBIN:
|
76 |
+
return self.get_roundrobin_express_key()
|
77 |
+
else:
|
78 |
+
return self.get_random_express_key()
|
79 |
+
|
80 |
+
def get_all_keys_indexed(self) -> List[Tuple[int, str]]:
|
81 |
+
"""
|
82 |
+
Get all Express API keys with their indices.
|
83 |
+
Useful for retry logic where we need to try all keys.
|
84 |
+
Returns list of (original_index, key) tuples.
|
85 |
+
"""
|
86 |
+
return list(enumerate(self.express_keys))
|
87 |
+
|
88 |
+
def refresh_keys(self):
|
89 |
+
"""
|
90 |
+
Refresh the Express API keys from config.
|
91 |
+
This allows for dynamic updates if the config is reloaded.
|
92 |
+
"""
|
93 |
+
self.express_keys = app_config.VERTEX_EXPRESS_API_KEY_VAL
|
94 |
+
# Reset round-robin index if keys changed
|
95 |
+
if self.round_robin_index >= len(self.express_keys):
|
96 |
+
self.round_robin_index = 0
|
97 |
+
print(f"INFO: Express API keys refreshed. Total keys: {self.get_total_keys()}")
|
app/openai_handler.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
OpenAI handler module for creating clients and processing OpenAI Direct mode responses.
|
3 |
+
This module encapsulates all OpenAI-specific logic that was previously in chat_api.py.
|
4 |
+
"""
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import asyncio
|
8 |
+
from typing import Dict, Any, AsyncGenerator
|
9 |
+
|
10 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
11 |
+
import openai
|
12 |
+
from google.auth.transport.requests import Request as AuthRequest
|
13 |
+
|
14 |
+
from models import OpenAIRequest
|
15 |
+
from config import VERTEX_REASONING_TAG
|
16 |
+
import config as app_config
|
17 |
+
from api_helpers import (
|
18 |
+
create_openai_error_response,
|
19 |
+
openai_fake_stream_generator,
|
20 |
+
StreamingReasoningProcessor
|
21 |
+
)
|
22 |
+
from message_processing import extract_reasoning_by_tags
|
23 |
+
from credentials_manager import _refresh_auth
|
24 |
+
|
25 |
+
|
26 |
+
class OpenAIDirectHandler:
|
27 |
+
"""Handles OpenAI Direct mode operations including client creation and response processing."""
|
28 |
+
|
29 |
+
def __init__(self, credential_manager):
|
30 |
+
self.credential_manager = credential_manager
|
31 |
+
self.safety_settings = [
|
32 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"},
|
33 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"},
|
34 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"},
|
35 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"},
|
36 |
+
{"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'}
|
37 |
+
]
|
38 |
+
|
39 |
+
def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI:
|
40 |
+
"""Create an OpenAI client configured for Vertex AI endpoint."""
|
41 |
+
endpoint_url = (
|
42 |
+
f"https://aiplatform.googleapis.com/v1beta1/"
|
43 |
+
f"projects/{project_id}/locations/{location}/endpoints/openapi"
|
44 |
+
)
|
45 |
+
|
46 |
+
return openai.AsyncOpenAI(
|
47 |
+
base_url=endpoint_url,
|
48 |
+
api_key=gcp_token, # OAuth token
|
49 |
+
)
|
50 |
+
|
51 |
+
def prepare_openai_params(self, request: OpenAIRequest, model_id: str) -> Dict[str, Any]:
|
52 |
+
"""Prepare parameters for OpenAI API call."""
|
53 |
+
params = {
|
54 |
+
"model": model_id,
|
55 |
+
"messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
|
56 |
+
"temperature": request.temperature,
|
57 |
+
"max_tokens": request.max_tokens,
|
58 |
+
"top_p": request.top_p,
|
59 |
+
"stream": request.stream,
|
60 |
+
"stop": request.stop,
|
61 |
+
"seed": request.seed,
|
62 |
+
"n": request.n,
|
63 |
+
}
|
64 |
+
# Remove None values
|
65 |
+
return {k: v for k, v in params.items() if v is not None}
|
66 |
+
|
67 |
+
def prepare_extra_body(self) -> Dict[str, Any]:
|
68 |
+
"""Prepare extra body parameters for OpenAI API call."""
|
69 |
+
return {
|
70 |
+
"extra_body": {
|
71 |
+
'google': {
|
72 |
+
'safety_settings': self.safety_settings,
|
73 |
+
'thought_tag_marker': VERTEX_REASONING_TAG
|
74 |
+
}
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
async def handle_streaming_response(
|
79 |
+
self,
|
80 |
+
openai_client: openai.AsyncOpenAI,
|
81 |
+
openai_params: Dict[str, Any],
|
82 |
+
openai_extra_body: Dict[str, Any],
|
83 |
+
request: OpenAIRequest
|
84 |
+
) -> StreamingResponse:
|
85 |
+
"""Handle streaming responses for OpenAI Direct mode."""
|
86 |
+
if app_config.FAKE_STREAMING_ENABLED:
|
87 |
+
print(f"INFO: OpenAI Fake Streaming (SSE Simulation) ENABLED for model '{request.model}'.")
|
88 |
+
return StreamingResponse(
|
89 |
+
openai_fake_stream_generator(
|
90 |
+
openai_client=openai_client,
|
91 |
+
openai_params=openai_params,
|
92 |
+
openai_extra_body=openai_extra_body,
|
93 |
+
request_obj=request,
|
94 |
+
is_auto_attempt=False
|
95 |
+
),
|
96 |
+
media_type="text/event-stream"
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
print(f"INFO: OpenAI True Streaming ENABLED for model '{request.model}'.")
|
100 |
+
return StreamingResponse(
|
101 |
+
self._true_stream_generator(openai_client, openai_params, openai_extra_body, request),
|
102 |
+
media_type="text/event-stream"
|
103 |
+
)
|
104 |
+
|
105 |
+
async def _true_stream_generator(
|
106 |
+
self,
|
107 |
+
openai_client: openai.AsyncOpenAI,
|
108 |
+
openai_params: Dict[str, Any],
|
109 |
+
openai_extra_body: Dict[str, Any],
|
110 |
+
request: OpenAIRequest
|
111 |
+
) -> AsyncGenerator[str, None]:
|
112 |
+
"""Generate true streaming response."""
|
113 |
+
try:
|
114 |
+
# Ensure stream=True is explicitly passed for real streaming
|
115 |
+
openai_params_for_stream = {**openai_params, "stream": True}
|
116 |
+
stream_response = await openai_client.chat.completions.create(
|
117 |
+
**openai_params_for_stream,
|
118 |
+
extra_body=openai_extra_body['extra_body']
|
119 |
+
)
|
120 |
+
|
121 |
+
# Create processor for tag-based extraction across chunks
|
122 |
+
reasoning_processor = StreamingReasoningProcessor(VERTEX_REASONING_TAG)
|
123 |
+
|
124 |
+
async for chunk in stream_response:
|
125 |
+
try:
|
126 |
+
chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True)
|
127 |
+
|
128 |
+
choices = chunk_as_dict.get('choices')
|
129 |
+
if choices and isinstance(choices, list) and len(choices) > 0:
|
130 |
+
delta = choices[0].get('delta')
|
131 |
+
if delta and isinstance(delta, dict):
|
132 |
+
# Always remove extra_content if present
|
133 |
+
if 'extra_content' in delta:
|
134 |
+
del delta['extra_content']
|
135 |
+
|
136 |
+
content = delta.get('content', '')
|
137 |
+
if content:
|
138 |
+
# Use the processor to extract reasoning
|
139 |
+
processed_content, current_reasoning = reasoning_processor.process_chunk(content)
|
140 |
+
|
141 |
+
# Update delta with processed content
|
142 |
+
if current_reasoning:
|
143 |
+
delta['reasoning_content'] = current_reasoning
|
144 |
+
if processed_content:
|
145 |
+
delta['content'] = processed_content
|
146 |
+
elif 'content' in delta:
|
147 |
+
del delta['content']
|
148 |
+
|
149 |
+
yield f"data: {json.dumps(chunk_as_dict)}\n\n"
|
150 |
+
|
151 |
+
except Exception as chunk_error:
|
152 |
+
error_msg = f"Error processing OpenAI chunk for {request.model}: {str(chunk_error)}"
|
153 |
+
print(f"ERROR: {error_msg}")
|
154 |
+
if len(error_msg) > 1024:
|
155 |
+
error_msg = error_msg[:1024] + "..."
|
156 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
157 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
158 |
+
yield "data: [DONE]\n\n"
|
159 |
+
return
|
160 |
+
|
161 |
+
# Handle any remaining buffer content
|
162 |
+
if reasoning_processor.tag_buffer and not reasoning_processor.inside_tag:
|
163 |
+
# Output any remaining content
|
164 |
+
final_chunk = {
|
165 |
+
"id": f"chatcmpl-{int(time.time())}",
|
166 |
+
"object": "chat.completion.chunk",
|
167 |
+
"created": int(time.time()),
|
168 |
+
"model": request.model,
|
169 |
+
"choices": [{"index": 0, "delta": {"content": reasoning_processor.tag_buffer}, "finish_reason": None}]
|
170 |
+
}
|
171 |
+
yield f"data: {json.dumps(final_chunk)}\n\n"
|
172 |
+
elif reasoning_processor.inside_tag and reasoning_processor.reasoning_buffer:
|
173 |
+
# We were inside a tag but never found the closing tag
|
174 |
+
print(f"WARNING: Unclosed reasoning tag detected. Partial reasoning: {reasoning_processor.reasoning_buffer[:100]}...")
|
175 |
+
|
176 |
+
yield "data: [DONE]\n\n"
|
177 |
+
|
178 |
+
except Exception as stream_error:
|
179 |
+
error_msg = str(stream_error)
|
180 |
+
if len(error_msg) > 1024:
|
181 |
+
error_msg = error_msg[:1024] + "..."
|
182 |
+
error_msg_full = f"Error during OpenAI streaming for {request.model}: {error_msg}"
|
183 |
+
print(f"ERROR: {error_msg_full}")
|
184 |
+
error_response = create_openai_error_response(500, error_msg_full, "server_error")
|
185 |
+
yield f"data: {json.dumps(error_response)}\n\n"
|
186 |
+
yield "data: [DONE]\n\n"
|
187 |
+
|
188 |
+
async def handle_non_streaming_response(
|
189 |
+
self,
|
190 |
+
openai_client: openai.AsyncOpenAI,
|
191 |
+
openai_params: Dict[str, Any],
|
192 |
+
openai_extra_body: Dict[str, Any],
|
193 |
+
request: OpenAIRequest
|
194 |
+
) -> JSONResponse:
|
195 |
+
"""Handle non-streaming responses for OpenAI Direct mode."""
|
196 |
+
try:
|
197 |
+
# Ensure stream=False is explicitly passed
|
198 |
+
openai_params_non_stream = {**openai_params, "stream": False}
|
199 |
+
response = await openai_client.chat.completions.create(
|
200 |
+
**openai_params_non_stream,
|
201 |
+
extra_body=openai_extra_body['extra_body']
|
202 |
+
)
|
203 |
+
response_dict = response.model_dump(exclude_unset=True, exclude_none=True)
|
204 |
+
|
205 |
+
try:
|
206 |
+
choices = response_dict.get('choices')
|
207 |
+
if choices and isinstance(choices, list) and len(choices) > 0:
|
208 |
+
message_dict = choices[0].get('message')
|
209 |
+
if message_dict and isinstance(message_dict, dict):
|
210 |
+
# Always remove extra_content from the message if it exists
|
211 |
+
if 'extra_content' in message_dict:
|
212 |
+
del message_dict['extra_content']
|
213 |
+
|
214 |
+
# Extract reasoning from content
|
215 |
+
full_content = message_dict.get('content')
|
216 |
+
actual_content = full_content if isinstance(full_content, str) else ""
|
217 |
+
|
218 |
+
if actual_content:
|
219 |
+
print(f"INFO: OpenAI Direct Non-Streaming - Applying tag extraction with fixed marker: '{VERTEX_REASONING_TAG}'")
|
220 |
+
reasoning_text, actual_content = extract_reasoning_by_tags(actual_content, VERTEX_REASONING_TAG)
|
221 |
+
message_dict['content'] = actual_content
|
222 |
+
if reasoning_text:
|
223 |
+
message_dict['reasoning_content'] = reasoning_text
|
224 |
+
print(f"DEBUG: Tag extraction success. Reasoning len: {len(reasoning_text)}, Content len: {len(actual_content)}")
|
225 |
+
else:
|
226 |
+
print(f"DEBUG: No content found within fixed tag '{VERTEX_REASONING_TAG}'.")
|
227 |
+
else:
|
228 |
+
print(f"WARNING: OpenAI Direct Non-Streaming - No initial content found in message.")
|
229 |
+
message_dict['content'] = ""
|
230 |
+
|
231 |
+
except Exception as e_reasoning:
|
232 |
+
print(f"WARNING: Error during non-streaming reasoning processing for model {request.model}: {e_reasoning}")
|
233 |
+
|
234 |
+
return JSONResponse(content=response_dict)
|
235 |
+
|
236 |
+
except Exception as e:
|
237 |
+
error_msg = f"Error calling OpenAI client for {request.model}: {str(e)}"
|
238 |
+
print(f"ERROR: {error_msg}")
|
239 |
+
return JSONResponse(
|
240 |
+
status_code=500,
|
241 |
+
content=create_openai_error_response(500, error_msg, "server_error")
|
242 |
+
)
|
243 |
+
|
244 |
+
async def process_request(self, request: OpenAIRequest, base_model_name: str):
|
245 |
+
"""Main entry point for processing OpenAI Direct mode requests."""
|
246 |
+
print(f"INFO: Using OpenAI Direct Path for model: {request.model}")
|
247 |
+
|
248 |
+
# Get credentials
|
249 |
+
rotated_credentials, rotated_project_id = self.credential_manager.get_credentials()
|
250 |
+
|
251 |
+
if not rotated_credentials or not rotated_project_id:
|
252 |
+
error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully."
|
253 |
+
print(f"ERROR: {error_msg}")
|
254 |
+
return JSONResponse(
|
255 |
+
status_code=500,
|
256 |
+
content=create_openai_error_response(500, error_msg, "server_error")
|
257 |
+
)
|
258 |
+
|
259 |
+
print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}")
|
260 |
+
gcp_token = _refresh_auth(rotated_credentials)
|
261 |
+
|
262 |
+
if not gcp_token:
|
263 |
+
error_msg = f"Failed to obtain valid GCP token for OpenAI client (Project: {rotated_project_id})."
|
264 |
+
print(f"ERROR: {error_msg}")
|
265 |
+
return JSONResponse(
|
266 |
+
status_code=500,
|
267 |
+
content=create_openai_error_response(500, error_msg, "server_error")
|
268 |
+
)
|
269 |
+
|
270 |
+
# Create client and prepare parameters
|
271 |
+
openai_client = self.create_openai_client(rotated_project_id, gcp_token)
|
272 |
+
model_id = f"google/{base_model_name}"
|
273 |
+
openai_params = self.prepare_openai_params(request, model_id)
|
274 |
+
openai_extra_body = self.prepare_extra_body()
|
275 |
+
|
276 |
+
# Handle streaming vs non-streaming
|
277 |
+
if request.stream:
|
278 |
+
return await self.handle_streaming_response(
|
279 |
+
openai_client, openai_params, openai_extra_body, request
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
return await self.handle_non_streaming_response(
|
283 |
+
openai_client, openai_params, openai_extra_body, request
|
284 |
+
)
|