Update app/main.py
Browse files- app/main.py +45 -32
app/main.py
CHANGED
@@ -808,8 +808,9 @@ def create_final_chunk(model: str, response_id: str, candidate_count: int = 1) -
|
|
808 |
|
809 |
# /v1/models endpoint
|
810 |
@app.get("/v1/models")
|
811 |
-
async def list_models(
|
812 |
# Based on current information for Vertex AI models
|
|
|
813 |
models = [
|
814 |
{
|
815 |
"id": "gemini-2.5-pro-exp-03-25",
|
@@ -1174,68 +1175,79 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1174 |
print("Prompt structure: Unknown format")
|
1175 |
|
1176 |
|
|
|
|
|
1177 |
if request.stream:
|
1178 |
-
# Streaming call
|
1179 |
response_id = f"chatcmpl-{int(time.time())}"
|
1180 |
candidate_count = request.n or 1
|
1181 |
-
|
1182 |
async def stream_generator_inner():
|
1183 |
all_chunks_empty = True # Track if we receive any content
|
1184 |
first_chunk_received = False
|
1185 |
try:
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1201 |
# Check if any chunk was received at all
|
1202 |
if not first_chunk_received:
|
1203 |
raise ValueError("Stream connection established but no chunks received")
|
1204 |
|
1205 |
yield create_final_chunk(request.model, response_id, candidate_count)
|
1206 |
yield "data: [DONE]\n\n"
|
1207 |
-
|
1208 |
# Return status based on content received
|
1209 |
if all_chunks_empty and first_chunk_received: # Check if we got chunks but they were all empty
|
1210 |
raise ValueError("Streamed response contained only empty chunks") # Treat empty stream as failure for retry
|
1211 |
|
1212 |
except Exception as stream_error:
|
1213 |
-
error_msg = f"Error during streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
|
1214 |
print(error_msg)
|
1215 |
# Yield error in SSE format but also raise to signal failure
|
1216 |
error_response_content = create_openai_error_response(500, error_msg, "server_error")
|
1217 |
yield f"data: {json.dumps(error_response_content)}\n\n"
|
1218 |
yield "data: [DONE]\n\n"
|
1219 |
raise stream_error # Propagate error for retry logic
|
1220 |
-
|
1221 |
return StreamingResponse(stream_generator_inner(), media_type="text/event-stream")
|
1222 |
|
1223 |
else:
|
1224 |
-
# Non-streaming call
|
1225 |
try:
|
1226 |
-
print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1227 |
-
response =
|
1228 |
-
model=model_name,
|
1229 |
contents=prompt,
|
1230 |
-
|
|
|
1231 |
)
|
1232 |
if not is_response_valid(response):
|
1233 |
raise ValueError("Invalid or empty response received") # Trigger retry
|
1234 |
-
|
1235 |
openai_response = convert_to_openai_format(response, request.model)
|
1236 |
return JSONResponse(content=openai_response)
|
1237 |
except Exception as generate_error:
|
1238 |
-
error_msg = f"Error generating content (Model: {model_name}, Format: {prompt_func.__name__}): {str(generate_error)}"
|
1239 |
print(error_msg)
|
1240 |
# Raise error to signal failure for retry logic
|
1241 |
raise generate_error
|
@@ -1378,10 +1390,11 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1378 |
|
1379 |
# Health check endpoint
|
1380 |
@app.get("/health")
|
1381 |
-
def health_check(api_key: str = Depends(get_api_key)):
|
1382 |
-
# Refresh the credentials list
|
1383 |
-
|
1384 |
-
|
|
|
1385 |
return {
|
1386 |
"status": "ok",
|
1387 |
"credentials": {
|
|
|
808 |
|
809 |
# /v1/models endpoint
|
810 |
@app.get("/v1/models")
|
811 |
+
async def list_models(): # Removed api_key dependency as it wasn't used, kept async
|
812 |
# Based on current information for Vertex AI models
|
813 |
+
# Note: Consider adding authentication back if needed later
|
814 |
models = [
|
815 |
{
|
816 |
"id": "gemini-2.5-pro-exp-03-25",
|
|
|
1175 |
print("Prompt structure: Unknown format")
|
1176 |
|
1177 |
|
1178 |
+
model_instance = client.get_model(f"models/{model_name}") # Get the model instance
|
1179 |
+
|
1180 |
if request.stream:
|
1181 |
+
# Streaming call (Async)
|
1182 |
response_id = f"chatcmpl-{int(time.time())}"
|
1183 |
candidate_count = request.n or 1
|
1184 |
+
|
1185 |
async def stream_generator_inner():
|
1186 |
all_chunks_empty = True # Track if we receive any content
|
1187 |
first_chunk_received = False
|
1188 |
try:
|
1189 |
+
# No need to loop candidate_index here, the stream handles multiple candidates if config asks for it
|
1190 |
+
print(f"Sending async streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1191 |
+
async_responses = await model_instance.generate_content_stream_async( # Use await and async method
|
1192 |
+
contents=prompt,
|
1193 |
+
generation_config=current_gen_config, # Use generation_config parameter
|
1194 |
+
# safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
|
1195 |
+
)
|
1196 |
+
|
1197 |
+
# Use async for loop
|
1198 |
+
async for chunk in async_responses: # Use async for
|
1199 |
+
first_chunk_received = True
|
1200 |
+
# Determine candidate_index based on the chunk itself if possible, fallback to 0
|
1201 |
+
# Note: Adjust this if the async stream chunk structure provides candidate index differently
|
1202 |
+
candidate_index = 0 # Assuming default index for now
|
1203 |
+
if hasattr(chunk, '_candidate_index'): # Check for potential internal attribute (may change)
|
1204 |
+
candidate_index = chunk._candidate_index
|
1205 |
+
elif hasattr(chunk, 'candidates') and chunk.candidates and hasattr(chunk.candidates[0], 'index'):
|
1206 |
+
# Or check standard candidate structure if available on chunk
|
1207 |
+
candidate_index = chunk.candidates[0].index
|
1208 |
+
|
1209 |
+
if hasattr(chunk, 'text') and chunk.text:
|
1210 |
+
all_chunks_empty = False
|
1211 |
+
yield convert_chunk_to_openai(chunk, request.model, response_id, candidate_index)
|
1212 |
+
|
1213 |
# Check if any chunk was received at all
|
1214 |
if not first_chunk_received:
|
1215 |
raise ValueError("Stream connection established but no chunks received")
|
1216 |
|
1217 |
yield create_final_chunk(request.model, response_id, candidate_count)
|
1218 |
yield "data: [DONE]\n\n"
|
1219 |
+
|
1220 |
# Return status based on content received
|
1221 |
if all_chunks_empty and first_chunk_received: # Check if we got chunks but they were all empty
|
1222 |
raise ValueError("Streamed response contained only empty chunks") # Treat empty stream as failure for retry
|
1223 |
|
1224 |
except Exception as stream_error:
|
1225 |
+
error_msg = f"Error during async streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
|
1226 |
print(error_msg)
|
1227 |
# Yield error in SSE format but also raise to signal failure
|
1228 |
error_response_content = create_openai_error_response(500, error_msg, "server_error")
|
1229 |
yield f"data: {json.dumps(error_response_content)}\n\n"
|
1230 |
yield "data: [DONE]\n\n"
|
1231 |
raise stream_error # Propagate error for retry logic
|
1232 |
+
|
1233 |
return StreamingResponse(stream_generator_inner(), media_type="text/event-stream")
|
1234 |
|
1235 |
else:
|
1236 |
+
# Non-streaming call (Async)
|
1237 |
try:
|
1238 |
+
print(f"Sending async request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1239 |
+
response = await model_instance.generate_content_async( # Use await and async method
|
|
|
1240 |
contents=prompt,
|
1241 |
+
generation_config=current_gen_config, # Use generation_config parameter
|
1242 |
+
# safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
|
1243 |
)
|
1244 |
if not is_response_valid(response):
|
1245 |
raise ValueError("Invalid or empty response received") # Trigger retry
|
1246 |
+
|
1247 |
openai_response = convert_to_openai_format(response, request.model)
|
1248 |
return JSONResponse(content=openai_response)
|
1249 |
except Exception as generate_error:
|
1250 |
+
error_msg = f"Error generating async content (Model: {model_name}, Format: {prompt_func.__name__}): {str(generate_error)}"
|
1251 |
print(error_msg)
|
1252 |
# Raise error to signal failure for retry logic
|
1253 |
raise generate_error
|
|
|
1390 |
|
1391 |
# Health check endpoint
|
1392 |
@app.get("/health")
|
1393 |
+
async def health_check(api_key: str = Depends(get_api_key)): # Made async
|
1394 |
+
# Refresh the credentials list (still sync I/O, consider wrapping later if needed)
|
1395 |
+
# For now, just call the sync method. If it blocks significantly, wrap with asyncio.to_thread
|
1396 |
+
credential_manager.refresh_credentials_list() # Keep sync call for now
|
1397 |
+
|
1398 |
return {
|
1399 |
"status": "ok",
|
1400 |
"credentials": {
|