bibibi12345 commited on
Commit
9275605
·
verified ·
1 Parent(s): 35867dc

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +220 -6
app/main.py CHANGED
@@ -15,6 +15,8 @@ import random
15
  import urllib.parse
16
  from google.oauth2 import service_account
17
  import config
 
 
18
 
19
  from google.genai import types
20
 
@@ -1149,6 +1151,15 @@ async def list_models(api_key: str = Depends(get_api_key)):
1149
  "root": "gemini-2.5-pro-exp-03-25",
1150
  "parent": None,
1151
  },
 
 
 
 
 
 
 
 
 
1152
  {
1153
  "id": "gemini-2.5-pro-preview-03-25",
1154
  "object": "model",
@@ -1336,6 +1347,15 @@ def create_openai_error_response(status_code: int, message: str, error_type: str
1336
  }
1337
  }
1338
 
 
 
 
 
 
 
 
 
 
1339
  @app.post("/v1/chat/completions")
1340
  async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
1341
  try:
@@ -1348,10 +1368,193 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1348
  )
1349
  return JSONResponse(status_code=400, content=error_response)
1350
 
1351
- # Check model type and extract base model name
1352
- is_auto_model = request.model.endswith("-auto")
1353
- is_grounded_search = request.model.endswith("-search")
1354
- is_encrypted_model = request.model.endswith("-encrypt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1355
  is_encrypted_full_model = request.model.endswith("-encrypt-full")
1356
  is_nothinking_model = request.model.endswith("-nothinking")
1357
  is_max_thinking_model = request.model.endswith("-max")
@@ -1418,7 +1621,8 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1418
  types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
1419
  types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
1420
  types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
1421
- types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF")
 
1422
  ]
1423
  generation_config["safety_settings"] = safety_settings
1424
 
@@ -1518,8 +1722,18 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
1518
  # --- Main Logic ---
1519
  last_error = None
1520
 
1521
- if is_auto_model:
 
 
 
 
 
 
 
 
 
1522
  print(f"Processing auto model: {request.model}")
 
1523
  # Define encryption instructions for system_instruction
1524
  encryption_instructions = [
1525
  "// AI Assistant Configuration //",
 
15
  import urllib.parse
16
  from google.oauth2 import service_account
17
  import config
18
+ import openai # Added import
19
+ from google.auth.transport.requests import Request as AuthRequest # Added import
20
 
21
  from google.genai import types
22
 
 
1151
  "root": "gemini-2.5-pro-exp-03-25",
1152
  "parent": None,
1153
  },
1154
+ { # Added new model entry for OpenAI endpoint
1155
+ "id": "gemini-2.5-pro-exp-03-25-openai",
1156
+ "object": "model",
1157
+ "created": int(time.time()),
1158
+ "owned_by": "google",
1159
+ "permission": [],
1160
+ "root": "gemini-2.5-pro-exp-03-25", # Underlying model
1161
+ "parent": None,
1162
+ },
1163
  {
1164
  "id": "gemini-2.5-pro-preview-03-25",
1165
  "object": "model",
 
1347
  }
1348
  }
1349
 
1350
+ # Helper for token refresh
1351
+ def _refresh_auth(credentials):
1352
+ try:
1353
+ credentials.refresh(AuthRequest())
1354
+ return credentials.token
1355
+ except Exception as e:
1356
+ print(f"Error refreshing GCP token: {e}")
1357
+ return None
1358
+
1359
  @app.post("/v1/chat/completions")
1360
  async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
1361
  try:
 
1368
  )
1369
  return JSONResponse(status_code=400, content=error_response)
1370
 
1371
+ # --- Handle specific OpenAI client model ---
1372
+ if request.model == "gemini-2.5-pro-exp-03-25-openai":
1373
+ print(f"INFO: Using OpenAI library path for model: {request.model}")
1374
+
1375
+ # --- Determine Credentials for OpenAI Client (Correct Priority) ---
1376
+ credentials_to_use = None
1377
+ project_id_to_use = None
1378
+ credential_source = "unknown"
1379
+
1380
+ # Priority 1: GOOGLE_CREDENTIALS_JSON (JSON String in Env Var)
1381
+ credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
1382
+ if credentials_json_str:
1383
+ try:
1384
+ credentials_info = json.loads(credentials_json_str)
1385
+ if not isinstance(credentials_info, dict): raise ValueError("JSON is not a dict")
1386
+ required = ["type", "project_id", "private_key_id", "private_key", "client_email"]
1387
+ if any(f not in credentials_info for f in required): raise ValueError("Missing required fields")
1388
+
1389
+ credentials = service_account.Credentials.from_service_account_info(
1390
+ credentials_info, scopes=['https://www.googleapis.com/auth/cloud-platform']
1391
+ )
1392
+ project_id = credentials.project_id
1393
+ credentials_to_use = credentials
1394
+ project_id_to_use = project_id
1395
+ credential_source = "GOOGLE_CREDENTIALS_JSON env var"
1396
+ print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
1397
+ except Exception as e:
1398
+ print(f"WARNING: [OpenAI Path] Error processing GOOGLE_CREDENTIALS_JSON: {e}. Trying next method.")
1399
+ credentials_to_use = None # Ensure reset if failed
1400
+
1401
+ # Priority 2: Credential Manager (Rotated Files)
1402
+ if credentials_to_use is None:
1403
+ print(f"INFO: [OpenAI Path] Checking Credential Manager (directory: {credential_manager.credentials_dir})")
1404
+ rotated_credentials, rotated_project_id = credential_manager.get_next_credentials()
1405
+ if rotated_credentials and rotated_project_id:
1406
+ credentials_to_use = rotated_credentials
1407
+ project_id_to_use = rotated_project_id
1408
+ credential_source = f"Credential Manager file (Index: {credential_manager.current_index -1 if credential_manager.current_index > 0 else len(credential_manager.credentials_files) - 1})"
1409
+ print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
1410
+ else:
1411
+ print(f"INFO: [OpenAI Path] No credentials loaded via Credential Manager.")
1412
+
1413
+ # Priority 3: GOOGLE_APPLICATION_CREDENTIALS (File Path in Env Var)
1414
+ if credentials_to_use is None:
1415
+ file_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
1416
+ if file_path:
1417
+ print(f"INFO: [OpenAI Path] Checking GOOGLE_APPLICATION_CREDENTIALS file path: {file_path}")
1418
+ if os.path.exists(file_path):
1419
+ try:
1420
+ credentials = service_account.Credentials.from_service_account_file(
1421
+ file_path, scopes=['https://www.googleapis.com/auth/cloud-platform']
1422
+ )
1423
+ project_id = credentials.project_id
1424
+ credentials_to_use = credentials
1425
+ project_id_to_use = project_id
1426
+ credential_source = "GOOGLE_APPLICATION_CREDENTIALS file path"
1427
+ print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
1428
+ except Exception as e:
1429
+ print(f"ERROR: [OpenAI Path] Failed to load credentials from GOOGLE_APPLICATION_CREDENTIALS path ({file_path}): {e}")
1430
+ else:
1431
+ print(f"ERROR: [OpenAI Path] GOOGLE_APPLICATION_CREDENTIALS file does not exist at path: {file_path}")
1432
+
1433
+ # Error if no credentials found after all checks
1434
+ if credentials_to_use is None or project_id_to_use is None:
1435
+ error_msg = "No valid credentials found for OpenAI client path. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager, and GOOGLE_APPLICATION_CREDENTIALS."
1436
+ print(f"ERROR: {error_msg}")
1437
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1438
+ return JSONResponse(status_code=500, content=error_response)
1439
+ # --- Credentials Determined ---
1440
+
1441
+ # Get/Refresh GCP Token from the chosen credentials (credentials_to_use)
1442
+ gcp_token = None
1443
+ if credentials_to_use.expired or not credentials_to_use.token:
1444
+ print(f"INFO: [OpenAI Path] Refreshing GCP token (Source: {credential_source})...")
1445
+ gcp_token = _refresh_auth(credentials_to_use)
1446
+ else:
1447
+ gcp_token = credentials_to_use.token
1448
+
1449
+ if not gcp_token:
1450
+ error_msg = f"Failed to obtain valid GCP token for OpenAI client (Source: {credential_source})."
1451
+ print(f"ERROR: {error_msg}")
1452
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1453
+ return JSONResponse(status_code=500, content=error_response)
1454
+
1455
+ # Configuration using determined Project ID
1456
+ PROJECT_ID = project_id_to_use
1457
+ LOCATION = "us-central1" # Assuming same location as genai client
1458
+ VERTEX_AI_OPENAI_ENDPOINT_URL = (
1459
+ f"https://{LOCATION}-aiplatform.googleapis.com/v1beta1/"
1460
+ f"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi"
1461
+ )
1462
+ UNDERLYING_MODEL_ID = "gemini-2.5-pro-exp-03-25" # As specified
1463
+
1464
+ # Initialize Async OpenAI Client
1465
+ openai_client = openai.AsyncOpenAI(
1466
+ base_url=VERTEX_AI_OPENAI_ENDPOINT_URL,
1467
+ api_key=gcp_token,
1468
+ )
1469
+
1470
+ # Define standard safety settings (as used elsewhere)
1471
+ openai_safety_settings = [
1472
+ {
1473
+ "category": "HARM_CATEGORY_HARASSMENT",
1474
+ "threshold": "OFF"
1475
+ },
1476
+ {
1477
+ "category": "HARM_CATEGORY_HATE_SPEECH",
1478
+ "threshold": "OFF"
1479
+ },
1480
+ {
1481
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
1482
+ "threshold": "OFF"
1483
+ },
1484
+ {
1485
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
1486
+ "threshold": "OFF"
1487
+ },
1488
+ {
1489
+ "category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
1490
+ "threshold": 'OFF'
1491
+ }
1492
+ ]
1493
+
1494
+ # Prepare parameters for OpenAI client call
1495
+ openai_params = {
1496
+ "model": UNDERLYING_MODEL_ID,
1497
+ "messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
1498
+ "temperature": request.temperature,
1499
+ "max_tokens": request.max_tokens,
1500
+ "top_p": request.top_p,
1501
+ "stream": request.stream,
1502
+ "stop": request.stop,
1503
+ # "presence_penalty": request.presence_penalty,
1504
+ # "frequency_penalty": request.frequency_penalty,
1505
+ "seed": request.seed,
1506
+ "n": request.n,
1507
+ # Note: logprobs/response_logprobs mapping might need adjustment
1508
+ # Note: top_k is not directly supported by standard OpenAI API spec
1509
+ }
1510
+ # Add safety settings via extra_body
1511
+ openai_extra_body = {
1512
+ 'google': {
1513
+ 'safety_settings': openai_safety_settings
1514
+ }
1515
+ }
1516
+ openai_params = {k: v for k, v in openai_params.items() if v is not None}
1517
+
1518
+
1519
+ # Make the call using OpenAI client
1520
+ if request.stream:
1521
+ async def openai_stream_generator():
1522
+ try:
1523
+ stream = await openai_client.chat.completions.create(
1524
+ **openai_params,
1525
+ extra_body=openai_extra_body # Pass safety settings here
1526
+ )
1527
+ async for chunk in stream:
1528
+ yield f"data: {chunk.model_dump_json()}\n\n"
1529
+ yield "data: [DONE]\n\n"
1530
+ except Exception as stream_error:
1531
+ error_msg = f"Error during OpenAI client streaming for {request.model}: {str(stream_error)}"
1532
+ print(error_msg)
1533
+ error_response_content = create_openai_error_response(500, error_msg, "server_error")
1534
+ yield f"data: {json.dumps(error_response_content)}\n\n"
1535
+ yield "data: [DONE]\n\n"
1536
+
1537
+ return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
1538
+ else:
1539
+ try:
1540
+ response = await openai_client.chat.completions.create(
1541
+ **openai_params,
1542
+ extra_body=openai_extra_body # Pass safety settings here
1543
+ )
1544
+ return JSONResponse(content=response.model_dump(exclude_unset=True))
1545
+ except Exception as generate_error:
1546
+ error_msg = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
1547
+ print(error_msg)
1548
+ error_response = create_openai_error_response(500, error_msg, "server_error")
1549
+ return JSONResponse(status_code=500, content=error_response)
1550
+
1551
+ # --- End of specific OpenAI client model handling ---
1552
+
1553
+ # Check model type and extract base model name (Changed to elif)
1554
+ elif request.model.endswith("-auto"):
1555
+ is_auto_model = True
1556
+ is_grounded_search = False
1557
+ is_encrypted_model = False
1558
  is_encrypted_full_model = request.model.endswith("-encrypt-full")
1559
  is_nothinking_model = request.model.endswith("-nothinking")
1560
  is_max_thinking_model = request.model.endswith("-max")
 
1621
  types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
1622
  types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
1623
  types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
1624
+ types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
1625
+ types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")
1626
  ]
1627
  generation_config["safety_settings"] = safety_settings
1628
 
 
1722
  # --- Main Logic ---
1723
  last_error = None
1724
 
1725
+ # --- Main Logic --- (Ensure flags are correctly set if the first 'if' wasn't met)
1726
+ # Re-evaluate flags based on elif structure for clarity if needed, or rely on the fact that the first 'if' returned.
1727
+ is_auto_model = request.model.endswith("-auto") # This will be False if the first 'if' was True
1728
+ is_grounded_search = request.model.endswith("-search")
1729
+ is_encrypted_model = request.model.endswith("-encrypt")
1730
+ is_encrypted_full_model = request.model.endswith("-encrypt-full")
1731
+ is_nothinking_model = request.model.endswith("-nothinking")
1732
+ is_max_thinking_model = request.model.endswith("-max")
1733
+
1734
+ if is_auto_model: # This remains the primary check after the openai specific one
1735
  print(f"Processing auto model: {request.model}")
1736
+ base_model_name = request.model.replace("-auto", "") # Ensure base_model_name is set here too
1737
  # Define encryption instructions for system_instruction
1738
  encryption_instructions = [
1739
  "// AI Assistant Configuration //",