bibibi12345 commited on
Commit
a5cc545
·
verified ·
1 Parent(s): d426ec5

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +45 -32
app/main.py CHANGED
@@ -808,8 +808,9 @@ def create_final_chunk(model: str, response_id: str, candidate_count: int = 1) -
808
 
809
  # /v1/models endpoint
810
  @app.get("/v1/models")
811
- async def list_models(api_key: str = Depends(get_api_key)):
812
  # Based on current information for Vertex AI models
 
813
  models = [
814
  {
815
  "id": "gemini-2.5-pro-exp-03-25",
@@ -1174,68 +1175,79 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1174
  print("Prompt structure: Unknown format")
1175
 
1176
 
 
 
1177
  if request.stream:
1178
- # Streaming call
1179
  response_id = f"chatcmpl-{int(time.time())}"
1180
  candidate_count = request.n or 1
1181
-
1182
  async def stream_generator_inner():
1183
  all_chunks_empty = True # Track if we receive any content
1184
  first_chunk_received = False
1185
  try:
1186
- for candidate_index in range(candidate_count):
1187
- print(f"Sending streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1188
- responses = client.models.generate_content_stream(
1189
- model=model_name,
1190
- contents=prompt,
1191
- config=current_gen_config,
1192
- )
1193
-
1194
- # Use regular for loop, not async for
1195
- for chunk in responses:
1196
- first_chunk_received = True
1197
- if hasattr(chunk, 'text') and chunk.text:
1198
- all_chunks_empty = False
1199
- yield convert_chunk_to_openai(chunk, request.model, response_id, candidate_index)
1200
-
 
 
 
 
 
 
 
 
 
1201
  # Check if any chunk was received at all
1202
  if not first_chunk_received:
1203
  raise ValueError("Stream connection established but no chunks received")
1204
 
1205
  yield create_final_chunk(request.model, response_id, candidate_count)
1206
  yield "data: [DONE]\n\n"
1207
-
1208
  # Return status based on content received
1209
  if all_chunks_empty and first_chunk_received: # Check if we got chunks but they were all empty
1210
  raise ValueError("Streamed response contained only empty chunks") # Treat empty stream as failure for retry
1211
 
1212
  except Exception as stream_error:
1213
- error_msg = f"Error during streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
1214
  print(error_msg)
1215
  # Yield error in SSE format but also raise to signal failure
1216
  error_response_content = create_openai_error_response(500, error_msg, "server_error")
1217
  yield f"data: {json.dumps(error_response_content)}\n\n"
1218
  yield "data: [DONE]\n\n"
1219
  raise stream_error # Propagate error for retry logic
1220
-
1221
  return StreamingResponse(stream_generator_inner(), media_type="text/event-stream")
1222
 
1223
  else:
1224
- # Non-streaming call
1225
  try:
1226
- print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1227
- response = client.models.generate_content(
1228
- model=model_name,
1229
  contents=prompt,
1230
- config=current_gen_config,
 
1231
  )
1232
  if not is_response_valid(response):
1233
  raise ValueError("Invalid or empty response received") # Trigger retry
1234
-
1235
  openai_response = convert_to_openai_format(response, request.model)
1236
  return JSONResponse(content=openai_response)
1237
  except Exception as generate_error:
1238
- error_msg = f"Error generating content (Model: {model_name}, Format: {prompt_func.__name__}): {str(generate_error)}"
1239
  print(error_msg)
1240
  # Raise error to signal failure for retry logic
1241
  raise generate_error
@@ -1378,10 +1390,11 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1378
 
1379
  # Health check endpoint
1380
  @app.get("/health")
1381
- def health_check(api_key: str = Depends(get_api_key)):
1382
- # Refresh the credentials list to get the latest status
1383
- credential_manager.refresh_credentials_list()
1384
-
 
1385
  return {
1386
  "status": "ok",
1387
  "credentials": {
 
808
 
809
  # /v1/models endpoint
810
  @app.get("/v1/models")
811
+ async def list_models(): # Removed api_key dependency as it wasn't used, kept async
812
  # Based on current information for Vertex AI models
813
+ # Note: Consider adding authentication back if needed later
814
  models = [
815
  {
816
  "id": "gemini-2.5-pro-exp-03-25",
 
1175
  print("Prompt structure: Unknown format")
1176
 
1177
 
1178
+ model_instance = client.get_model(f"models/{model_name}") # Get the model instance
1179
+
1180
  if request.stream:
1181
+ # Streaming call (Async)
1182
  response_id = f"chatcmpl-{int(time.time())}"
1183
  candidate_count = request.n or 1
1184
+
1185
  async def stream_generator_inner():
1186
  all_chunks_empty = True # Track if we receive any content
1187
  first_chunk_received = False
1188
  try:
1189
+ # No need to loop candidate_index here, the stream handles multiple candidates if config asks for it
1190
+ print(f"Sending async streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1191
+ async_responses = await model_instance.generate_content_stream_async( # Use await and async method
1192
+ contents=prompt,
1193
+ generation_config=current_gen_config, # Use generation_config parameter
1194
+ # safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
1195
+ )
1196
+
1197
+ # Use async for loop
1198
+ async for chunk in async_responses: # Use async for
1199
+ first_chunk_received = True
1200
+ # Determine candidate_index based on the chunk itself if possible, fallback to 0
1201
+ # Note: Adjust this if the async stream chunk structure provides candidate index differently
1202
+ candidate_index = 0 # Assuming default index for now
1203
+ if hasattr(chunk, '_candidate_index'): # Check for potential internal attribute (may change)
1204
+ candidate_index = chunk._candidate_index
1205
+ elif hasattr(chunk, 'candidates') and chunk.candidates and hasattr(chunk.candidates[0], 'index'):
1206
+ # Or check standard candidate structure if available on chunk
1207
+ candidate_index = chunk.candidates[0].index
1208
+
1209
+ if hasattr(chunk, 'text') and chunk.text:
1210
+ all_chunks_empty = False
1211
+ yield convert_chunk_to_openai(chunk, request.model, response_id, candidate_index)
1212
+
1213
  # Check if any chunk was received at all
1214
  if not first_chunk_received:
1215
  raise ValueError("Stream connection established but no chunks received")
1216
 
1217
  yield create_final_chunk(request.model, response_id, candidate_count)
1218
  yield "data: [DONE]\n\n"
1219
+
1220
  # Return status based on content received
1221
  if all_chunks_empty and first_chunk_received: # Check if we got chunks but they were all empty
1222
  raise ValueError("Streamed response contained only empty chunks") # Treat empty stream as failure for retry
1223
 
1224
  except Exception as stream_error:
1225
+ error_msg = f"Error during async streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
1226
  print(error_msg)
1227
  # Yield error in SSE format but also raise to signal failure
1228
  error_response_content = create_openai_error_response(500, error_msg, "server_error")
1229
  yield f"data: {json.dumps(error_response_content)}\n\n"
1230
  yield "data: [DONE]\n\n"
1231
  raise stream_error # Propagate error for retry logic
1232
+
1233
  return StreamingResponse(stream_generator_inner(), media_type="text/event-stream")
1234
 
1235
  else:
1236
+ # Non-streaming call (Async)
1237
  try:
1238
+ print(f"Sending async request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1239
+ response = await model_instance.generate_content_async( # Use await and async method
 
1240
  contents=prompt,
1241
+ generation_config=current_gen_config, # Use generation_config parameter
1242
+ # safety_settings=current_gen_config.get("safety_settings", None) # Pass safety separately if needed
1243
  )
1244
  if not is_response_valid(response):
1245
  raise ValueError("Invalid or empty response received") # Trigger retry
1246
+
1247
  openai_response = convert_to_openai_format(response, request.model)
1248
  return JSONResponse(content=openai_response)
1249
  except Exception as generate_error:
1250
+ error_msg = f"Error generating async content (Model: {model_name}, Format: {prompt_func.__name__}): {str(generate_error)}"
1251
  print(error_msg)
1252
  # Raise error to signal failure for retry logic
1253
  raise generate_error
 
1390
 
1391
  # Health check endpoint
1392
  @app.get("/health")
1393
+ async def health_check(api_key: str = Depends(get_api_key)): # Made async
1394
+ # Refresh the credentials list (still sync I/O, consider wrapping later if needed)
1395
+ # For now, just call the sync method. If it blocks significantly, wrap with asyncio.to_thread
1396
+ credential_manager.refresh_credentials_list() # Keep sync call for now
1397
+
1398
  return {
1399
  "status": "ok",
1400
  "credentials": {