Spaces:
Running
Running
import os | |
import json | |
import requests | |
import re | |
import uvicorn | |
from datetime import datetime | |
from fastapi import FastAPI, Request, Response | |
from fastapi.responses import StreamingResponse | |
from http.server import BaseHTTPRequestHandler, HTTPServer | |
from urllib.parse import urlparse, parse_qs | |
import ijson | |
from google.oauth2.credentials import Credentials | |
from google_auth_oauthlib.flow import Flow | |
from google.auth.transport.requests import Request as GoogleAuthRequest | |
# --- Configuration --- | |
CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" | |
CLIENT_SECRET = "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" | |
SCOPES = [ | |
"https://www.googleapis.com/auth/cloud-platform", | |
"https://www.googleapis.com/auth/userinfo.email", | |
"https://www.googleapis.com/auth/userinfo.profile", | |
"openid", | |
] | |
GEMINI_DIR = os.path.dirname(os.path.abspath(__file__)) # Same directory as the script | |
CREDENTIAL_FILE = os.path.join(GEMINI_DIR, "oauth_creds.json") | |
CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com" | |
# --- Global State --- | |
credentials = None | |
user_project_id = None | |
app = FastAPI() | |
# Helper class to adapt a generator of bytes into a file-like object | |
# that ijson can read from. | |
class _GeneratorStream: | |
def __init__(self, generator): | |
self.generator = generator | |
self.buffer = b'' | |
def read(self, size=-1): | |
# This read implementation is crucial for streaming. | |
# It must not block to read the entire stream if size is -1. | |
if size == -1: | |
# If asked to read all, return what's in the buffer and get one more chunk. | |
try: | |
self.buffer += next(self.generator) | |
except StopIteration: | |
pass | |
data = self.buffer | |
self.buffer = b'' | |
return data | |
# Otherwise, read from the generator until we have enough bytes. | |
while len(self.buffer) < size: | |
try: | |
self.buffer += next(self.generator) | |
except StopIteration: | |
# Generator is exhausted. | |
break | |
data = self.buffer[:size] | |
self.buffer = self.buffer[size:] | |
return data | |
class _OAuthCallbackHandler(BaseHTTPRequestHandler): | |
auth_code = None | |
def do_GET(self): | |
query_components = parse_qs(urlparse(self.path).query) | |
code = query_components.get("code", [None])[0] | |
if code: | |
_OAuthCallbackHandler.auth_code = code | |
self.send_response(200) | |
self.send_header("Content-type", "text/html") | |
self.end_headers() | |
self.wfile.write(b"<h1>Authentication successful!</h1><p>You can close this window and restart the proxy.</p>") | |
else: | |
self.send_response(400) | |
self.send_header("Content-type", "text/html") | |
self.end_headers() | |
self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>") | |
def get_user_project_id(creds): | |
"""Gets the user's project ID from cache or by probing the API.""" | |
global user_project_id | |
if user_project_id: | |
return user_project_id | |
# First, try to load project ID from credential file | |
if os.path.exists(CREDENTIAL_FILE): | |
try: | |
with open(CREDENTIAL_FILE, "r") as f: | |
creds_data = json.load(f) | |
cached_project_id = creds_data.get("project_id") | |
if cached_project_id: | |
user_project_id = cached_project_id | |
print(f"Loaded project ID from cache: {user_project_id}") | |
return user_project_id | |
except Exception as e: | |
print(f"Could not load project ID from cache: {e}") | |
# If not found in cache, probe for it | |
print("Project ID not found in cache. Probing for user project ID...") | |
headers = { | |
"Authorization": f"Bearer {creds.token}", | |
"Content-Type": "application/json", | |
} | |
probe_payload = { | |
"cloudaicompanionProject": "gcp-project", | |
"metadata": { | |
"ideType": "VSCODE", | |
"pluginType": "GEMINI" | |
} | |
} | |
try: | |
resp = requests.post( | |
f"{CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist", | |
data=json.dumps(probe_payload), | |
headers=headers, | |
) | |
resp.raise_for_status() | |
data = resp.json() | |
user_project_id = data.get("cloudaicompanionProject") | |
if not user_project_id: | |
raise ValueError("Could not find 'cloudaicompanionProject' in loadCodeAssist response.") | |
print(f"Successfully fetched user project ID: {user_project_id}") | |
# Save the project ID to the credential file for future use | |
save_credentials(creds, user_project_id) | |
print("Project ID saved to credential file for future use.") | |
return user_project_id | |
except requests.exceptions.HTTPError as e: | |
print(f"Error fetching project ID: {e.response.text}") | |
raise | |
def save_credentials(creds, project_id=None): | |
os.makedirs(GEMINI_DIR, exist_ok=True) | |
creds_data = { | |
"access_token": creds.token, | |
"refresh_token": creds.refresh_token, | |
"scope": " ".join(creds.scopes), | |
"token_type": "Bearer", | |
"expiry_date": creds.expiry.isoformat() if creds.expiry else None, | |
} | |
# If project_id is provided, save it; otherwise preserve existing project_id | |
if project_id: | |
creds_data["project_id"] = project_id | |
elif os.path.exists(CREDENTIAL_FILE): | |
try: | |
with open(CREDENTIAL_FILE, "r") as f: | |
existing_data = json.load(f) | |
if "project_id" in existing_data: | |
creds_data["project_id"] = existing_data["project_id"] | |
except Exception: | |
pass # If we can't read existing file, just continue without project_id | |
with open(CREDENTIAL_FILE, "w") as f: | |
json.dump(creds_data, f) | |
def get_credentials(): | |
"""Loads credentials from cache or initiates the OAuth 2.0 flow.""" | |
global credentials | |
if credentials: | |
if credentials.valid: | |
return credentials | |
if credentials.expired and credentials.refresh_token: | |
print("Refreshing expired credentials...") | |
try: | |
credentials.refresh(GoogleAuthRequest()) | |
save_credentials(credentials) | |
print("Credentials refreshed successfully.") | |
return credentials | |
except Exception as e: | |
print(f"Could not refresh token: {e}. Attempting to load from file.") | |
if os.path.exists(CREDENTIAL_FILE): | |
try: | |
with open(CREDENTIAL_FILE, "r") as f: | |
creds_data = json.load(f) | |
# Load project ID if available | |
global user_project_id | |
cached_project_id = creds_data.get("project_id") | |
if cached_project_id: | |
user_project_id = cached_project_id | |
print(f"Loaded project ID from credential file: {user_project_id}") | |
expiry = None | |
expiry_str = creds_data.get("expiry_date") | |
if expiry_str: | |
if not isinstance(expiry_str, str) or not expiry_str.strip(): | |
expiry = None | |
elif expiry_str.endswith('Z'): | |
expiry_str = expiry_str[:-1] + '+00:00' | |
expiry = datetime.fromisoformat(expiry_str) | |
else: | |
expiry = datetime.fromisoformat(expiry_str) | |
credentials = Credentials( | |
token=creds_data.get("access_token"), | |
refresh_token=creds_data.get("refresh_token"), | |
token_uri="https://oauth2.googleapis.com/token", | |
client_id=CLIENT_ID, | |
client_secret=CLIENT_SECRET, | |
scopes=SCOPES, | |
expiry=expiry | |
) | |
if credentials.expired and credentials.refresh_token: | |
print("Loaded credentials from file are expired. Refreshing...") | |
credentials.refresh(GoogleAuthRequest()) | |
save_credentials(credentials) | |
print("Successfully loaded credentials from cache.") | |
return credentials | |
except Exception as e: | |
print(f"Could not load cached credentials: {e}. Starting new login.") | |
client_config = { | |
"installed": { | |
"client_id": CLIENT_ID, | |
"client_secret": CLIENT_SECRET, | |
"auth_uri": "https://accounts.google.com/o/oauth2/auth", | |
"token_uri": "https://oauth2.googleapis.com/token", | |
} | |
} | |
flow = Flow.from_client_config( | |
client_config, scopes=SCOPES, redirect_uri="http://localhost:8080" | |
) | |
auth_url, _ = flow.authorization_url(access_type="offline", prompt="consent") | |
print(f"\nPlease open this URL in your browser to log in:\n{auth_url}\n") | |
server = HTTPServer(("", 8080), _OAuthCallbackHandler) | |
server.handle_request() | |
auth_code = _OAuthCallbackHandler.auth_code | |
if not auth_code: | |
print("Failed to retrieve authorization code.") | |
return None | |
flow.fetch_token(code=auth_code) | |
credentials = flow.credentials | |
save_credentials(credentials) | |
print("Authentication successful! Credentials saved.") | |
return credentials | |
async def proxy_request(request: Request, full_path: str): | |
creds = get_credentials() | |
if not creds: | |
return Response(content="Authentication failed. Please restart the proxy to log in.", status_code=500) | |
proj_id = get_user_project_id(creds) | |
if not proj_id: | |
return Response(content="Failed to get user project ID.", status_code=500) | |
post_data = await request.body() | |
path = f"/{full_path}" | |
model_name_from_url = None | |
action = None | |
model_match = re.match(r"/(v\d+(?:beta)?)/models/([^:]+):(\w+)", path) | |
is_streaming = False | |
if model_match: | |
model_name_from_url = model_match.group(2) | |
action = model_match.group(3) | |
target_url = f"{CODE_ASSIST_ENDPOINT}/v1internal:{action}" | |
if "stream" in action.lower(): | |
is_streaming = True | |
else: | |
target_url = f"{CODE_ASSIST_ENDPOINT}{path}" | |
try: | |
incoming_json = json.loads(post_data) | |
final_model = model_name_from_url if model_match else incoming_json.get("model") | |
structured_payload = { | |
"model": final_model, | |
"project": proj_id, | |
"request": { | |
"contents": incoming_json.get("contents"), | |
"systemInstruction": incoming_json.get("systemInstruction"), | |
"cachedContent": incoming_json.get("cachedContent"), | |
"tools": incoming_json.get("tools"), | |
"toolConfig": incoming_json.get("toolConfig"), | |
"safetySettings": incoming_json.get("safetySettings"), | |
"generationConfig": incoming_json.get("generationConfig"), | |
}, | |
} | |
structured_payload["request"] = { | |
k: v | |
for k, v in structured_payload["request"].items() | |
if v is not None | |
} | |
final_post_data = json.dumps(structured_payload) | |
except (json.JSONDecodeError, AttributeError): | |
final_post_data = post_data | |
headers = { | |
"Authorization": f"Bearer {creds.token}", | |
"Content-Type": "application/json", | |
# We remove 'Accept-Encoding' to allow the server to send gzip, | |
# which it seems to stream correctly. We will decompress on the fly. | |
} | |
if is_streaming: | |
async def stream_generator(): | |
try: | |
print(f"[STREAM] Starting streaming request to: {target_url}") | |
print(f"[STREAM] Request payload size: {len(final_post_data)} bytes") | |
with requests.post(target_url, data=final_post_data, headers=headers, stream=True) as resp: | |
print(f"[STREAM] Response status: {resp.status_code}") | |
print(f"[STREAM] Response headers: {dict(resp.headers)}") | |
resp.raise_for_status() | |
buffer = "" | |
brace_count = 0 | |
in_array = False | |
chunk_count = 0 | |
total_bytes = 0 | |
objects_yielded = 0 | |
print(f"[STREAM] Starting to process chunks...") | |
for chunk in resp.iter_content(chunk_size=1024, decode_unicode=True): | |
chunk_count += 1 | |
chunk_size = len(chunk) if chunk else 0 | |
total_bytes += chunk_size | |
print(f"[STREAM] Chunk #{chunk_count}: {chunk_size} bytes, total: {total_bytes} bytes") | |
if chunk: | |
print(f"[STREAM] Chunk content preview: {repr(chunk[:100])}") | |
buffer += chunk | |
print(f"[STREAM] Buffer size after chunk: {len(buffer)} chars") | |
# Process complete JSON objects from the buffer | |
processing_iterations = 0 | |
while buffer: | |
processing_iterations += 1 | |
if processing_iterations > 100: # Prevent infinite loops | |
print(f"[STREAM] WARNING: Too many processing iterations, breaking") | |
break | |
buffer = buffer.lstrip() | |
if not buffer: | |
print(f"[STREAM] Buffer empty after lstrip") | |
break | |
print(f"[STREAM] Processing buffer (len={len(buffer)}): {repr(buffer[:50])}") | |
# Handle array start | |
if buffer.startswith('[') and not in_array: | |
print(f"[STREAM] Found array start, entering array mode") | |
buffer = buffer[1:].lstrip() | |
in_array = True | |
continue | |
# Handle array end | |
if buffer.startswith(']'): | |
print(f"[STREAM] Found array end, stopping processing") | |
break | |
# Skip commas between objects | |
if buffer.startswith(','): | |
print(f"[STREAM] Skipping comma separator") | |
buffer = buffer[1:].lstrip() | |
continue | |
# Look for complete JSON objects | |
if buffer.startswith('{'): | |
print(f"[STREAM] Found object start, parsing JSON object...") | |
brace_count = 0 | |
in_string = False | |
escape_next = False | |
end_pos = -1 | |
for i, char in enumerate(buffer): | |
if escape_next: | |
escape_next = False | |
continue | |
if char == '\\': | |
escape_next = True | |
continue | |
if char == '"' and not escape_next: | |
in_string = not in_string | |
continue | |
if not in_string: | |
if char == '{': | |
brace_count += 1 | |
elif char == '}': | |
brace_count -= 1 | |
if brace_count == 0: | |
end_pos = i + 1 | |
break | |
if end_pos > 0: | |
# Found complete JSON object | |
json_str = buffer[:end_pos] | |
buffer = buffer[end_pos:].lstrip() | |
print(f"[STREAM] Found complete JSON object ({len(json_str)} chars): {repr(json_str[:200])}") | |
try: | |
obj = json.loads(json_str) | |
print(f"[STREAM] Successfully parsed JSON object with keys: {list(obj.keys())}") | |
if "response" in obj: | |
response_chunk = obj["response"] | |
objects_yielded += 1 | |
response_json = json.dumps(response_chunk) | |
print(f"[STREAM] Yielding object #{objects_yielded} (response size: {len(response_json)} chars)") | |
print(f"[STREAM] Response content preview: {repr(response_json[:200])}") | |
yield f"data: {response_json}\n\n" | |
else: | |
print(f"[STREAM] Object has no 'response' key, skipping") | |
except json.JSONDecodeError as e: | |
print(f"[STREAM] Failed to parse JSON object: {e}") | |
print(f"[STREAM] Problematic JSON: {repr(json_str[:500])}") | |
continue | |
else: | |
# Incomplete object, wait for more data | |
print(f"[STREAM] Incomplete JSON object (brace_count={brace_count}), waiting for more data") | |
break | |
else: | |
# Skip unexpected characters | |
print(f"[STREAM] Skipping unexpected character: {repr(buffer[0])}") | |
buffer = buffer[1:] | |
print(f"[STREAM] Finished processing. Total chunks: {chunk_count}, total bytes: {total_bytes}, objects yielded: {objects_yielded}") | |
except requests.exceptions.RequestException as e: | |
print(f"Error during streaming request: {e}") | |
error_message = json.dumps({"error": {"message": f"Upstream request failed: {e}"}}) | |
yield f"data: {error_message}\n\n" | |
except Exception as e: | |
print(f"An unexpected error occurred during streaming: {e}") | |
error_message = json.dumps({"error": {"message": f"An unexpected error occurred: {e}"}}) | |
yield f"data: {error_message}\n\n" | |
return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
else: | |
resp = requests.post(target_url, data=final_post_data, headers=headers) | |
if resp.status_code == 200: | |
try: | |
google_api_response = resp.json() | |
# The actual response is nested under the "response" key | |
# The actual response is nested under the "response" key | |
standard_gemini_response = google_api_response.get("response") | |
# The standard client expects a list containing the response object | |
return Response(content=json.dumps([standard_gemini_response]), status_code=200, media_type="application/json") | |
except (json.JSONDecodeError, AttributeError) as e: | |
print(f"Error converting to standard Gemini format: {e}") | |
# Fallback to sending the original content if conversion fails | |
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("Content-Type")) | |
else: | |
return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("Content-Type")) | |
if __name__ == "__main__": | |
print("Initializing credentials...") | |
creds = get_credentials() | |
if creds: | |
get_user_project_id(creds) | |
print("\nStarting Gemini proxy server on http://localhost:8888") | |
print("Send your Gemini API requests to this address.") | |
uvicorn.run(app, host="0.0.0.0", port=8888) | |
else: | |
print("\nCould not obtain credentials. Please authenticate and restart the server.") |