bibibi12345 commited on
Commit
c644d18
·
verified ·
1 Parent(s): e6545e7

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +451 -120
app/main.py CHANGED
@@ -7,6 +7,7 @@ import base64
7
  import re
8
  import json
9
  import time
 
10
  import os
11
  import glob
12
  import random
@@ -276,6 +277,148 @@ async def startup_event():
276
  # Define supported roles for Gemini API
277
  SUPPORTED_ROLES = ["user", "model"]
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[types.Content, List[types.Content]]:
280
  """
281
  Convert OpenAI messages to Gemini format.
@@ -545,22 +688,45 @@ def convert_to_openai_format(gemini_response, model: str) -> Dict[str, Any]:
545
  if hasattr(gemini_response, 'candidates') and len(gemini_response.candidates) > 1:
546
  choices = []
547
  for i, candidate in enumerate(gemini_response.candidates):
 
 
 
 
 
 
 
 
 
 
548
  choices.append({
549
  "index": i,
550
  "message": {
551
  "role": "assistant",
552
- "content": candidate.text
553
  },
554
  "finish_reason": "stop"
555
  })
556
  else:
557
  # Handle single response (backward compatibility)
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  choices = [
559
  {
560
  "index": 0,
561
  "message": {
562
  "role": "assistant",
563
- "content": gemini_response.text
564
  },
565
  "finish_reason": "stop"
566
  }
@@ -662,6 +828,51 @@ async def list_models(api_key: str = Depends(get_api_key)):
662
  "root": "gemini-2.5-pro-exp-03-25",
663
  "parent": None,
664
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  {
666
  "id": "gemini-2.0-flash",
667
  "object": "model",
@@ -782,157 +993,277 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
782
  try:
783
  # Validate model availability
784
  models_response = await list_models()
785
- if not request.model or not any(model["id"] == request.model for model in models_response.get("data", [])):
 
786
  error_response = create_openai_error_response(
787
  400, f"Model '{request.model}' not found", "invalid_request_error"
788
  )
789
  return JSONResponse(status_code=400, content=error_response)
790
-
791
- # Check if this is a grounded search model or encrypted model
 
792
  is_grounded_search = request.model.endswith("-search")
793
- is_encrypted_model = request.model == "gemini-2.5-pro-exp-03-25-encrypt"
794
-
795
- # Extract the base model name
796
- if is_grounded_search:
797
- gemini_model = request.model.replace("-search", "")
 
798
  elif is_encrypted_model:
799
- gemini_model = "gemini-2.5-pro-exp-03-25" # Use the base model
800
  else:
801
- gemini_model = request.model
802
-
803
  # Create generation config
804
  generation_config = create_generation_config(request)
805
-
806
  # Use the globally initialized client (from startup)
807
  global client
808
  if client is None:
809
- # This should ideally not happen if startup was successful
810
- error_response = create_openai_error_response(
811
- 500, "Vertex AI client not initialized", "server_error"
812
- )
813
- return JSONResponse(status_code=500, content=error_response)
814
  print(f"Using globally initialized client.")
815
-
816
- # Initialize Gemini model
817
- search_tool = types.Tool(google_search=types.GoogleSearch())
818
 
 
819
  safety_settings = [
820
- types.SafetySetting(
821
- category="HARM_CATEGORY_HATE_SPEECH",
822
- threshold="OFF"
823
- ),types.SafetySetting(
824
- category="HARM_CATEGORY_DANGEROUS_CONTENT",
825
- threshold="OFF"
826
- ),types.SafetySetting(
827
- category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
828
- threshold="OFF"
829
- ),types.SafetySetting(
830
- category="HARM_CATEGORY_HARASSMENT",
831
- threshold="OFF"
832
- )]
833
-
834
  generation_config["safety_settings"] = safety_settings
835
- if is_grounded_search:
836
- generation_config["tools"] = [search_tool]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
 
838
- # Create prompt from messages - use encrypted version if needed
839
- if is_encrypted_model:
840
- print(f"Using encrypted prompt for model: {request.model}")
841
- prompt = create_encrypted_gemini_prompt(request.messages)
842
- else:
843
- prompt = create_gemini_prompt(request.messages)
844
-
845
- # Log the structure of the prompt (without exposing sensitive content)
846
- if isinstance(prompt, list):
847
- print(f"Prompt structure: {len(prompt)} messages")
848
- for i, msg in enumerate(prompt):
849
- role = msg.role if hasattr(msg, 'role') else 'unknown'
850
- parts_count = len(msg.parts) if hasattr(msg, 'parts') else 0
851
- parts_types = [type(p).__name__ for p in (msg.parts if hasattr(msg, 'parts') else [])]
852
- print(f" Message {i+1}: role={role}, parts={parts_count}, types={parts_types}")
853
- elif isinstance(prompt, types.Content):
854
- print("Prompt structure: 1 message")
855
- role = prompt.role if hasattr(prompt, 'role') else 'unknown'
856
- parts_count = len(prompt.parts) if hasattr(prompt, 'parts') else 0
857
- parts_types = [type(p).__name__ for p in (prompt.parts if hasattr(prompt, 'parts') else [])]
858
- print(f" Message 1: role={role}, parts={parts_count}, types={parts_types}")
859
- else:
860
- print("Prompt structure: Unknown format")
861
 
862
- if request.stream:
863
- # Handle streaming response
864
- async def stream_generator():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
  response_id = f"chatcmpl-{int(time.time())}"
866
  candidate_count = request.n or 1
867
 
868
- try:
869
- # For streaming, we can only handle one candidate at a time
870
- # If multiple candidates are requested, we'll generate them sequentially
871
- for candidate_index in range(candidate_count):
872
- # Generate content with streaming
873
- # Handle the new message format for streaming using Gemini types
874
- print(f"Sending streaming request to Gemini API")
 
 
 
 
 
 
 
 
 
 
 
875
 
876
- # The prompt is now either a Content object or a list of Content objects
877
- responses = client.models.generate_content_stream(
878
- model=gemini_model,
879
- contents=prompt,
880
- config=generation_config,
881
- )
882
 
883
- # Convert and yield each chunk
884
- for response in responses:
885
- yield convert_chunk_to_openai(response, request.model, response_id, candidate_index)
886
-
887
- # Send final chunk with all candidates
888
- yield create_final_chunk(request.model, response_id, candidate_count)
889
- yield "data: [DONE]\n\n"
 
 
 
 
 
890
 
891
- except Exception as stream_error:
892
- # Format streaming errors in SSE format
893
- error_msg = f"Error during streaming: {str(stream_error)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
894
  print(error_msg)
895
- error_response = create_openai_error_response(500, error_msg, "server_error")
896
- yield f"data: {json.dumps(error_response)}\n\n"
897
- yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
898
 
899
- return StreamingResponse(
900
- stream_generator(),
901
- media_type="text/event-stream"
902
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903
  else:
904
- # Handle non-streaming response
 
 
 
 
 
 
 
 
 
 
 
 
905
  try:
906
- # If multiple candidates are requested, set candidate_count
907
- if request.n and request.n > 1:
908
- # Make sure generation_config has candidate_count set
909
- if "candidate_count" not in generation_config:
910
- generation_config["candidate_count"] = request.n
911
- # Handle the new message format using Gemini types
912
- print(f"Sending request to Gemini API")
913
-
914
- # The prompt is now either a Content object or a list of Content objects
915
- response = client.models.generate_content(
916
- model=gemini_model,
917
- contents=prompt,
918
- config=generation_config,
919
- )
920
-
921
-
922
- openai_response = convert_to_openai_format(response, request.model)
923
- return JSONResponse(content=openai_response)
924
- except Exception as generate_error:
925
- error_msg = f"Error generating content: {str(generate_error)}"
926
- print(error_msg)
927
- error_response = create_openai_error_response(500, error_msg, "server_error")
928
- return JSONResponse(status_code=500, content=error_response)
929
-
930
  except Exception as e:
931
- error_msg = f"Error processing request: {str(e)}"
 
932
  print(error_msg)
933
  error_response = create_openai_error_response(500, error_msg, "server_error")
 
934
  return JSONResponse(status_code=500, content=error_response)
935
 
 
 
 
936
  # Health check endpoint
937
  @app.get("/health")
938
  def health_check(api_key: str = Depends(get_api_key)):
 
7
  import re
8
  import json
9
  import time
10
+ import asyncio # Add this import
11
  import os
12
  import glob
13
  import random
 
277
  # Define supported roles for Gemini API
278
  SUPPORTED_ROLES = ["user", "model"]
279
 
280
+ # Conversion functions
281
+ def create_gemini_prompt_old(messages: List[OpenAIMessage]) -> Union[str, List[Any]]:
282
+ """
283
+ Convert OpenAI messages to Gemini format.
284
+ Returns either a string prompt or a list of content parts if images are present.
285
+ """
286
+ # Check if any message contains image content
287
+ has_images = False
288
+ for message in messages:
289
+ if isinstance(message.content, list):
290
+ for part in message.content:
291
+ if isinstance(part, dict) and part.get('type') == 'image_url':
292
+ has_images = True
293
+ break
294
+ elif isinstance(part, ContentPartImage):
295
+ has_images = True
296
+ break
297
+ if has_images:
298
+ break
299
+
300
+ # If no images, use the text-only format
301
+ if not has_images:
302
+ prompt = ""
303
+
304
+ # Extract system message if present
305
+ system_message = None
306
+ # Process all messages in their original order
307
+ for message in messages:
308
+ if message.role == "system":
309
+ # Handle both string and list[dict] content types
310
+ if isinstance(message.content, str):
311
+ system_message = message.content
312
+ elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
313
+ system_message = message.content[0]['text']
314
+ else:
315
+ # Handle unexpected format or raise error? For now, assume it's usable or skip.
316
+ system_message = str(message.content) # Fallback, might need refinement
317
+ break
318
+
319
+ # If system message exists, prepend it
320
+ if system_message:
321
+ prompt += f"System: {system_message}\n\n"
322
+
323
+ # Add other messages
324
+ for message in messages:
325
+ if message.role == "system":
326
+ continue # Already handled
327
+
328
+ # Handle both string and list[dict] content types
329
+ content_text = ""
330
+ if isinstance(message.content, str):
331
+ content_text = message.content
332
+ elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
333
+ content_text = message.content[0]['text']
334
+ else:
335
+ # Fallback for unexpected format
336
+ content_text = str(message.content)
337
+
338
+ if message.role == "system":
339
+ prompt += f"System: {content_text}\n\n"
340
+ elif message.role == "user":
341
+ prompt += f"Human: {content_text}\n"
342
+ elif message.role == "assistant":
343
+ prompt += f"AI: {content_text}\n"
344
+
345
+ # Add final AI prompt if last message was from user
346
+ if messages[-1].role == "user":
347
+ prompt += "AI: "
348
+
349
+ return prompt
350
+
351
+ # If images are present, create a list of content parts
352
+ gemini_contents = []
353
+
354
+ # Extract system message if present and add it first
355
+ for message in messages:
356
+ if message.role == "system":
357
+ if isinstance(message.content, str):
358
+ gemini_contents.append(f"System: {message.content}")
359
+ elif isinstance(message.content, list):
360
+ # Extract text from system message
361
+ system_text = ""
362
+ for part in message.content:
363
+ if isinstance(part, dict) and part.get('type') == 'text':
364
+ system_text += part.get('text', '')
365
+ elif isinstance(part, ContentPartText):
366
+ system_text += part.text
367
+ if system_text:
368
+ gemini_contents.append(f"System: {system_text}")
369
+ break
370
+
371
+ # Process user and assistant messages
372
+ # Process all messages in their original order
373
+ for message in messages:
374
+ if message.role == "system":
375
+ continue # Already handled
376
+
377
+ # For string content, add as text
378
+ if isinstance(message.content, str):
379
+ prefix = "Human: " if message.role == "user" else "AI: "
380
+ gemini_contents.append(f"{prefix}{message.content}")
381
+
382
+ # For list content, process each part
383
+ elif isinstance(message.content, list):
384
+ # First collect all text parts
385
+ text_content = ""
386
+
387
+ for part in message.content:
388
+ # Handle text parts
389
+ if isinstance(part, dict) and part.get('type') == 'text':
390
+ text_content += part.get('text', '')
391
+ elif isinstance(part, ContentPartText):
392
+ text_content += part.text
393
+
394
+ # Add the combined text content if any
395
+ if text_content:
396
+ prefix = "Human: " if message.role == "user" else "AI: "
397
+ gemini_contents.append(f"{prefix}{text_content}")
398
+
399
+ # Then process image parts
400
+ for part in message.content:
401
+ # Handle image parts
402
+ if isinstance(part, dict) and part.get('type') == 'image_url':
403
+ image_url = part.get('image_url', {}).get('url', '')
404
+ if image_url.startswith('data:'):
405
+ # Extract mime type and base64 data
406
+ mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
407
+ if mime_match:
408
+ mime_type, b64_data = mime_match.groups()
409
+ image_bytes = base64.b64decode(b64_data)
410
+ gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
411
+ elif isinstance(part, ContentPartImage):
412
+ image_url = part.image_url.url
413
+ if image_url.startswith('data:'):
414
+ # Extract mime type and base64 data
415
+ mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
416
+ if mime_match:
417
+ mime_type, b64_data = mime_match.groups()
418
+ image_bytes = base64.b64decode(b64_data)
419
+ gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
420
+ return gemini_contents
421
+
422
  def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[types.Content, List[types.Content]]:
423
  """
424
  Convert OpenAI messages to Gemini format.
 
688
  if hasattr(gemini_response, 'candidates') and len(gemini_response.candidates) > 1:
689
  choices = []
690
  for i, candidate in enumerate(gemini_response.candidates):
691
+ # Extract text content from candidate
692
+ content = ""
693
+ if hasattr(candidate, 'text'):
694
+ content = candidate.text
695
+ elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
696
+ # Look for text in parts
697
+ for part in candidate.content.parts:
698
+ if hasattr(part, 'text'):
699
+ content += part.text
700
+
701
  choices.append({
702
  "index": i,
703
  "message": {
704
  "role": "assistant",
705
+ "content": content
706
  },
707
  "finish_reason": "stop"
708
  })
709
  else:
710
  # Handle single response (backward compatibility)
711
+ content = ""
712
+ # Try different ways to access the text content
713
+ if hasattr(gemini_response, 'text'):
714
+ content = gemini_response.text
715
+ elif hasattr(gemini_response, 'candidates') and gemini_response.candidates:
716
+ candidate = gemini_response.candidates[0]
717
+ if hasattr(candidate, 'text'):
718
+ content = candidate.text
719
+ elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
720
+ for part in candidate.content.parts:
721
+ if hasattr(part, 'text'):
722
+ content += part.text
723
+
724
  choices = [
725
  {
726
  "index": 0,
727
  "message": {
728
  "role": "assistant",
729
+ "content": content
730
  },
731
  "finish_reason": "stop"
732
  }
 
828
  "root": "gemini-2.5-pro-exp-03-25",
829
  "parent": None,
830
  },
831
+ {
832
+ "id": "gemini-2.5-pro-exp-03-25-auto", # New auto model
833
+ "object": "model",
834
+ "created": int(time.time()),
835
+ "owned_by": "google",
836
+ "permission": [],
837
+ "root": "gemini-2.5-pro-exp-03-25",
838
+ "parent": None,
839
+ },
840
+ {
841
+ "id": "gemini-2.5-pro-preview-03-25",
842
+ "object": "model",
843
+ "created": int(time.time()),
844
+ "owned_by": "google",
845
+ "permission": [],
846
+ "root": "gemini-2.5-pro-preview-03-25",
847
+ "parent": None,
848
+ },
849
+ {
850
+ "id": "gemini-2.5-pro-preview-03-25-search",
851
+ "object": "model",
852
+ "created": int(time.time()),
853
+ "owned_by": "google",
854
+ "permission": [],
855
+ "root": "gemini-2.5-pro-preview-03-25",
856
+ "parent": None,
857
+ },
858
+ {
859
+ "id": "gemini-2.5-pro-preview-03-25-encrypt",
860
+ "object": "model",
861
+ "created": int(time.time()),
862
+ "owned_by": "google",
863
+ "permission": [],
864
+ "root": "gemini-2.5-pro-preview-03-25",
865
+ "parent": None,
866
+ },
867
+ {
868
+ "id": "gemini-2.5-pro-preview-03-25-auto", # New auto model
869
+ "object": "model",
870
+ "created": int(time.time()),
871
+ "owned_by": "google",
872
+ "permission": [],
873
+ "root": "gemini-2.5-pro-preview-03-25",
874
+ "parent": None,
875
+ },
876
  {
877
  "id": "gemini-2.0-flash",
878
  "object": "model",
 
993
  try:
994
  # Validate model availability
995
  models_response = await list_models()
996
+ available_models = [model["id"] for model in models_response.get("data", [])]
997
+ if not request.model or request.model not in available_models:
998
  error_response = create_openai_error_response(
999
  400, f"Model '{request.model}' not found", "invalid_request_error"
1000
  )
1001
  return JSONResponse(status_code=400, content=error_response)
1002
+
1003
+ # Check model type and extract base model name
1004
+ is_auto_model = request.model.endswith("-auto")
1005
  is_grounded_search = request.model.endswith("-search")
1006
+ is_encrypted_model = request.model.endswith("-encrypt")
1007
+
1008
+ if is_auto_model:
1009
+ base_model_name = request.model.replace("-auto", "")
1010
+ elif is_grounded_search:
1011
+ base_model_name = request.model.replace("-search", "")
1012
  elif is_encrypted_model:
1013
+ base_model_name = request.model.replace("-encrypt", "")
1014
  else:
1015
+ base_model_name = request.model
1016
+
1017
  # Create generation config
1018
  generation_config = create_generation_config(request)
1019
+
1020
  # Use the globally initialized client (from startup)
1021
  global client
1022
  if client is None:
1023
+ error_response = create_openai_error_response(
1024
+ 500, "Vertex AI client not initialized", "server_error"
1025
+ )
1026
+ return JSONResponse(status_code=500, content=error_response)
 
1027
  print(f"Using globally initialized client.")
 
 
 
1028
 
1029
+ # Common safety settings
1030
  safety_settings = [
1031
+ types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
1032
+ types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
1033
+ types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
1034
+ types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF")
1035
+ ]
 
 
 
 
 
 
 
 
 
1036
  generation_config["safety_settings"] = safety_settings
1037
+
1038
+ # --- Helper function to check response validity ---
1039
+ def is_response_valid(response):
1040
+ if response is None:
1041
+ return False
1042
+
1043
+ # Check if candidates exist
1044
+ if not hasattr(response, 'candidates') or not response.candidates:
1045
+ return False
1046
+
1047
+ # Get the first candidate
1048
+ candidate = response.candidates[0]
1049
+
1050
+ # Try different ways to access the text content
1051
+ text_content = None
1052
+
1053
+ # Method 1: Direct text attribute on candidate
1054
+ if hasattr(candidate, 'text'):
1055
+ text_content = candidate.text
1056
+ # Method 2: Text attribute on response
1057
+ elif hasattr(response, 'text'):
1058
+ text_content = response.text
1059
+ # Method 3: Content with parts
1060
+ elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
1061
+ # Look for text in parts
1062
+ for part in candidate.content.parts:
1063
+ if hasattr(part, 'text') and part.text:
1064
+ text_content = part.text
1065
+ break
1066
+
1067
+ # If we found text content and it's not empty, the response is valid
1068
+ if text_content:
1069
+ return True
1070
 
1071
+ # If no text content was found, check if there are other parts that might be valid
1072
+ if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
1073
+ if len(candidate.content.parts) > 0:
1074
+ # Consider valid if there are any parts at all
1075
+ return True
1076
+
1077
+ # Also check if the response itself has text
1078
+ if hasattr(response, 'text') and response.text:
1079
+ return True
1080
+
1081
+ # If we got here, the response is invalid
1082
+ print(f"Invalid response: No text content found in response structure: {str(response)[:200]}...")
1083
+ return False
1084
+
 
 
 
 
 
 
 
 
 
1085
 
1086
+ # --- Helper function to make the API call (handles stream/non-stream) ---
1087
+ async def make_gemini_call(model_name, prompt_func, current_gen_config):
1088
+ prompt = prompt_func(request.messages)
1089
+
1090
+ # Log prompt structure
1091
+ if isinstance(prompt, list):
1092
+ print(f"Prompt structure: {len(prompt)} messages")
1093
+ elif isinstance(prompt, types.Content):
1094
+ print("Prompt structure: 1 message")
1095
+ else:
1096
+ # Handle old format case (which returns str or list[Any])
1097
+ if isinstance(prompt, str):
1098
+ print("Prompt structure: String (old format)")
1099
+ elif isinstance(prompt, list):
1100
+ print(f"Prompt structure: List[{len(prompt)}] (old format with images)")
1101
+ else:
1102
+ print("Prompt structure: Unknown format")
1103
+
1104
+
1105
+ if request.stream:
1106
+ # Streaming call
1107
  response_id = f"chatcmpl-{int(time.time())}"
1108
  candidate_count = request.n or 1
1109
 
1110
+ async def stream_generator_inner():
1111
+ all_chunks_empty = True # Track if we receive any content
1112
+ first_chunk_received = False
1113
+ try:
1114
+ for candidate_index in range(candidate_count):
1115
+ print(f"Sending streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1116
+ responses = client.models.generate_content_stream(
1117
+ model=model_name,
1118
+ contents=prompt,
1119
+ config=current_gen_config,
1120
+ )
1121
+
1122
+ # Use regular for loop, not async for
1123
+ for chunk in responses:
1124
+ first_chunk_received = True
1125
+ if hasattr(chunk, 'text') and chunk.text:
1126
+ all_chunks_empty = False
1127
+ yield convert_chunk_to_openai(chunk, request.model, response_id, candidate_index)
1128
 
1129
+ # Check if any chunk was received at all
1130
+ if not first_chunk_received:
1131
+ raise ValueError("Stream connection established but no chunks received")
1132
+
1133
+ yield create_final_chunk(request.model, response_id, candidate_count)
1134
+ yield "data: [DONE]\n\n"
1135
 
1136
+ # Return status based on content received
1137
+ if all_chunks_empty and first_chunk_received: # Check if we got chunks but they were all empty
1138
+ raise ValueError("Streamed response contained only empty chunks") # Treat empty stream as failure for retry
1139
+
1140
+ except Exception as stream_error:
1141
+ error_msg = f"Error during streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
1142
+ print(error_msg)
1143
+ # Yield error in SSE format but also raise to signal failure
1144
+ error_response_content = create_openai_error_response(500, error_msg, "server_error")
1145
+ yield f"data: {json.dumps(error_response_content)}\n\n"
1146
+ yield "data: [DONE]\n\n"
1147
+ raise stream_error # Propagate error for retry logic
1148
 
1149
+ return StreamingResponse(stream_generator_inner(), media_type="text/event-stream")
1150
+
1151
+ else:
1152
+ # Non-streaming call
1153
+ try:
1154
+ print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
1155
+ response = client.models.generate_content(
1156
+ model=model_name,
1157
+ contents=prompt,
1158
+ config=current_gen_config,
1159
+ )
1160
+ if not is_response_valid(response):
1161
+ raise ValueError("Invalid or empty response received") # Trigger retry
1162
+
1163
+ openai_response = convert_to_openai_format(response, request.model)
1164
+ return JSONResponse(content=openai_response)
1165
+ except Exception as generate_error:
1166
+ error_msg = f"Error generating content (Model: {model_name}, Format: {prompt_func.__name__}): {str(generate_error)}"
1167
  print(error_msg)
1168
+ # Raise error to signal failure for retry logic
1169
+ raise generate_error
1170
+
1171
+
1172
+ # --- Main Logic ---
1173
+ last_error = None
1174
+
1175
+ if is_auto_model:
1176
+ print(f"Processing auto model: {request.model}")
1177
+ attempts = [
1178
+ {"name": "base", "model": base_model_name, "prompt_func": create_gemini_prompt, "config_modifier": lambda c: c},
1179
+ {"name": "old_format", "model": base_model_name, "prompt_func": create_gemini_prompt_old, "config_modifier": lambda c: c},
1180
+ {"name": "encrypt", "model": base_model_name, "prompt_func": create_encrypted_gemini_prompt, "config_modifier": lambda c: c}
1181
+ ]
1182
+
1183
+ for i, attempt in enumerate(attempts):
1184
+ print(f"Attempt {i+1}/{len(attempts)} using '{attempt['name']}' mode...")
1185
+ current_config = attempt["config_modifier"](generation_config.copy())
1186
+
1187
+ try:
1188
+ result = await make_gemini_call(attempt["model"], attempt["prompt_func"], current_config)
1189
+
1190
+ # For streaming, the result is StreamingResponse, success is determined inside make_gemini_call raising an error on failure
1191
+ # For non-streaming, if make_gemini_call doesn't raise, it's successful
1192
+ print(f"Attempt {i+1} ('{attempt['name']}') successful.")
1193
+ return result
1194
+ except Exception as e:
1195
+ last_error = e
1196
+ print(f"Attempt {i+1} ('{attempt['name']}') failed: {e}")
1197
+ if i < len(attempts) - 1:
1198
+ print("Waiting 1 second before next attempt...")
1199
+ await asyncio.sleep(1) # Use asyncio.sleep for async context
1200
+ else:
1201
+ print("All attempts failed.")
1202
 
1203
+ # If all attempts failed, return the last error
1204
+ error_msg = f"All retry attempts failed for model {request.model}. Last error: {str(last_error)}"
1205
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1206
+ # If the last attempt was streaming and failed, the error response is already yielded by the generator.
1207
+ # If non-streaming failed last, return the JSON error.
1208
+ if not request.stream:
1209
+ return JSONResponse(status_code=500, content=error_response)
1210
+ else:
1211
+ # The StreamingResponse returned earlier will handle yielding the final error.
1212
+ # We should not return a new response here.
1213
+ # If we reach here after a failed stream, it means the initial StreamingResponse object was returned,
1214
+ # but the generator within it failed on the last attempt.
1215
+ # The generator itself handles yielding the error SSE.
1216
+ # We need to ensure the main function doesn't try to return another response.
1217
+ # Returning the 'result' from the failed attempt (which is the StreamingResponse object)
1218
+ # might be okay IF the generator correctly yields the error and DONE message.
1219
+ # Let's return the StreamingResponse object which contains the failing generator.
1220
+ # This assumes the generator correctly terminates after yielding the error.
1221
+ # Re-evaluate if this causes issues. The goal is to avoid double responses.
1222
+ # It seems returning the StreamingResponse object itself is the correct FastAPI pattern.
1223
+ return result # Return the StreamingResponse object which contains the failing generator
1224
+
1225
+
1226
  else:
1227
+ # Handle non-auto models (base, search, encrypt)
1228
+ current_model_name = base_model_name
1229
+ current_prompt_func = create_gemini_prompt
1230
+ current_config = generation_config.copy()
1231
+
1232
+ if is_grounded_search:
1233
+ print(f"Using grounded search for model: {request.model}")
1234
+ search_tool = types.Tool(google_search=types.GoogleSearch())
1235
+ current_config["tools"] = [search_tool]
1236
+ elif is_encrypted_model:
1237
+ print(f"Using encrypted prompt for model: {request.model}")
1238
+ current_prompt_func = create_encrypted_gemini_prompt
1239
+
1240
  try:
1241
+ result = await make_gemini_call(current_model_name, current_prompt_func, current_config)
1242
+ return result
1243
+ except Exception as e:
1244
+ # Handle potential errors for non-auto models
1245
+ error_msg = f"Error processing model {request.model}: {str(e)}"
1246
+ print(error_msg)
1247
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1248
+ # Similar to auto-fail case, handle stream vs non-stream error return
1249
+ if not request.stream:
1250
+ return JSONResponse(status_code=500, content=error_response)
1251
+ else:
1252
+ # Let the StreamingResponse handle yielding the error
1253
+ return result # Return the StreamingResponse object containing the failing generator
1254
+
1255
+
 
 
 
 
 
 
 
 
 
1256
  except Exception as e:
1257
+ # Catch-all for unexpected errors during setup or logic flow
1258
+ error_msg = f"Unexpected error processing request: {str(e)}"
1259
  print(error_msg)
1260
  error_response = create_openai_error_response(500, error_msg, "server_error")
1261
+ # Ensure we return a JSON response even for stream requests if error happens early
1262
  return JSONResponse(status_code=500, content=error_response)
1263
 
1264
+ # --- Need to import asyncio ---
1265
+ # import asyncio # Add this import at the top of the file # Already added below
1266
+
1267
  # Health check endpoint
1268
  @app.get("/health")
1269
  def health_check(api_key: str = Depends(get_api_key)):