Update app/main.py
Browse files- app/main.py +451 -120
app/main.py
CHANGED
@@ -7,6 +7,7 @@ import base64
|
|
7 |
import re
|
8 |
import json
|
9 |
import time
|
|
|
10 |
import os
|
11 |
import glob
|
12 |
import random
|
@@ -276,6 +277,148 @@ async def startup_event():
|
|
276 |
# Define supported roles for Gemini API
|
277 |
SUPPORTED_ROLES = ["user", "model"]
|
278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[types.Content, List[types.Content]]:
|
280 |
"""
|
281 |
Convert OpenAI messages to Gemini format.
|
@@ -545,22 +688,45 @@ def convert_to_openai_format(gemini_response, model: str) -> Dict[str, Any]:
|
|
545 |
if hasattr(gemini_response, 'candidates') and len(gemini_response.candidates) > 1:
|
546 |
choices = []
|
547 |
for i, candidate in enumerate(gemini_response.candidates):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
choices.append({
|
549 |
"index": i,
|
550 |
"message": {
|
551 |
"role": "assistant",
|
552 |
-
"content":
|
553 |
},
|
554 |
"finish_reason": "stop"
|
555 |
})
|
556 |
else:
|
557 |
# Handle single response (backward compatibility)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
choices = [
|
559 |
{
|
560 |
"index": 0,
|
561 |
"message": {
|
562 |
"role": "assistant",
|
563 |
-
"content":
|
564 |
},
|
565 |
"finish_reason": "stop"
|
566 |
}
|
@@ -662,6 +828,51 @@ async def list_models(api_key: str = Depends(get_api_key)):
|
|
662 |
"root": "gemini-2.5-pro-exp-03-25",
|
663 |
"parent": None,
|
664 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
{
|
666 |
"id": "gemini-2.0-flash",
|
667 |
"object": "model",
|
@@ -782,157 +993,277 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
782 |
try:
|
783 |
# Validate model availability
|
784 |
models_response = await list_models()
|
785 |
-
|
|
|
786 |
error_response = create_openai_error_response(
|
787 |
400, f"Model '{request.model}' not found", "invalid_request_error"
|
788 |
)
|
789 |
return JSONResponse(status_code=400, content=error_response)
|
790 |
-
|
791 |
-
# Check
|
|
|
792 |
is_grounded_search = request.model.endswith("-search")
|
793 |
-
is_encrypted_model = request.model
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
|
|
798 |
elif is_encrypted_model:
|
799 |
-
|
800 |
else:
|
801 |
-
|
802 |
-
|
803 |
# Create generation config
|
804 |
generation_config = create_generation_config(request)
|
805 |
-
|
806 |
# Use the globally initialized client (from startup)
|
807 |
global client
|
808 |
if client is None:
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
return JSONResponse(status_code=500, content=error_response)
|
814 |
print(f"Using globally initialized client.")
|
815 |
-
|
816 |
-
# Initialize Gemini model
|
817 |
-
search_tool = types.Tool(google_search=types.GoogleSearch())
|
818 |
|
|
|
819 |
safety_settings = [
|
820 |
-
types.SafetySetting(
|
821 |
-
category="
|
822 |
-
threshold="OFF"
|
823 |
-
|
824 |
-
|
825 |
-
threshold="OFF"
|
826 |
-
),types.SafetySetting(
|
827 |
-
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
828 |
-
threshold="OFF"
|
829 |
-
),types.SafetySetting(
|
830 |
-
category="HARM_CATEGORY_HARASSMENT",
|
831 |
-
threshold="OFF"
|
832 |
-
)]
|
833 |
-
|
834 |
generation_config["safety_settings"] = safety_settings
|
835 |
-
|
836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
print(f" Message {i+1}: role={role}, parts={parts_count}, types={parts_types}")
|
853 |
-
elif isinstance(prompt, types.Content):
|
854 |
-
print("Prompt structure: 1 message")
|
855 |
-
role = prompt.role if hasattr(prompt, 'role') else 'unknown'
|
856 |
-
parts_count = len(prompt.parts) if hasattr(prompt, 'parts') else 0
|
857 |
-
parts_types = [type(p).__name__ for p in (prompt.parts if hasattr(prompt, 'parts') else [])]
|
858 |
-
print(f" Message 1: role={role}, parts={parts_count}, types={parts_types}")
|
859 |
-
else:
|
860 |
-
print("Prompt structure: Unknown format")
|
861 |
|
862 |
-
|
863 |
-
|
864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
865 |
response_id = f"chatcmpl-{int(time.time())}"
|
866 |
candidate_count = request.n or 1
|
867 |
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
875 |
|
876 |
-
#
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
|
883 |
-
#
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
|
|
|
|
|
|
|
|
|
|
890 |
|
891 |
-
|
892 |
-
|
893 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
894 |
print(error_msg)
|
895 |
-
|
896 |
-
|
897 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
898 |
|
899 |
-
return
|
900 |
-
|
901 |
-
|
902 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
903 |
else:
|
904 |
-
# Handle non-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
905 |
try:
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
openai_response = convert_to_openai_format(response, request.model)
|
923 |
-
return JSONResponse(content=openai_response)
|
924 |
-
except Exception as generate_error:
|
925 |
-
error_msg = f"Error generating content: {str(generate_error)}"
|
926 |
-
print(error_msg)
|
927 |
-
error_response = create_openai_error_response(500, error_msg, "server_error")
|
928 |
-
return JSONResponse(status_code=500, content=error_response)
|
929 |
-
|
930 |
except Exception as e:
|
931 |
-
|
|
|
932 |
print(error_msg)
|
933 |
error_response = create_openai_error_response(500, error_msg, "server_error")
|
|
|
934 |
return JSONResponse(status_code=500, content=error_response)
|
935 |
|
|
|
|
|
|
|
936 |
# Health check endpoint
|
937 |
@app.get("/health")
|
938 |
def health_check(api_key: str = Depends(get_api_key)):
|
|
|
7 |
import re
|
8 |
import json
|
9 |
import time
|
10 |
+
import asyncio # Add this import
|
11 |
import os
|
12 |
import glob
|
13 |
import random
|
|
|
277 |
# Define supported roles for Gemini API
|
278 |
SUPPORTED_ROLES = ["user", "model"]
|
279 |
|
280 |
+
# Conversion functions
|
281 |
+
def create_gemini_prompt_old(messages: List[OpenAIMessage]) -> Union[str, List[Any]]:
|
282 |
+
"""
|
283 |
+
Convert OpenAI messages to Gemini format.
|
284 |
+
Returns either a string prompt or a list of content parts if images are present.
|
285 |
+
"""
|
286 |
+
# Check if any message contains image content
|
287 |
+
has_images = False
|
288 |
+
for message in messages:
|
289 |
+
if isinstance(message.content, list):
|
290 |
+
for part in message.content:
|
291 |
+
if isinstance(part, dict) and part.get('type') == 'image_url':
|
292 |
+
has_images = True
|
293 |
+
break
|
294 |
+
elif isinstance(part, ContentPartImage):
|
295 |
+
has_images = True
|
296 |
+
break
|
297 |
+
if has_images:
|
298 |
+
break
|
299 |
+
|
300 |
+
# If no images, use the text-only format
|
301 |
+
if not has_images:
|
302 |
+
prompt = ""
|
303 |
+
|
304 |
+
# Extract system message if present
|
305 |
+
system_message = None
|
306 |
+
# Process all messages in their original order
|
307 |
+
for message in messages:
|
308 |
+
if message.role == "system":
|
309 |
+
# Handle both string and list[dict] content types
|
310 |
+
if isinstance(message.content, str):
|
311 |
+
system_message = message.content
|
312 |
+
elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
|
313 |
+
system_message = message.content[0]['text']
|
314 |
+
else:
|
315 |
+
# Handle unexpected format or raise error? For now, assume it's usable or skip.
|
316 |
+
system_message = str(message.content) # Fallback, might need refinement
|
317 |
+
break
|
318 |
+
|
319 |
+
# If system message exists, prepend it
|
320 |
+
if system_message:
|
321 |
+
prompt += f"System: {system_message}\n\n"
|
322 |
+
|
323 |
+
# Add other messages
|
324 |
+
for message in messages:
|
325 |
+
if message.role == "system":
|
326 |
+
continue # Already handled
|
327 |
+
|
328 |
+
# Handle both string and list[dict] content types
|
329 |
+
content_text = ""
|
330 |
+
if isinstance(message.content, str):
|
331 |
+
content_text = message.content
|
332 |
+
elif isinstance(message.content, list) and message.content and isinstance(message.content[0], dict) and 'text' in message.content[0]:
|
333 |
+
content_text = message.content[0]['text']
|
334 |
+
else:
|
335 |
+
# Fallback for unexpected format
|
336 |
+
content_text = str(message.content)
|
337 |
+
|
338 |
+
if message.role == "system":
|
339 |
+
prompt += f"System: {content_text}\n\n"
|
340 |
+
elif message.role == "user":
|
341 |
+
prompt += f"Human: {content_text}\n"
|
342 |
+
elif message.role == "assistant":
|
343 |
+
prompt += f"AI: {content_text}\n"
|
344 |
+
|
345 |
+
# Add final AI prompt if last message was from user
|
346 |
+
if messages[-1].role == "user":
|
347 |
+
prompt += "AI: "
|
348 |
+
|
349 |
+
return prompt
|
350 |
+
|
351 |
+
# If images are present, create a list of content parts
|
352 |
+
gemini_contents = []
|
353 |
+
|
354 |
+
# Extract system message if present and add it first
|
355 |
+
for message in messages:
|
356 |
+
if message.role == "system":
|
357 |
+
if isinstance(message.content, str):
|
358 |
+
gemini_contents.append(f"System: {message.content}")
|
359 |
+
elif isinstance(message.content, list):
|
360 |
+
# Extract text from system message
|
361 |
+
system_text = ""
|
362 |
+
for part in message.content:
|
363 |
+
if isinstance(part, dict) and part.get('type') == 'text':
|
364 |
+
system_text += part.get('text', '')
|
365 |
+
elif isinstance(part, ContentPartText):
|
366 |
+
system_text += part.text
|
367 |
+
if system_text:
|
368 |
+
gemini_contents.append(f"System: {system_text}")
|
369 |
+
break
|
370 |
+
|
371 |
+
# Process user and assistant messages
|
372 |
+
# Process all messages in their original order
|
373 |
+
for message in messages:
|
374 |
+
if message.role == "system":
|
375 |
+
continue # Already handled
|
376 |
+
|
377 |
+
# For string content, add as text
|
378 |
+
if isinstance(message.content, str):
|
379 |
+
prefix = "Human: " if message.role == "user" else "AI: "
|
380 |
+
gemini_contents.append(f"{prefix}{message.content}")
|
381 |
+
|
382 |
+
# For list content, process each part
|
383 |
+
elif isinstance(message.content, list):
|
384 |
+
# First collect all text parts
|
385 |
+
text_content = ""
|
386 |
+
|
387 |
+
for part in message.content:
|
388 |
+
# Handle text parts
|
389 |
+
if isinstance(part, dict) and part.get('type') == 'text':
|
390 |
+
text_content += part.get('text', '')
|
391 |
+
elif isinstance(part, ContentPartText):
|
392 |
+
text_content += part.text
|
393 |
+
|
394 |
+
# Add the combined text content if any
|
395 |
+
if text_content:
|
396 |
+
prefix = "Human: " if message.role == "user" else "AI: "
|
397 |
+
gemini_contents.append(f"{prefix}{text_content}")
|
398 |
+
|
399 |
+
# Then process image parts
|
400 |
+
for part in message.content:
|
401 |
+
# Handle image parts
|
402 |
+
if isinstance(part, dict) and part.get('type') == 'image_url':
|
403 |
+
image_url = part.get('image_url', {}).get('url', '')
|
404 |
+
if image_url.startswith('data:'):
|
405 |
+
# Extract mime type and base64 data
|
406 |
+
mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
|
407 |
+
if mime_match:
|
408 |
+
mime_type, b64_data = mime_match.groups()
|
409 |
+
image_bytes = base64.b64decode(b64_data)
|
410 |
+
gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
|
411 |
+
elif isinstance(part, ContentPartImage):
|
412 |
+
image_url = part.image_url.url
|
413 |
+
if image_url.startswith('data:'):
|
414 |
+
# Extract mime type and base64 data
|
415 |
+
mime_match = re.match(r'data:([^;]+);base64,(.+)', image_url)
|
416 |
+
if mime_match:
|
417 |
+
mime_type, b64_data = mime_match.groups()
|
418 |
+
image_bytes = base64.b64decode(b64_data)
|
419 |
+
gemini_contents.append(types.Part.from_bytes(data=image_bytes, mime_type=mime_type))
|
420 |
+
return gemini_contents
|
421 |
+
|
422 |
def create_gemini_prompt(messages: List[OpenAIMessage]) -> Union[types.Content, List[types.Content]]:
|
423 |
"""
|
424 |
Convert OpenAI messages to Gemini format.
|
|
|
688 |
if hasattr(gemini_response, 'candidates') and len(gemini_response.candidates) > 1:
|
689 |
choices = []
|
690 |
for i, candidate in enumerate(gemini_response.candidates):
|
691 |
+
# Extract text content from candidate
|
692 |
+
content = ""
|
693 |
+
if hasattr(candidate, 'text'):
|
694 |
+
content = candidate.text
|
695 |
+
elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
696 |
+
# Look for text in parts
|
697 |
+
for part in candidate.content.parts:
|
698 |
+
if hasattr(part, 'text'):
|
699 |
+
content += part.text
|
700 |
+
|
701 |
choices.append({
|
702 |
"index": i,
|
703 |
"message": {
|
704 |
"role": "assistant",
|
705 |
+
"content": content
|
706 |
},
|
707 |
"finish_reason": "stop"
|
708 |
})
|
709 |
else:
|
710 |
# Handle single response (backward compatibility)
|
711 |
+
content = ""
|
712 |
+
# Try different ways to access the text content
|
713 |
+
if hasattr(gemini_response, 'text'):
|
714 |
+
content = gemini_response.text
|
715 |
+
elif hasattr(gemini_response, 'candidates') and gemini_response.candidates:
|
716 |
+
candidate = gemini_response.candidates[0]
|
717 |
+
if hasattr(candidate, 'text'):
|
718 |
+
content = candidate.text
|
719 |
+
elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
720 |
+
for part in candidate.content.parts:
|
721 |
+
if hasattr(part, 'text'):
|
722 |
+
content += part.text
|
723 |
+
|
724 |
choices = [
|
725 |
{
|
726 |
"index": 0,
|
727 |
"message": {
|
728 |
"role": "assistant",
|
729 |
+
"content": content
|
730 |
},
|
731 |
"finish_reason": "stop"
|
732 |
}
|
|
|
828 |
"root": "gemini-2.5-pro-exp-03-25",
|
829 |
"parent": None,
|
830 |
},
|
831 |
+
{
|
832 |
+
"id": "gemini-2.5-pro-exp-03-25-auto", # New auto model
|
833 |
+
"object": "model",
|
834 |
+
"created": int(time.time()),
|
835 |
+
"owned_by": "google",
|
836 |
+
"permission": [],
|
837 |
+
"root": "gemini-2.5-pro-exp-03-25",
|
838 |
+
"parent": None,
|
839 |
+
},
|
840 |
+
{
|
841 |
+
"id": "gemini-2.5-pro-preview-03-25",
|
842 |
+
"object": "model",
|
843 |
+
"created": int(time.time()),
|
844 |
+
"owned_by": "google",
|
845 |
+
"permission": [],
|
846 |
+
"root": "gemini-2.5-pro-preview-03-25",
|
847 |
+
"parent": None,
|
848 |
+
},
|
849 |
+
{
|
850 |
+
"id": "gemini-2.5-pro-preview-03-25-search",
|
851 |
+
"object": "model",
|
852 |
+
"created": int(time.time()),
|
853 |
+
"owned_by": "google",
|
854 |
+
"permission": [],
|
855 |
+
"root": "gemini-2.5-pro-preview-03-25",
|
856 |
+
"parent": None,
|
857 |
+
},
|
858 |
+
{
|
859 |
+
"id": "gemini-2.5-pro-preview-03-25-encrypt",
|
860 |
+
"object": "model",
|
861 |
+
"created": int(time.time()),
|
862 |
+
"owned_by": "google",
|
863 |
+
"permission": [],
|
864 |
+
"root": "gemini-2.5-pro-preview-03-25",
|
865 |
+
"parent": None,
|
866 |
+
},
|
867 |
+
{
|
868 |
+
"id": "gemini-2.5-pro-preview-03-25-auto", # New auto model
|
869 |
+
"object": "model",
|
870 |
+
"created": int(time.time()),
|
871 |
+
"owned_by": "google",
|
872 |
+
"permission": [],
|
873 |
+
"root": "gemini-2.5-pro-preview-03-25",
|
874 |
+
"parent": None,
|
875 |
+
},
|
876 |
{
|
877 |
"id": "gemini-2.0-flash",
|
878 |
"object": "model",
|
|
|
993 |
try:
|
994 |
# Validate model availability
|
995 |
models_response = await list_models()
|
996 |
+
available_models = [model["id"] for model in models_response.get("data", [])]
|
997 |
+
if not request.model or request.model not in available_models:
|
998 |
error_response = create_openai_error_response(
|
999 |
400, f"Model '{request.model}' not found", "invalid_request_error"
|
1000 |
)
|
1001 |
return JSONResponse(status_code=400, content=error_response)
|
1002 |
+
|
1003 |
+
# Check model type and extract base model name
|
1004 |
+
is_auto_model = request.model.endswith("-auto")
|
1005 |
is_grounded_search = request.model.endswith("-search")
|
1006 |
+
is_encrypted_model = request.model.endswith("-encrypt")
|
1007 |
+
|
1008 |
+
if is_auto_model:
|
1009 |
+
base_model_name = request.model.replace("-auto", "")
|
1010 |
+
elif is_grounded_search:
|
1011 |
+
base_model_name = request.model.replace("-search", "")
|
1012 |
elif is_encrypted_model:
|
1013 |
+
base_model_name = request.model.replace("-encrypt", "")
|
1014 |
else:
|
1015 |
+
base_model_name = request.model
|
1016 |
+
|
1017 |
# Create generation config
|
1018 |
generation_config = create_generation_config(request)
|
1019 |
+
|
1020 |
# Use the globally initialized client (from startup)
|
1021 |
global client
|
1022 |
if client is None:
|
1023 |
+
error_response = create_openai_error_response(
|
1024 |
+
500, "Vertex AI client not initialized", "server_error"
|
1025 |
+
)
|
1026 |
+
return JSONResponse(status_code=500, content=error_response)
|
|
|
1027 |
print(f"Using globally initialized client.")
|
|
|
|
|
|
|
1028 |
|
1029 |
+
# Common safety settings
|
1030 |
safety_settings = [
|
1031 |
+
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
1032 |
+
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
|
1033 |
+
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
|
1034 |
+
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF")
|
1035 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1036 |
generation_config["safety_settings"] = safety_settings
|
1037 |
+
|
1038 |
+
# --- Helper function to check response validity ---
|
1039 |
+
def is_response_valid(response):
|
1040 |
+
if response is None:
|
1041 |
+
return False
|
1042 |
+
|
1043 |
+
# Check if candidates exist
|
1044 |
+
if not hasattr(response, 'candidates') or not response.candidates:
|
1045 |
+
return False
|
1046 |
+
|
1047 |
+
# Get the first candidate
|
1048 |
+
candidate = response.candidates[0]
|
1049 |
+
|
1050 |
+
# Try different ways to access the text content
|
1051 |
+
text_content = None
|
1052 |
+
|
1053 |
+
# Method 1: Direct text attribute on candidate
|
1054 |
+
if hasattr(candidate, 'text'):
|
1055 |
+
text_content = candidate.text
|
1056 |
+
# Method 2: Text attribute on response
|
1057 |
+
elif hasattr(response, 'text'):
|
1058 |
+
text_content = response.text
|
1059 |
+
# Method 3: Content with parts
|
1060 |
+
elif hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
1061 |
+
# Look for text in parts
|
1062 |
+
for part in candidate.content.parts:
|
1063 |
+
if hasattr(part, 'text') and part.text:
|
1064 |
+
text_content = part.text
|
1065 |
+
break
|
1066 |
+
|
1067 |
+
# If we found text content and it's not empty, the response is valid
|
1068 |
+
if text_content:
|
1069 |
+
return True
|
1070 |
|
1071 |
+
# If no text content was found, check if there are other parts that might be valid
|
1072 |
+
if hasattr(candidate, 'content') and hasattr(candidate.content, 'parts'):
|
1073 |
+
if len(candidate.content.parts) > 0:
|
1074 |
+
# Consider valid if there are any parts at all
|
1075 |
+
return True
|
1076 |
+
|
1077 |
+
# Also check if the response itself has text
|
1078 |
+
if hasattr(response, 'text') and response.text:
|
1079 |
+
return True
|
1080 |
+
|
1081 |
+
# If we got here, the response is invalid
|
1082 |
+
print(f"Invalid response: No text content found in response structure: {str(response)[:200]}...")
|
1083 |
+
return False
|
1084 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1085 |
|
1086 |
+
# --- Helper function to make the API call (handles stream/non-stream) ---
|
1087 |
+
async def make_gemini_call(model_name, prompt_func, current_gen_config):
|
1088 |
+
prompt = prompt_func(request.messages)
|
1089 |
+
|
1090 |
+
# Log prompt structure
|
1091 |
+
if isinstance(prompt, list):
|
1092 |
+
print(f"Prompt structure: {len(prompt)} messages")
|
1093 |
+
elif isinstance(prompt, types.Content):
|
1094 |
+
print("Prompt structure: 1 message")
|
1095 |
+
else:
|
1096 |
+
# Handle old format case (which returns str or list[Any])
|
1097 |
+
if isinstance(prompt, str):
|
1098 |
+
print("Prompt structure: String (old format)")
|
1099 |
+
elif isinstance(prompt, list):
|
1100 |
+
print(f"Prompt structure: List[{len(prompt)}] (old format with images)")
|
1101 |
+
else:
|
1102 |
+
print("Prompt structure: Unknown format")
|
1103 |
+
|
1104 |
+
|
1105 |
+
if request.stream:
|
1106 |
+
# Streaming call
|
1107 |
response_id = f"chatcmpl-{int(time.time())}"
|
1108 |
candidate_count = request.n or 1
|
1109 |
|
1110 |
+
async def stream_generator_inner():
|
1111 |
+
all_chunks_empty = True # Track if we receive any content
|
1112 |
+
first_chunk_received = False
|
1113 |
+
try:
|
1114 |
+
for candidate_index in range(candidate_count):
|
1115 |
+
print(f"Sending streaming request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1116 |
+
responses = client.models.generate_content_stream(
|
1117 |
+
model=model_name,
|
1118 |
+
contents=prompt,
|
1119 |
+
config=current_gen_config,
|
1120 |
+
)
|
1121 |
+
|
1122 |
+
# Use regular for loop, not async for
|
1123 |
+
for chunk in responses:
|
1124 |
+
first_chunk_received = True
|
1125 |
+
if hasattr(chunk, 'text') and chunk.text:
|
1126 |
+
all_chunks_empty = False
|
1127 |
+
yield convert_chunk_to_openai(chunk, request.model, response_id, candidate_index)
|
1128 |
|
1129 |
+
# Check if any chunk was received at all
|
1130 |
+
if not first_chunk_received:
|
1131 |
+
raise ValueError("Stream connection established but no chunks received")
|
1132 |
+
|
1133 |
+
yield create_final_chunk(request.model, response_id, candidate_count)
|
1134 |
+
yield "data: [DONE]\n\n"
|
1135 |
|
1136 |
+
# Return status based on content received
|
1137 |
+
if all_chunks_empty and first_chunk_received: # Check if we got chunks but they were all empty
|
1138 |
+
raise ValueError("Streamed response contained only empty chunks") # Treat empty stream as failure for retry
|
1139 |
+
|
1140 |
+
except Exception as stream_error:
|
1141 |
+
error_msg = f"Error during streaming (Model: {model_name}, Format: {prompt_func.__name__}): {str(stream_error)}"
|
1142 |
+
print(error_msg)
|
1143 |
+
# Yield error in SSE format but also raise to signal failure
|
1144 |
+
error_response_content = create_openai_error_response(500, error_msg, "server_error")
|
1145 |
+
yield f"data: {json.dumps(error_response_content)}\n\n"
|
1146 |
+
yield "data: [DONE]\n\n"
|
1147 |
+
raise stream_error # Propagate error for retry logic
|
1148 |
|
1149 |
+
return StreamingResponse(stream_generator_inner(), media_type="text/event-stream")
|
1150 |
+
|
1151 |
+
else:
|
1152 |
+
# Non-streaming call
|
1153 |
+
try:
|
1154 |
+
print(f"Sending request to Gemini API (Model: {model_name}, Prompt Format: {prompt_func.__name__})")
|
1155 |
+
response = client.models.generate_content(
|
1156 |
+
model=model_name,
|
1157 |
+
contents=prompt,
|
1158 |
+
config=current_gen_config,
|
1159 |
+
)
|
1160 |
+
if not is_response_valid(response):
|
1161 |
+
raise ValueError("Invalid or empty response received") # Trigger retry
|
1162 |
+
|
1163 |
+
openai_response = convert_to_openai_format(response, request.model)
|
1164 |
+
return JSONResponse(content=openai_response)
|
1165 |
+
except Exception as generate_error:
|
1166 |
+
error_msg = f"Error generating content (Model: {model_name}, Format: {prompt_func.__name__}): {str(generate_error)}"
|
1167 |
print(error_msg)
|
1168 |
+
# Raise error to signal failure for retry logic
|
1169 |
+
raise generate_error
|
1170 |
+
|
1171 |
+
|
1172 |
+
# --- Main Logic ---
|
1173 |
+
last_error = None
|
1174 |
+
|
1175 |
+
if is_auto_model:
|
1176 |
+
print(f"Processing auto model: {request.model}")
|
1177 |
+
attempts = [
|
1178 |
+
{"name": "base", "model": base_model_name, "prompt_func": create_gemini_prompt, "config_modifier": lambda c: c},
|
1179 |
+
{"name": "old_format", "model": base_model_name, "prompt_func": create_gemini_prompt_old, "config_modifier": lambda c: c},
|
1180 |
+
{"name": "encrypt", "model": base_model_name, "prompt_func": create_encrypted_gemini_prompt, "config_modifier": lambda c: c}
|
1181 |
+
]
|
1182 |
+
|
1183 |
+
for i, attempt in enumerate(attempts):
|
1184 |
+
print(f"Attempt {i+1}/{len(attempts)} using '{attempt['name']}' mode...")
|
1185 |
+
current_config = attempt["config_modifier"](generation_config.copy())
|
1186 |
+
|
1187 |
+
try:
|
1188 |
+
result = await make_gemini_call(attempt["model"], attempt["prompt_func"], current_config)
|
1189 |
+
|
1190 |
+
# For streaming, the result is StreamingResponse, success is determined inside make_gemini_call raising an error on failure
|
1191 |
+
# For non-streaming, if make_gemini_call doesn't raise, it's successful
|
1192 |
+
print(f"Attempt {i+1} ('{attempt['name']}') successful.")
|
1193 |
+
return result
|
1194 |
+
except Exception as e:
|
1195 |
+
last_error = e
|
1196 |
+
print(f"Attempt {i+1} ('{attempt['name']}') failed: {e}")
|
1197 |
+
if i < len(attempts) - 1:
|
1198 |
+
print("Waiting 1 second before next attempt...")
|
1199 |
+
await asyncio.sleep(1) # Use asyncio.sleep for async context
|
1200 |
+
else:
|
1201 |
+
print("All attempts failed.")
|
1202 |
|
1203 |
+
# If all attempts failed, return the last error
|
1204 |
+
error_msg = f"All retry attempts failed for model {request.model}. Last error: {str(last_error)}"
|
1205 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
1206 |
+
# If the last attempt was streaming and failed, the error response is already yielded by the generator.
|
1207 |
+
# If non-streaming failed last, return the JSON error.
|
1208 |
+
if not request.stream:
|
1209 |
+
return JSONResponse(status_code=500, content=error_response)
|
1210 |
+
else:
|
1211 |
+
# The StreamingResponse returned earlier will handle yielding the final error.
|
1212 |
+
# We should not return a new response here.
|
1213 |
+
# If we reach here after a failed stream, it means the initial StreamingResponse object was returned,
|
1214 |
+
# but the generator within it failed on the last attempt.
|
1215 |
+
# The generator itself handles yielding the error SSE.
|
1216 |
+
# We need to ensure the main function doesn't try to return another response.
|
1217 |
+
# Returning the 'result' from the failed attempt (which is the StreamingResponse object)
|
1218 |
+
# might be okay IF the generator correctly yields the error and DONE message.
|
1219 |
+
# Let's return the StreamingResponse object which contains the failing generator.
|
1220 |
+
# This assumes the generator correctly terminates after yielding the error.
|
1221 |
+
# Re-evaluate if this causes issues. The goal is to avoid double responses.
|
1222 |
+
# It seems returning the StreamingResponse object itself is the correct FastAPI pattern.
|
1223 |
+
return result # Return the StreamingResponse object which contains the failing generator
|
1224 |
+
|
1225 |
+
|
1226 |
else:
|
1227 |
+
# Handle non-auto models (base, search, encrypt)
|
1228 |
+
current_model_name = base_model_name
|
1229 |
+
current_prompt_func = create_gemini_prompt
|
1230 |
+
current_config = generation_config.copy()
|
1231 |
+
|
1232 |
+
if is_grounded_search:
|
1233 |
+
print(f"Using grounded search for model: {request.model}")
|
1234 |
+
search_tool = types.Tool(google_search=types.GoogleSearch())
|
1235 |
+
current_config["tools"] = [search_tool]
|
1236 |
+
elif is_encrypted_model:
|
1237 |
+
print(f"Using encrypted prompt for model: {request.model}")
|
1238 |
+
current_prompt_func = create_encrypted_gemini_prompt
|
1239 |
+
|
1240 |
try:
|
1241 |
+
result = await make_gemini_call(current_model_name, current_prompt_func, current_config)
|
1242 |
+
return result
|
1243 |
+
except Exception as e:
|
1244 |
+
# Handle potential errors for non-auto models
|
1245 |
+
error_msg = f"Error processing model {request.model}: {str(e)}"
|
1246 |
+
print(error_msg)
|
1247 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
1248 |
+
# Similar to auto-fail case, handle stream vs non-stream error return
|
1249 |
+
if not request.stream:
|
1250 |
+
return JSONResponse(status_code=500, content=error_response)
|
1251 |
+
else:
|
1252 |
+
# Let the StreamingResponse handle yielding the error
|
1253 |
+
return result # Return the StreamingResponse object containing the failing generator
|
1254 |
+
|
1255 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1256 |
except Exception as e:
|
1257 |
+
# Catch-all for unexpected errors during setup or logic flow
|
1258 |
+
error_msg = f"Unexpected error processing request: {str(e)}"
|
1259 |
print(error_msg)
|
1260 |
error_response = create_openai_error_response(500, error_msg, "server_error")
|
1261 |
+
# Ensure we return a JSON response even for stream requests if error happens early
|
1262 |
return JSONResponse(status_code=500, content=error_response)
|
1263 |
|
1264 |
+
# --- Need to import asyncio ---
|
1265 |
+
# import asyncio # Add this import at the top of the file # Already added below
|
1266 |
+
|
1267 |
# Health check endpoint
|
1268 |
@app.get("/health")
|
1269 |
def health_check(api_key: str = Depends(get_api_key)):
|