Spaces:
Running
Running
File size: 12,566 Bytes
d12a6b6 8061397 d12a6b6 0185608 a4ba91b d12a6b6 8061397 d12a6b6 8061397 d12a6b6 8061397 d12a6b6 c3b0824 8061397 d12a6b6 8061397 d12a6b6 8061397 d12a6b6 8061397 d12a6b6 8061397 d12a6b6 8061397 d12a6b6 8061397 d12a6b6 8061397 d12a6b6 e02039d 0185608 d12a6b6 cfdf66d d12a6b6 cfdf66d d12a6b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
"""
Google API Client - Handles all communication with Google's Gemini API.
This module is used by both OpenAI compatibility layer and native Gemini endpoints.
"""
import json
import logging
import requests
from fastapi import Response
from fastapi.responses import StreamingResponse
from google.auth.transport.requests import Request as GoogleAuthRequest
from .auth import get_credentials, save_credentials, get_user_project_id, onboard_user
from .utils import get_user_agent
from .config import (
CODE_ASSIST_ENDPOINT,
DEFAULT_SAFETY_SETTINGS,
get_base_model_name,
is_search_model,
get_thinking_budget,
should_include_thoughts
)
import asyncio
def send_gemini_request(payload: dict, is_streaming: bool = False) -> Response:
"""
Send a request to Google's Gemini API.
Args:
payload: The request payload in Gemini format
is_streaming: Whether this is a streaming request
Returns:
FastAPI Response object
"""
# Get and validate credentials
creds = get_credentials()
if not creds:
return Response(
content="Authentication failed. Please restart the proxy to log in.",
status_code=500
)
# Refresh credentials if needed
if creds.expired and creds.refresh_token:
try:
creds.refresh(GoogleAuthRequest())
save_credentials(creds)
except Exception as e:
return Response(
content="Token refresh failed. Please restart the proxy to re-authenticate.",
status_code=500
)
elif not creds.token:
return Response(
content="No access token. Please restart the proxy to re-authenticate.",
status_code=500
)
# Get project ID and onboard user
proj_id = get_user_project_id(creds)
if not proj_id:
return Response(content="Failed to get user project ID.", status_code=500)
onboard_user(creds, proj_id)
# Build the final payload with project info
final_payload = {
"model": payload.get("model"),
"project": proj_id,
"request": payload.get("request", {})
}
# Determine the action and URL
action = "streamGenerateContent" if is_streaming else "generateContent"
target_url = f"{CODE_ASSIST_ENDPOINT}/v1internal:{action}"
if is_streaming:
target_url += "?alt=sse"
# Build request headers
request_headers = {
"Authorization": f"Bearer {creds.token}",
"Content-Type": "application/json",
"User-Agent": get_user_agent(),
}
final_post_data = json.dumps(final_payload)
# Send the request
try:
if is_streaming:
resp = requests.post(target_url, data=final_post_data, headers=request_headers, stream=True)
return _handle_streaming_response(resp)
else:
resp = requests.post(target_url, data=final_post_data, headers=request_headers)
return _handle_non_streaming_response(resp)
except requests.exceptions.RequestException as e:
logging.error(f"Request to Google API failed: {str(e)}")
return Response(
content=json.dumps({"error": {"message": f"Request failed: {str(e)}"}}),
status_code=500,
media_type="application/json"
)
except Exception as e:
logging.error(f"Unexpected error during Google API request: {str(e)}")
return Response(
content=json.dumps({"error": {"message": f"Unexpected error: {str(e)}"}}),
status_code=500,
media_type="application/json"
)
def _handle_streaming_response(resp) -> StreamingResponse:
"""Handle streaming response from Google API."""
# Check for HTTP errors before starting to stream
if resp.status_code != 200:
logging.error(f"Google API returned status {resp.status_code}: {resp.text}")
error_message = f"Google API error: {resp.status_code}"
try:
error_data = resp.json()
if "error" in error_data:
error_message = error_data["error"].get("message", error_message)
except:
pass
# Return error as a streaming response
async def error_generator():
error_response = {
"error": {
"message": error_message,
"type": "invalid_request_error" if resp.status_code == 404 else "api_error",
"code": resp.status_code
}
}
yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8')
response_headers = {
"Content-Type": "text/event-stream",
"Content-Disposition": "attachment",
"Vary": "Origin, X-Origin, Referer",
"X-XSS-Protection": "0",
"X-Frame-Options": "SAMEORIGIN",
"X-Content-Type-Options": "nosniff",
"Server": "ESF"
}
return StreamingResponse(
error_generator(),
media_type="text/event-stream",
headers=response_headers,
status_code=resp.status_code
)
async def stream_generator():
try:
with resp:
for chunk in resp.iter_lines():
if chunk:
if not isinstance(chunk, str):
chunk = chunk.decode('utf-8')
if chunk.startswith('data: '):
chunk = chunk[len('data: '):]
try:
obj = json.loads(chunk)
if "response" in obj:
response_chunk = obj["response"]
response_json = json.dumps(response_chunk, separators=(',', ':'))
response_line = f"data: {response_json}\n\n"
yield response_line.encode('utf-8')
await asyncio.sleep(0)
else:
obj_json = json.dumps(obj, separators=(',', ':'))
yield f"data: {obj_json}\n\n".encode('utf-8')
except json.JSONDecodeError:
continue
except requests.exceptions.RequestException as e:
logging.error(f"Streaming request failed: {str(e)}")
error_response = {
"error": {
"message": f"Upstream request failed: {str(e)}",
"type": "api_error",
"code": 502
}
}
yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8')
except Exception as e:
logging.error(f"Unexpected error during streaming: {str(e)}")
error_response = {
"error": {
"message": f"An unexpected error occurred: {str(e)}",
"type": "api_error",
"code": 500
}
}
yield f'data: {json.dumps(error_response)}\n\n'.encode('utf-8')
response_headers = {
"Content-Type": "text/event-stream",
"Content-Disposition": "attachment",
"Vary": "Origin, X-Origin, Referer",
"X-XSS-Protection": "0",
"X-Frame-Options": "SAMEORIGIN",
"X-Content-Type-Options": "nosniff",
"Server": "ESF"
}
return StreamingResponse(
stream_generator(),
media_type="text/event-stream",
headers=response_headers
)
def _handle_non_streaming_response(resp) -> Response:
"""Handle non-streaming response from Google API."""
if resp.status_code == 200:
try:
google_api_response = resp.text
if google_api_response.startswith('data: '):
google_api_response = google_api_response[len('data: '):]
google_api_response = json.loads(google_api_response)
standard_gemini_response = google_api_response.get("response")
return Response(
content=json.dumps(standard_gemini_response),
status_code=200,
media_type="application/json; charset=utf-8"
)
except (json.JSONDecodeError, AttributeError) as e:
logging.error(f"Failed to parse Google API response: {str(e)}")
return Response(
content=resp.content,
status_code=resp.status_code,
media_type=resp.headers.get("Content-Type")
)
else:
# Log the error details
logging.error(f"Google API returned status {resp.status_code}: {resp.text}")
# Try to parse error response and provide meaningful error message
try:
error_data = resp.json()
if "error" in error_data:
error_message = error_data["error"].get("message", f"API error: {resp.status_code}")
error_response = {
"error": {
"message": error_message,
"type": "invalid_request_error" if resp.status_code == 404 else "api_error",
"code": resp.status_code
}
}
return Response(
content=json.dumps(error_response),
status_code=resp.status_code,
media_type="application/json"
)
except (json.JSONDecodeError, KeyError):
pass
# Fallback to original response if we can't parse the error
return Response(
content=resp.content,
status_code=resp.status_code,
media_type=resp.headers.get("Content-Type")
)
def build_gemini_payload_from_openai(openai_payload: dict) -> dict:
"""
Build a Gemini API payload from an OpenAI-transformed request.
This is used when OpenAI requests are converted to Gemini format.
"""
# Extract model from the payload
model = openai_payload.get("model")
# Get safety settings or use defaults
safety_settings = openai_payload.get("safetySettings", DEFAULT_SAFETY_SETTINGS)
# Build the request portion
request_data = {
"contents": openai_payload.get("contents"),
"systemInstruction": openai_payload.get("systemInstruction"),
"cachedContent": openai_payload.get("cachedContent"),
"tools": openai_payload.get("tools"),
"toolConfig": openai_payload.get("toolConfig"),
"safetySettings": safety_settings,
"generationConfig": openai_payload.get("generationConfig", {}),
}
# Remove any keys with None values
request_data = {k: v for k, v in request_data.items() if v is not None}
return {
"model": model,
"request": request_data
}
def build_gemini_payload_from_native(native_request: dict, model_from_path: str) -> dict:
"""
Build a Gemini API payload from a native Gemini request.
This is used for direct Gemini API calls.
"""
native_request["safetySettings"] = DEFAULT_SAFETY_SETTINGS
if "generationConfig" not in native_request:
native_request["generationConfig"] = {}
if "thinkingConfig" not in native_request["generationConfig"]:
native_request["generationConfig"]["thinkingConfig"] = {}
# Configure thinking based on model variant
thinking_budget = get_thinking_budget(model_from_path)
include_thoughts = should_include_thoughts(model_from_path)
native_request["generationConfig"]["thinkingConfig"]["includeThoughts"] = include_thoughts
native_request["generationConfig"]["thinkingConfig"]["thinkingBudget"] = thinking_budget
# Add Google Search grounding for search models
if is_search_model(model_from_path):
if "tools" not in native_request:
native_request["tools"] = []
# Add googleSearch tool if not already present
if not any(tool.get("googleSearch") for tool in native_request["tools"]):
native_request["tools"].append({"googleSearch": {}})
return {
"model": get_base_model_name(model_from_path), # Use base model name for API call
"request": native_request
} |