Spaces:
Running
Running
import os | |
import json | |
import base64 | |
import time | |
from datetime import datetime | |
from fastapi import Request, HTTPException, Depends | |
from fastapi.security import HTTPBasic | |
from http.server import BaseHTTPRequestHandler, HTTPServer | |
from urllib.parse import urlparse, parse_qs | |
from google.oauth2.credentials import Credentials | |
from google_auth_oauthlib.flow import Flow | |
from google.auth.transport.requests import Request as GoogleAuthRequest | |
from .utils import get_user_agent, get_client_metadata | |
from .config import ( | |
CLIENT_ID, CLIENT_SECRET, SCOPES, CREDENTIAL_FILE, | |
CODE_ASSIST_ENDPOINT, GEMINI_AUTH_PASSWORD | |
) | |
# --- Global State --- | |
credentials = None | |
user_project_id = None | |
onboarding_complete = False | |
credentials_from_env = False # Track if credentials came from environment variable | |
security = HTTPBasic() | |
class _OAuthCallbackHandler(BaseHTTPRequestHandler): | |
auth_code = None | |
def do_GET(self): | |
query_components = parse_qs(urlparse(self.path).query) | |
code = query_components.get("code", [None])[0] | |
if code: | |
_OAuthCallbackHandler.auth_code = code | |
self.send_response(200) | |
self.send_header("Content-type", "text/html") | |
self.end_headers() | |
self.wfile.write(b"<h1>OAuth authentication successful!</h1><p>You can close this window. Please check the proxy server logs to verify that onboarding completed successfully. No need to restart the proxy.</p>") | |
else: | |
self.send_response(400) | |
self.send_header("Content-type", "text/html") | |
self.end_headers() | |
self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>") | |
def authenticate_user(request: Request): | |
"""Authenticate the user with multiple methods.""" | |
# Check for API key in query parameters first (for Gemini client compatibility) | |
api_key = request.query_params.get("key") | |
if api_key and api_key == GEMINI_AUTH_PASSWORD: | |
return "api_key_user" | |
# Check for API key in x-goog-api-key header (Google SDK format) | |
goog_api_key = request.headers.get("x-goog-api-key", "") | |
if goog_api_key and goog_api_key == GEMINI_AUTH_PASSWORD: | |
return "goog_api_key_user" | |
# Check for API key in Authorization header (Bearer token format) | |
auth_header = request.headers.get("authorization", "") | |
if auth_header.startswith("Bearer "): | |
bearer_token = auth_header[7:] | |
if bearer_token == GEMINI_AUTH_PASSWORD: | |
return "bearer_user" | |
# Check for HTTP Basic Authentication | |
if auth_header.startswith("Basic "): | |
try: | |
encoded_credentials = auth_header[6:] | |
decoded_credentials = base64.b64decode(encoded_credentials).decode('utf-8') | |
username, password = decoded_credentials.split(':', 1) | |
if password == GEMINI_AUTH_PASSWORD: | |
return username | |
except Exception: | |
pass | |
# If none of the authentication methods work | |
raise HTTPException( | |
status_code=401, | |
detail="Invalid authentication credentials. Use HTTP Basic Auth, Bearer token, 'key' query parameter, or 'x-goog-api-key' header.", | |
headers={"WWW-Authenticate": "Basic"}, | |
) | |
def save_credentials(creds, project_id=None): | |
global credentials_from_env | |
# Don't save to file if credentials came from environment variable | |
if credentials_from_env: | |
return | |
creds_data = { | |
"client_id": CLIENT_ID, | |
"client_secret": CLIENT_SECRET, | |
"token": creds.token, | |
"refresh_token": creds.refresh_token, | |
"scopes": creds.scopes if creds.scopes else SCOPES, | |
"token_uri": "https://oauth2.googleapis.com/token", | |
} | |
if creds.expiry: | |
if creds.expiry.tzinfo is None: | |
from datetime import timezone | |
expiry_utc = creds.expiry.replace(tzinfo=timezone.utc) | |
else: | |
expiry_utc = creds.expiry | |
creds_data["expiry"] = expiry_utc.isoformat() | |
if project_id: | |
creds_data["project_id"] = project_id | |
elif os.path.exists(CREDENTIAL_FILE): | |
try: | |
with open(CREDENTIAL_FILE, "r") as f: | |
existing_data = json.load(f) | |
if "project_id" in existing_data: | |
creds_data["project_id"] = existing_data["project_id"] | |
except Exception: | |
pass | |
with open(CREDENTIAL_FILE, "w") as f: | |
json.dump(creds_data, f, indent=2) | |
def get_credentials(): | |
"""Loads credentials matching gemini-cli OAuth2 flow.""" | |
global credentials, credentials_from_env | |
if credentials and credentials.token: | |
return credentials | |
# Check for credentials in environment variable (JSON string) | |
env_creds_json = os.getenv("GEMINI_CREDENTIALS") | |
if env_creds_json: | |
try: | |
creds_data = json.loads(env_creds_json) | |
credentials = Credentials.from_authorized_user_info(creds_data, SCOPES) | |
credentials_from_env = True # Mark as environment credentials | |
if credentials.refresh_token: | |
try: | |
credentials.refresh(GoogleAuthRequest()) | |
except Exception as refresh_error: | |
pass # Use credentials as-is if refresh fails | |
return credentials | |
except Exception as e: | |
pass # Fall through to file-based credentials | |
# Check for credentials file (CREDENTIAL_FILE now includes GOOGLE_APPLICATION_CREDENTIALS path if set) | |
if os.path.exists(CREDENTIAL_FILE): | |
try: | |
with open(CREDENTIAL_FILE, "r") as f: | |
creds_data = json.load(f) | |
if "access_token" in creds_data and "token" not in creds_data: | |
creds_data["token"] = creds_data["access_token"] | |
if "scope" in creds_data and "scopes" not in creds_data: | |
creds_data["scopes"] = creds_data["scope"].split() | |
credentials = Credentials.from_authorized_user_info(creds_data, SCOPES) | |
# Mark as environment credentials if GOOGLE_APPLICATION_CREDENTIALS was used | |
credentials_from_env = bool(os.getenv("GOOGLE_APPLICATION_CREDENTIALS")) | |
if credentials.refresh_token: | |
try: | |
credentials.refresh(GoogleAuthRequest()) | |
save_credentials(credentials) | |
except Exception as refresh_error: | |
pass # Use credentials as-is if refresh fails | |
return credentials | |
except Exception as e: | |
pass # Fall through to new login | |
client_config = { | |
"installed": { | |
"client_id": CLIENT_ID, | |
"client_secret": CLIENT_SECRET, | |
"auth_uri": "https://accounts.google.com/o/oauth2/auth", | |
"token_uri": "https://oauth2.googleapis.com/token", | |
} | |
} | |
flow = Flow.from_client_config( | |
client_config, | |
scopes=SCOPES, | |
redirect_uri="http://localhost:8080" | |
) | |
flow.oauth2session.scope = SCOPES | |
auth_url, _ = flow.authorization_url( | |
access_type="offline", | |
prompt="consent", | |
include_granted_scopes='true' | |
) | |
print(f"\nPlease open this URL in your browser to log in:\n{auth_url}\n") | |
server = HTTPServer(("", 8080), _OAuthCallbackHandler) | |
server.handle_request() | |
auth_code = _OAuthCallbackHandler.auth_code | |
if not auth_code: | |
return None | |
import oauthlib.oauth2.rfc6749.parameters | |
original_validate = oauthlib.oauth2.rfc6749.parameters.validate_token_parameters | |
def patched_validate(params): | |
try: | |
return original_validate(params) | |
except Warning: | |
pass | |
oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = patched_validate | |
try: | |
flow.fetch_token(code=auth_code) | |
credentials = flow.credentials | |
credentials_from_env = False # Mark as file-based credentials | |
save_credentials(credentials) | |
print("Authentication successful! Credentials saved.") | |
return credentials | |
except Exception as e: | |
print(f"Authentication failed: {e}") | |
return None | |
finally: | |
oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = original_validate | |
def onboard_user(creds, project_id): | |
"""Ensures the user is onboarded, matching gemini-cli setupUser behavior.""" | |
global onboarding_complete | |
if onboarding_complete: | |
return | |
if creds.expired and creds.refresh_token: | |
try: | |
creds.refresh(GoogleAuthRequest()) | |
save_credentials(creds) | |
except Exception as e: | |
raise Exception(f"Failed to refresh credentials during onboarding: {str(e)}") | |
headers = { | |
"Authorization": f"Bearer {creds.token}", | |
"Content-Type": "application/json", | |
"User-Agent": get_user_agent(), | |
} | |
load_assist_payload = { | |
"cloudaicompanionProject": project_id, | |
"metadata": get_client_metadata(project_id), | |
} | |
try: | |
import requests | |
resp = requests.post( | |
f"{CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist", | |
data=json.dumps(load_assist_payload), | |
headers=headers, | |
) | |
resp.raise_for_status() | |
load_data = resp.json() | |
tier = None | |
if load_data.get("currentTier"): | |
tier = load_data["currentTier"] | |
else: | |
for allowed_tier in load_data.get("allowedTiers", []): | |
if allowed_tier.get("isDefault"): | |
tier = allowed_tier | |
break | |
if not tier: | |
tier = { | |
"name": "", | |
"description": "", | |
"id": "legacy-tier", | |
"userDefinedCloudaicompanionProject": True, | |
} | |
if tier.get("userDefinedCloudaicompanionProject") and not project_id: | |
raise ValueError("This account requires setting the GOOGLE_CLOUD_PROJECT env var.") | |
if load_data.get("currentTier"): | |
onboarding_complete = True | |
return | |
onboard_req_payload = { | |
"tierId": tier.get("id"), | |
"cloudaicompanionProject": project_id, | |
"metadata": get_client_metadata(project_id), | |
} | |
while True: | |
onboard_resp = requests.post( | |
f"{CODE_ASSIST_ENDPOINT}/v1internal:onboardUser", | |
data=json.dumps(onboard_req_payload), | |
headers=headers, | |
) | |
onboard_resp.raise_for_status() | |
lro_data = onboard_resp.json() | |
if lro_data.get("done"): | |
onboarding_complete = True | |
break | |
time.sleep(5) | |
except requests.exceptions.HTTPError as e: | |
raise Exception(f"User onboarding failed. Please check your Google Cloud project permissions and try again. Error: {e.response.text if hasattr(e, 'response') else str(e)}") | |
except Exception as e: | |
raise Exception(f"User onboarding failed due to an unexpected error: {str(e)}") | |
def get_user_project_id(creds): | |
"""Gets the user's project ID matching gemini-cli setupUser logic.""" | |
global user_project_id | |
if user_project_id: | |
return user_project_id | |
env_project_id = os.getenv("GOOGLE_CLOUD_PROJECT") | |
if env_project_id: | |
user_project_id = env_project_id | |
save_credentials(creds, user_project_id) | |
return user_project_id | |
if os.path.exists(CREDENTIAL_FILE): | |
try: | |
with open(CREDENTIAL_FILE, "r") as f: | |
creds_data = json.load(f) | |
cached_project_id = creds_data.get("project_id") | |
if cached_project_id: | |
user_project_id = cached_project_id | |
return user_project_id | |
except Exception as e: | |
pass | |
if creds.expired and creds.refresh_token: | |
try: | |
creds.refresh(GoogleAuthRequest()) | |
save_credentials(creds) | |
except Exception as e: | |
raise Exception(f"Failed to refresh credentials while getting project ID: {str(e)}") | |
headers = { | |
"Authorization": f"Bearer {creds.token}", | |
"Content-Type": "application/json", | |
"User-Agent": get_user_agent(), | |
} | |
probe_payload = { | |
"metadata": get_client_metadata(), | |
} | |
try: | |
import requests | |
resp = requests.post( | |
f"{CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist", | |
data=json.dumps(probe_payload), | |
headers=headers, | |
) | |
resp.raise_for_status() | |
data = resp.json() | |
user_project_id = data.get("cloudaicompanionProject") | |
if not user_project_id: | |
raise ValueError("Could not find 'cloudaicompanionProject' in loadCodeAssist response.") | |
save_credentials(creds, user_project_id) | |
return user_project_id | |
except requests.exceptions.HTTPError as e: | |
raise |