Update app/main.py
Browse files- app/main.py +220 -6
app/main.py
CHANGED
@@ -15,6 +15,8 @@ import random
|
|
15 |
import urllib.parse
|
16 |
from google.oauth2 import service_account
|
17 |
import config
|
|
|
|
|
18 |
|
19 |
from google.genai import types
|
20 |
|
@@ -1149,6 +1151,15 @@ async def list_models(api_key: str = Depends(get_api_key)):
|
|
1149 |
"root": "gemini-2.5-pro-exp-03-25",
|
1150 |
"parent": None,
|
1151 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1152 |
{
|
1153 |
"id": "gemini-2.5-pro-preview-03-25",
|
1154 |
"object": "model",
|
@@ -1336,6 +1347,15 @@ def create_openai_error_response(status_code: int, message: str, error_type: str
|
|
1336 |
}
|
1337 |
}
|
1338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1339 |
@app.post("/v1/chat/completions")
|
1340 |
async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
|
1341 |
try:
|
@@ -1348,10 +1368,193 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1348 |
)
|
1349 |
return JSONResponse(status_code=400, content=error_response)
|
1350 |
|
1351 |
-
#
|
1352 |
-
|
1353 |
-
|
1354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1355 |
is_encrypted_full_model = request.model.endswith("-encrypt-full")
|
1356 |
is_nothinking_model = request.model.endswith("-nothinking")
|
1357 |
is_max_thinking_model = request.model.endswith("-max")
|
@@ -1418,7 +1621,8 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1418 |
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
1419 |
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
|
1420 |
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
|
1421 |
-
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF")
|
|
|
1422 |
]
|
1423 |
generation_config["safety_settings"] = safety_settings
|
1424 |
|
@@ -1518,8 +1722,18 @@ async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_ap
|
|
1518 |
# --- Main Logic ---
|
1519 |
last_error = None
|
1520 |
|
1521 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1522 |
print(f"Processing auto model: {request.model}")
|
|
|
1523 |
# Define encryption instructions for system_instruction
|
1524 |
encryption_instructions = [
|
1525 |
"// AI Assistant Configuration //",
|
|
|
15 |
import urllib.parse
|
16 |
from google.oauth2 import service_account
|
17 |
import config
|
18 |
+
import openai # Added import
|
19 |
+
from google.auth.transport.requests import Request as AuthRequest # Added import
|
20 |
|
21 |
from google.genai import types
|
22 |
|
|
|
1151 |
"root": "gemini-2.5-pro-exp-03-25",
|
1152 |
"parent": None,
|
1153 |
},
|
1154 |
+
{ # Added new model entry for OpenAI endpoint
|
1155 |
+
"id": "gemini-2.5-pro-exp-03-25-openai",
|
1156 |
+
"object": "model",
|
1157 |
+
"created": int(time.time()),
|
1158 |
+
"owned_by": "google",
|
1159 |
+
"permission": [],
|
1160 |
+
"root": "gemini-2.5-pro-exp-03-25", # Underlying model
|
1161 |
+
"parent": None,
|
1162 |
+
},
|
1163 |
{
|
1164 |
"id": "gemini-2.5-pro-preview-03-25",
|
1165 |
"object": "model",
|
|
|
1347 |
}
|
1348 |
}
|
1349 |
|
1350 |
+
# Helper for token refresh
|
1351 |
+
def _refresh_auth(credentials):
|
1352 |
+
try:
|
1353 |
+
credentials.refresh(AuthRequest())
|
1354 |
+
return credentials.token
|
1355 |
+
except Exception as e:
|
1356 |
+
print(f"Error refreshing GCP token: {e}")
|
1357 |
+
return None
|
1358 |
+
|
1359 |
@app.post("/v1/chat/completions")
|
1360 |
async def chat_completions(request: OpenAIRequest, api_key: str = Depends(get_api_key)): # Add request parameter
|
1361 |
try:
|
|
|
1368 |
)
|
1369 |
return JSONResponse(status_code=400, content=error_response)
|
1370 |
|
1371 |
+
# --- Handle specific OpenAI client model ---
|
1372 |
+
if request.model == "gemini-2.5-pro-exp-03-25-openai":
|
1373 |
+
print(f"INFO: Using OpenAI library path for model: {request.model}")
|
1374 |
+
|
1375 |
+
# --- Determine Credentials for OpenAI Client (Correct Priority) ---
|
1376 |
+
credentials_to_use = None
|
1377 |
+
project_id_to_use = None
|
1378 |
+
credential_source = "unknown"
|
1379 |
+
|
1380 |
+
# Priority 1: GOOGLE_CREDENTIALS_JSON (JSON String in Env Var)
|
1381 |
+
credentials_json_str = os.environ.get("GOOGLE_CREDENTIALS_JSON")
|
1382 |
+
if credentials_json_str:
|
1383 |
+
try:
|
1384 |
+
credentials_info = json.loads(credentials_json_str)
|
1385 |
+
if not isinstance(credentials_info, dict): raise ValueError("JSON is not a dict")
|
1386 |
+
required = ["type", "project_id", "private_key_id", "private_key", "client_email"]
|
1387 |
+
if any(f not in credentials_info for f in required): raise ValueError("Missing required fields")
|
1388 |
+
|
1389 |
+
credentials = service_account.Credentials.from_service_account_info(
|
1390 |
+
credentials_info, scopes=['https://www.googleapis.com/auth/cloud-platform']
|
1391 |
+
)
|
1392 |
+
project_id = credentials.project_id
|
1393 |
+
credentials_to_use = credentials
|
1394 |
+
project_id_to_use = project_id
|
1395 |
+
credential_source = "GOOGLE_CREDENTIALS_JSON env var"
|
1396 |
+
print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
|
1397 |
+
except Exception as e:
|
1398 |
+
print(f"WARNING: [OpenAI Path] Error processing GOOGLE_CREDENTIALS_JSON: {e}. Trying next method.")
|
1399 |
+
credentials_to_use = None # Ensure reset if failed
|
1400 |
+
|
1401 |
+
# Priority 2: Credential Manager (Rotated Files)
|
1402 |
+
if credentials_to_use is None:
|
1403 |
+
print(f"INFO: [OpenAI Path] Checking Credential Manager (directory: {credential_manager.credentials_dir})")
|
1404 |
+
rotated_credentials, rotated_project_id = credential_manager.get_next_credentials()
|
1405 |
+
if rotated_credentials and rotated_project_id:
|
1406 |
+
credentials_to_use = rotated_credentials
|
1407 |
+
project_id_to_use = rotated_project_id
|
1408 |
+
credential_source = f"Credential Manager file (Index: {credential_manager.current_index -1 if credential_manager.current_index > 0 else len(credential_manager.credentials_files) - 1})"
|
1409 |
+
print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
|
1410 |
+
else:
|
1411 |
+
print(f"INFO: [OpenAI Path] No credentials loaded via Credential Manager.")
|
1412 |
+
|
1413 |
+
# Priority 3: GOOGLE_APPLICATION_CREDENTIALS (File Path in Env Var)
|
1414 |
+
if credentials_to_use is None:
|
1415 |
+
file_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
|
1416 |
+
if file_path:
|
1417 |
+
print(f"INFO: [OpenAI Path] Checking GOOGLE_APPLICATION_CREDENTIALS file path: {file_path}")
|
1418 |
+
if os.path.exists(file_path):
|
1419 |
+
try:
|
1420 |
+
credentials = service_account.Credentials.from_service_account_file(
|
1421 |
+
file_path, scopes=['https://www.googleapis.com/auth/cloud-platform']
|
1422 |
+
)
|
1423 |
+
project_id = credentials.project_id
|
1424 |
+
credentials_to_use = credentials
|
1425 |
+
project_id_to_use = project_id
|
1426 |
+
credential_source = "GOOGLE_APPLICATION_CREDENTIALS file path"
|
1427 |
+
print(f"INFO: [OpenAI Path] Using credentials from {credential_source} for project: {project_id_to_use}")
|
1428 |
+
except Exception as e:
|
1429 |
+
print(f"ERROR: [OpenAI Path] Failed to load credentials from GOOGLE_APPLICATION_CREDENTIALS path ({file_path}): {e}")
|
1430 |
+
else:
|
1431 |
+
print(f"ERROR: [OpenAI Path] GOOGLE_APPLICATION_CREDENTIALS file does not exist at path: {file_path}")
|
1432 |
+
|
1433 |
+
# Error if no credentials found after all checks
|
1434 |
+
if credentials_to_use is None or project_id_to_use is None:
|
1435 |
+
error_msg = "No valid credentials found for OpenAI client path. Tried GOOGLE_CREDENTIALS_JSON, Credential Manager, and GOOGLE_APPLICATION_CREDENTIALS."
|
1436 |
+
print(f"ERROR: {error_msg}")
|
1437 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
1438 |
+
return JSONResponse(status_code=500, content=error_response)
|
1439 |
+
# --- Credentials Determined ---
|
1440 |
+
|
1441 |
+
# Get/Refresh GCP Token from the chosen credentials (credentials_to_use)
|
1442 |
+
gcp_token = None
|
1443 |
+
if credentials_to_use.expired or not credentials_to_use.token:
|
1444 |
+
print(f"INFO: [OpenAI Path] Refreshing GCP token (Source: {credential_source})...")
|
1445 |
+
gcp_token = _refresh_auth(credentials_to_use)
|
1446 |
+
else:
|
1447 |
+
gcp_token = credentials_to_use.token
|
1448 |
+
|
1449 |
+
if not gcp_token:
|
1450 |
+
error_msg = f"Failed to obtain valid GCP token for OpenAI client (Source: {credential_source})."
|
1451 |
+
print(f"ERROR: {error_msg}")
|
1452 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
1453 |
+
return JSONResponse(status_code=500, content=error_response)
|
1454 |
+
|
1455 |
+
# Configuration using determined Project ID
|
1456 |
+
PROJECT_ID = project_id_to_use
|
1457 |
+
LOCATION = "us-central1" # Assuming same location as genai client
|
1458 |
+
VERTEX_AI_OPENAI_ENDPOINT_URL = (
|
1459 |
+
f"https://{LOCATION}-aiplatform.googleapis.com/v1beta1/"
|
1460 |
+
f"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi"
|
1461 |
+
)
|
1462 |
+
UNDERLYING_MODEL_ID = "gemini-2.5-pro-exp-03-25" # As specified
|
1463 |
+
|
1464 |
+
# Initialize Async OpenAI Client
|
1465 |
+
openai_client = openai.AsyncOpenAI(
|
1466 |
+
base_url=VERTEX_AI_OPENAI_ENDPOINT_URL,
|
1467 |
+
api_key=gcp_token,
|
1468 |
+
)
|
1469 |
+
|
1470 |
+
# Define standard safety settings (as used elsewhere)
|
1471 |
+
openai_safety_settings = [
|
1472 |
+
{
|
1473 |
+
"category": "HARM_CATEGORY_HARASSMENT",
|
1474 |
+
"threshold": "OFF"
|
1475 |
+
},
|
1476 |
+
{
|
1477 |
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
1478 |
+
"threshold": "OFF"
|
1479 |
+
},
|
1480 |
+
{
|
1481 |
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
1482 |
+
"threshold": "OFF"
|
1483 |
+
},
|
1484 |
+
{
|
1485 |
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
1486 |
+
"threshold": "OFF"
|
1487 |
+
},
|
1488 |
+
{
|
1489 |
+
"category": 'HARM_CATEGORY_CIVIC_INTEGRITY',
|
1490 |
+
"threshold": 'OFF'
|
1491 |
+
}
|
1492 |
+
]
|
1493 |
+
|
1494 |
+
# Prepare parameters for OpenAI client call
|
1495 |
+
openai_params = {
|
1496 |
+
"model": UNDERLYING_MODEL_ID,
|
1497 |
+
"messages": [msg.model_dump(exclude_unset=True) for msg in request.messages],
|
1498 |
+
"temperature": request.temperature,
|
1499 |
+
"max_tokens": request.max_tokens,
|
1500 |
+
"top_p": request.top_p,
|
1501 |
+
"stream": request.stream,
|
1502 |
+
"stop": request.stop,
|
1503 |
+
# "presence_penalty": request.presence_penalty,
|
1504 |
+
# "frequency_penalty": request.frequency_penalty,
|
1505 |
+
"seed": request.seed,
|
1506 |
+
"n": request.n,
|
1507 |
+
# Note: logprobs/response_logprobs mapping might need adjustment
|
1508 |
+
# Note: top_k is not directly supported by standard OpenAI API spec
|
1509 |
+
}
|
1510 |
+
# Add safety settings via extra_body
|
1511 |
+
openai_extra_body = {
|
1512 |
+
'google': {
|
1513 |
+
'safety_settings': openai_safety_settings
|
1514 |
+
}
|
1515 |
+
}
|
1516 |
+
openai_params = {k: v for k, v in openai_params.items() if v is not None}
|
1517 |
+
|
1518 |
+
|
1519 |
+
# Make the call using OpenAI client
|
1520 |
+
if request.stream:
|
1521 |
+
async def openai_stream_generator():
|
1522 |
+
try:
|
1523 |
+
stream = await openai_client.chat.completions.create(
|
1524 |
+
**openai_params,
|
1525 |
+
extra_body=openai_extra_body # Pass safety settings here
|
1526 |
+
)
|
1527 |
+
async for chunk in stream:
|
1528 |
+
yield f"data: {chunk.model_dump_json()}\n\n"
|
1529 |
+
yield "data: [DONE]\n\n"
|
1530 |
+
except Exception as stream_error:
|
1531 |
+
error_msg = f"Error during OpenAI client streaming for {request.model}: {str(stream_error)}"
|
1532 |
+
print(error_msg)
|
1533 |
+
error_response_content = create_openai_error_response(500, error_msg, "server_error")
|
1534 |
+
yield f"data: {json.dumps(error_response_content)}\n\n"
|
1535 |
+
yield "data: [DONE]\n\n"
|
1536 |
+
|
1537 |
+
return StreamingResponse(openai_stream_generator(), media_type="text/event-stream")
|
1538 |
+
else:
|
1539 |
+
try:
|
1540 |
+
response = await openai_client.chat.completions.create(
|
1541 |
+
**openai_params,
|
1542 |
+
extra_body=openai_extra_body # Pass safety settings here
|
1543 |
+
)
|
1544 |
+
return JSONResponse(content=response.model_dump(exclude_unset=True))
|
1545 |
+
except Exception as generate_error:
|
1546 |
+
error_msg = f"Error calling OpenAI client for {request.model}: {str(generate_error)}"
|
1547 |
+
print(error_msg)
|
1548 |
+
error_response = create_openai_error_response(500, error_msg, "server_error")
|
1549 |
+
return JSONResponse(status_code=500, content=error_response)
|
1550 |
+
|
1551 |
+
# --- End of specific OpenAI client model handling ---
|
1552 |
+
|
1553 |
+
# Check model type and extract base model name (Changed to elif)
|
1554 |
+
elif request.model.endswith("-auto"):
|
1555 |
+
is_auto_model = True
|
1556 |
+
is_grounded_search = False
|
1557 |
+
is_encrypted_model = False
|
1558 |
is_encrypted_full_model = request.model.endswith("-encrypt-full")
|
1559 |
is_nothinking_model = request.model.endswith("-nothinking")
|
1560 |
is_max_thinking_model = request.model.endswith("-max")
|
|
|
1621 |
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
1622 |
types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"),
|
1623 |
types.SafetySetting(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"),
|
1624 |
+
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
1625 |
+
types.SafetySetting(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="OFF")
|
1626 |
]
|
1627 |
generation_config["safety_settings"] = safety_settings
|
1628 |
|
|
|
1722 |
# --- Main Logic ---
|
1723 |
last_error = None
|
1724 |
|
1725 |
+
# --- Main Logic --- (Ensure flags are correctly set if the first 'if' wasn't met)
|
1726 |
+
# Re-evaluate flags based on elif structure for clarity if needed, or rely on the fact that the first 'if' returned.
|
1727 |
+
is_auto_model = request.model.endswith("-auto") # This will be False if the first 'if' was True
|
1728 |
+
is_grounded_search = request.model.endswith("-search")
|
1729 |
+
is_encrypted_model = request.model.endswith("-encrypt")
|
1730 |
+
is_encrypted_full_model = request.model.endswith("-encrypt-full")
|
1731 |
+
is_nothinking_model = request.model.endswith("-nothinking")
|
1732 |
+
is_max_thinking_model = request.model.endswith("-max")
|
1733 |
+
|
1734 |
+
if is_auto_model: # This remains the primary check after the openai specific one
|
1735 |
print(f"Processing auto model: {request.model}")
|
1736 |
+
base_model_name = request.model.replace("-auto", "") # Ensure base_model_name is set here too
|
1737 |
# Define encryption instructions for system_instruction
|
1738 |
encryption_instructions = [
|
1739 |
"// AI Assistant Configuration //",
|