bibibi12345 commited on
Commit
488d5ee
·
1 Parent(s): ba53187

added onboarding logic. better mimic behavior. support preflight packages. support x-google authentication for cherry studio

Browse files
Files changed (1) hide show
  1. gemini_proxy.py +337 -94
gemini_proxy.py CHANGED
@@ -4,10 +4,13 @@ import requests
4
  import re
5
  import uvicorn
6
  import base64
 
 
7
  from datetime import datetime
8
  from fastapi import FastAPI, Request, Response, HTTPException, Depends
9
  from fastapi.responses import StreamingResponse
10
  from fastapi.security import HTTPBasic, HTTPBasicCredentials
 
11
  from http.server import BaseHTTPRequestHandler, HTTPServer
12
  from urllib.parse import urlparse, parse_qs
13
  import ijson
@@ -16,6 +19,7 @@ from dotenv import load_dotenv
16
  from google.oauth2.credentials import Credentials
17
  from google_auth_oauthlib.flow import Flow
18
  from google.auth.transport.requests import Request as GoogleAuthRequest
 
19
 
20
  # Load environment variables from .env file
21
  load_dotenv()
@@ -27,21 +31,38 @@ SCOPES = [
27
  "https://www.googleapis.com/auth/cloud-platform",
28
  "https://www.googleapis.com/auth/userinfo.email",
29
  "https://www.googleapis.com/auth/userinfo.profile",
30
- "openid",
31
  ]
32
- GEMINI_DIR = os.path.dirname(os.path.abspath(__file__)) # Same directory as the script
33
- CREDENTIAL_FILE = os.path.join(GEMINI_DIR, "oauth_creds.json")
34
  CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com"
35
  GEMINI_PORT = int(os.getenv("GEMINI_PORT", "8888")) # Default to 8888 if not set
36
  GEMINI_AUTH_PASSWORD = os.getenv("GEMINI_AUTH_PASSWORD", "123456") # Default password
 
37
 
38
  # --- Global State ---
39
  credentials = None
40
  user_project_id = None
 
41
 
42
  app = FastAPI()
43
  security = HTTPBasic()
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def authenticate_user(request: Request):
46
  """Authenticate the user with multiple methods."""
47
  # Check for API key in query parameters first (for Gemini client compatibility)
@@ -49,10 +70,17 @@ def authenticate_user(request: Request):
49
  if api_key and api_key == GEMINI_AUTH_PASSWORD:
50
  return "api_key_user"
51
 
 
 
 
 
 
52
  # Check for API key in Authorization header (Bearer token format)
53
  auth_header = request.headers.get("authorization", "")
54
- if auth_header.startswith("Bearer ") and auth_header[7:] == GEMINI_AUTH_PASSWORD:
55
- return "bearer_user"
 
 
56
 
57
  # Check for HTTP Basic Authentication
58
  if auth_header.startswith("Basic "):
@@ -68,7 +96,7 @@ def authenticate_user(request: Request):
68
  # If none of the authentication methods work
69
  raise HTTPException(
70
  status_code=401,
71
- detail="Invalid authentication credentials. Use HTTP Basic Auth, Bearer token, or 'key' query parameter.",
72
  headers={"WWW-Authenticate": "Basic"},
73
  )
74
 
@@ -121,22 +149,156 @@ class _OAuthCallbackHandler(BaseHTTPRequestHandler):
121
  self.end_headers()
122
  self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>")
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def get_user_project_id(creds):
125
- """Gets the user's project ID from environment variable, cache, or by probing the API."""
126
  global user_project_id
127
  if user_project_id:
128
  return user_project_id
129
 
130
- # First, check for environment variable override
131
- env_project_id = os.getenv("GEMINI_PROJECT_ID")
132
  if env_project_id:
133
  user_project_id = env_project_id
134
- print(f"Using project ID from environment variable: {user_project_id}")
135
- # Save the environment project ID to cache for consistency
 
 
 
 
 
 
 
136
  save_credentials(creds, user_project_id)
137
  return user_project_id
138
 
139
- # Second, try to load project ID from credential file
140
  if os.path.exists(CREDENTIAL_FILE):
141
  try:
142
  with open(CREDENTIAL_FILE, "r") as f:
@@ -149,19 +311,28 @@ def get_user_project_id(creds):
149
  except Exception as e:
150
  print(f"Could not load project ID from cache: {e}")
151
 
152
- # If not found in environment or cache, probe for it
153
  print("Project ID not found in environment or cache. Probing for user project ID...")
 
 
 
 
 
 
 
 
 
 
 
 
154
  headers = {
155
  "Authorization": f"Bearer {creds.token}",
156
  "Content-Type": "application/json",
 
157
  }
158
 
159
  probe_payload = {
160
- "cloudaicompanionProject": "gcp-project",
161
- "metadata": {
162
- "ideType": "VSCODE",
163
- "pluginType": "GEMINI"
164
- }
165
  }
166
 
167
  try:
@@ -177,7 +348,6 @@ def get_user_project_id(creds):
177
  raise ValueError("Could not find 'cloudaicompanionProject' in loadCodeAssist response.")
178
  print(f"Successfully fetched user project ID: {user_project_id}")
179
 
180
- # Save the project ID to the credential file for future use
181
  save_credentials(creds, user_project_id)
182
  print("Project ID saved to credential file for future use.")
183
 
@@ -187,15 +357,20 @@ def get_user_project_id(creds):
187
  raise
188
 
189
  def save_credentials(creds, project_id=None):
190
- os.makedirs(GEMINI_DIR, exist_ok=True)
191
  creds_data = {
 
 
192
  "access_token": creds.token,
193
  "refresh_token": creds.refresh_token,
194
- "scope": " ".join(creds.scopes),
195
  "token_type": "Bearer",
196
- "expiry_date": creds.expiry.isoformat() if creds.expiry else None,
197
  }
198
 
 
 
 
 
199
  # If project_id is provided, save it; otherwise preserve existing project_id
200
  if project_id:
201
  creds_data["project_id"] = project_id
@@ -212,65 +387,83 @@ def save_credentials(creds, project_id=None):
212
  json.dump(creds_data, f)
213
 
214
  def get_credentials():
215
- """Loads credentials from cache or initiates the OAuth 2.0 flow."""
216
  global credentials
217
-
218
- if credentials:
219
- if credentials.valid:
220
- return credentials
221
- if credentials.expired and credentials.refresh_token:
222
- print("Refreshing expired credentials...")
223
- try:
 
 
 
 
224
  credentials.refresh(GoogleAuthRequest())
225
  save_credentials(credentials)
226
- print("Credentials refreshed successfully.")
227
- return credentials
228
- except Exception as e:
229
- print(f"Could not refresh token: {e}. Attempting to load from file.")
230
-
231
  if os.path.exists(CREDENTIAL_FILE):
232
  try:
233
  with open(CREDENTIAL_FILE, "r") as f:
234
  creds_data = json.load(f)
235
-
236
- # Load project ID if available
237
- global user_project_id
238
- cached_project_id = creds_data.get("project_id")
239
- if cached_project_id:
240
- user_project_id = cached_project_id
241
- print(f"Loaded project ID from credential file: {user_project_id}")
242
-
243
- expiry = None
244
- expiry_str = creds_data.get("expiry_date")
245
- if expiry_str:
246
- if not isinstance(expiry_str, str) or not expiry_str.strip():
247
- expiry = None
248
- elif expiry_str.endswith('Z'):
249
- expiry_str = expiry_str[:-1] + '+00:00'
250
- expiry = datetime.fromisoformat(expiry_str)
251
- else:
252
- expiry = datetime.fromisoformat(expiry_str)
253
-
254
- credentials = Credentials(
255
- token=creds_data.get("access_token"),
256
- refresh_token=creds_data.get("refresh_token"),
257
- token_uri="https://oauth2.googleapis.com/token",
258
- client_id=CLIENT_ID,
259
- client_secret=CLIENT_SECRET,
260
- scopes=SCOPES,
261
- expiry=expiry
262
- )
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  if credentials.expired and credentials.refresh_token:
265
- print("Loaded credentials from file are expired. Refreshing...")
266
- credentials.refresh(GoogleAuthRequest())
267
- save_credentials(credentials)
268
-
269
- print("Successfully loaded credentials from cache.")
 
 
 
270
  return credentials
271
  except Exception as e:
272
  print(f"Could not load cached credentials: {e}. Starting new login.")
273
 
 
274
  client_config = {
275
  "installed": {
276
  "client_id": CLIENT_ID,
@@ -279,10 +472,22 @@ def get_credentials():
279
  "token_uri": "https://oauth2.googleapis.com/token",
280
  }
281
  }
 
 
282
  flow = Flow.from_client_config(
283
- client_config, scopes=SCOPES, redirect_uri="http://localhost:8080"
 
 
 
 
 
 
 
 
 
 
 
284
  )
285
- auth_url, _ = flow.authorization_url(access_type="offline", prompt="consent")
286
  print(f"\nPlease open this URL in your browser to log in:\n{auth_url}\n")
287
 
288
  server = HTTPServer(("", 8080), _OAuthCallbackHandler)
@@ -293,37 +498,74 @@ def get_credentials():
293
  print("Failed to retrieve authorization code.")
294
  return None
295
 
296
- flow.fetch_token(code=auth_code)
297
- credentials = flow.credentials
298
- save_credentials(credentials)
299
- print("Authentication successful! Credentials saved.")
300
- return credentials
301
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- @app.post("/{full_path:path}")
304
  async def proxy_request(request: Request, full_path: str, username: str = Depends(authenticate_user)):
 
 
305
  creds = get_credentials()
306
  if not creds:
 
307
  return Response(content="Authentication failed. Please restart the proxy to log in.", status_code=500)
308
 
309
- # Only check if credentials are properly formed and not expired locally
310
- if not creds.valid:
311
- if creds.expired and creds.refresh_token:
312
- print("Credentials expired locally. Refreshing...")
313
- try:
314
- creds.refresh(GoogleAuthRequest())
315
- save_credentials(creds)
316
- print("Credentials refreshed successfully.")
317
- except Exception as e:
318
- print(f"Could not refresh token during request: {e}")
319
- return Response(content="Token refresh failed. Please restart the proxy to re-authenticate.", status_code=500)
320
- else:
321
- print("Credentials are invalid locally and cannot be refreshed.")
322
- return Response(content="Invalid credentials. Please restart the proxy to re-authenticate.", status_code=500)
323
 
324
  proj_id = get_user_project_id(creds)
325
  if not proj_id:
326
  return Response(content="Failed to get user project ID.", status_code=500)
 
 
327
 
328
  post_data = await request.body()
329
  path = f"/{full_path}"
@@ -392,8 +634,7 @@ async def proxy_request(request: Request, full_path: str, username: str = Depend
392
  headers = {
393
  "Authorization": f"Bearer {creds.token}",
394
  "Content-Type": "application/json",
395
- # We remove 'Accept-Encoding' to allow the server to send gzip,
396
- # which it seems to stream correctly. We will decompress on the fly.
397
  }
398
 
399
  if is_streaming:
@@ -572,7 +813,9 @@ if __name__ == "__main__":
572
  print("Initializing credentials...")
573
  creds = get_credentials()
574
  if creds:
575
- get_user_project_id(creds)
 
 
576
  print(f"\nStarting Gemini proxy server on http://localhost:{GEMINI_PORT}")
577
  print("Send your Gemini API requests to this address.")
578
  print(f"Authentication required - Password: {GEMINI_AUTH_PASSWORD}")
 
4
  import re
5
  import uvicorn
6
  import base64
7
+ import platform
8
+ import time
9
  from datetime import datetime
10
  from fastapi import FastAPI, Request, Response, HTTPException, Depends
11
  from fastapi.responses import StreamingResponse
12
  from fastapi.security import HTTPBasic, HTTPBasicCredentials
13
+ from fastapi.middleware.cors import CORSMiddleware
14
  from http.server import BaseHTTPRequestHandler, HTTPServer
15
  from urllib.parse import urlparse, parse_qs
16
  import ijson
 
19
  from google.oauth2.credentials import Credentials
20
  from google_auth_oauthlib.flow import Flow
21
  from google.auth.transport.requests import Request as GoogleAuthRequest
22
+ from google.auth.exceptions import RefreshError
23
 
24
  # Load environment variables from .env file
25
  load_dotenv()
 
31
  "https://www.googleapis.com/auth/cloud-platform",
32
  "https://www.googleapis.com/auth/userinfo.email",
33
  "https://www.googleapis.com/auth/userinfo.profile",
 
34
  ]
35
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
36
+ CREDENTIAL_FILE = os.path.join(SCRIPT_DIR, "oauth_creds.json")
37
  CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com"
38
  GEMINI_PORT = int(os.getenv("GEMINI_PORT", "8888")) # Default to 8888 if not set
39
  GEMINI_AUTH_PASSWORD = os.getenv("GEMINI_AUTH_PASSWORD", "123456") # Default password
40
+ CLI_VERSION = "0.1.5" # Match current gemini-cli version
41
 
42
  # --- Global State ---
43
  credentials = None
44
  user_project_id = None
45
+ onboarding_complete = False
46
 
47
  app = FastAPI()
48
  security = HTTPBasic()
49
 
50
+ # Add CORS middleware for preflight requests
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=["*"], # Allow all origins
54
+ allow_credentials=True,
55
+ allow_methods=["*"], # Allow all methods
56
+ allow_headers=["*"], # Allow all headers
57
+ )
58
+
59
+ def get_user_agent():
60
+ """Generate User-Agent string matching gemini-cli format."""
61
+ version = CLI_VERSION
62
+ system = platform.system()
63
+ arch = platform.machine()
64
+ return f"GeminiCLI/{version} ({system}; {arch})"
65
+
66
  def authenticate_user(request: Request):
67
  """Authenticate the user with multiple methods."""
68
  # Check for API key in query parameters first (for Gemini client compatibility)
 
70
  if api_key and api_key == GEMINI_AUTH_PASSWORD:
71
  return "api_key_user"
72
 
73
+ # Check for API key in x-goog-api-key header (Google SDK format)
74
+ goog_api_key = request.headers.get("x-goog-api-key", "")
75
+ if goog_api_key and goog_api_key == GEMINI_AUTH_PASSWORD:
76
+ return "goog_api_key_user"
77
+
78
  # Check for API key in Authorization header (Bearer token format)
79
  auth_header = request.headers.get("authorization", "")
80
+ if auth_header.startswith("Bearer "):
81
+ bearer_token = auth_header[7:]
82
+ if bearer_token == GEMINI_AUTH_PASSWORD:
83
+ return "bearer_user"
84
 
85
  # Check for HTTP Basic Authentication
86
  if auth_header.startswith("Basic "):
 
96
  # If none of the authentication methods work
97
  raise HTTPException(
98
  status_code=401,
99
+ detail="Invalid authentication credentials. Use HTTP Basic Auth, Bearer token, 'key' query parameter, or 'x-goog-api-key' header.",
100
  headers={"WWW-Authenticate": "Basic"},
101
  )
102
 
 
149
  self.end_headers()
150
  self.wfile.write(b"<h1>Authentication failed.</h1><p>Please try again.</p>")
151
 
152
+ def get_platform_string():
153
+ """Generate platform string matching gemini-cli format."""
154
+ system = platform.system().upper()
155
+ arch = platform.machine().upper()
156
+
157
+ # Map to gemini-cli platform format
158
+ if system == "DARWIN":
159
+ if arch in ["ARM64", "AARCH64"]:
160
+ return "DARWIN_ARM64"
161
+ else:
162
+ return "DARWIN_AMD64"
163
+ elif system == "LINUX":
164
+ if arch in ["ARM64", "AARCH64"]:
165
+ return "LINUX_ARM64"
166
+ else:
167
+ return "LINUX_AMD64"
168
+ elif system == "WINDOWS":
169
+ return "WINDOWS_AMD64"
170
+ else:
171
+ return "PLATFORM_UNSPECIFIED"
172
+
173
+ def get_client_metadata(project_id=None):
174
+ return {
175
+ "ideType": "IDE_UNSPECIFIED",
176
+ "platform": get_platform_string(),
177
+ "pluginType": "GEMINI",
178
+ "duetProject": project_id,
179
+ }
180
+
181
+ def onboard_user(creds, project_id):
182
+ """Ensures the user is onboarded, matching gemini-cli setupUser behavior."""
183
+ global onboarding_complete
184
+ if onboarding_complete:
185
+ return
186
+
187
+ # Refresh credentials if expired before making API calls
188
+ if creds.expired and creds.refresh_token:
189
+ print("Credentials expired. Refreshing before onboarding...")
190
+ try:
191
+ creds.refresh(GoogleAuthRequest())
192
+ save_credentials(creds)
193
+ print("Credentials refreshed successfully.")
194
+ except Exception as e:
195
+ print(f"Could not refresh credentials: {e}")
196
+ raise
197
+
198
+ print("Checking user onboarding status...")
199
+ headers = {
200
+ "Authorization": f"Bearer {creds.token}",
201
+ "Content-Type": "application/json",
202
+ "User-Agent": get_user_agent(),
203
+ }
204
+
205
+ # 1. Call loadCodeAssist to check current status
206
+ load_assist_payload = {
207
+ "cloudaicompanionProject": project_id,
208
+ "metadata": get_client_metadata(project_id),
209
+ }
210
+
211
+ try:
212
+ resp = requests.post(
213
+ f"{CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist",
214
+ data=json.dumps(load_assist_payload),
215
+ headers=headers,
216
+ )
217
+ resp.raise_for_status()
218
+ load_data = resp.json()
219
+
220
+ # Determine the tier to use (current or default)
221
+ tier = None
222
+ if load_data.get("currentTier"):
223
+ tier = load_data["currentTier"]
224
+ print("User is already onboarded.")
225
+ else:
226
+ # Find default tier for onboarding
227
+ for allowed_tier in load_data.get("allowedTiers", []):
228
+ if allowed_tier.get("isDefault"):
229
+ tier = allowed_tier
230
+ break
231
+
232
+ if not tier:
233
+ # Fallback tier if no default found (matching gemini-cli logic)
234
+ tier = {
235
+ "name": "",
236
+ "description": "",
237
+ "id": "legacy-tier",
238
+ "userDefinedCloudaicompanionProject": True,
239
+ }
240
+
241
+ # Check if project ID is required but missing
242
+ if tier.get("userDefinedCloudaicompanionProject") and not project_id:
243
+ raise ValueError("This account requires setting the GOOGLE_CLOUD_PROJECT env var.")
244
+
245
+ # If already onboarded, skip the onboarding process
246
+ if load_data.get("currentTier"):
247
+ onboarding_complete = True
248
+ return
249
+
250
+ print(f"Onboarding user to tier: {tier.get('name', 'legacy-tier')}")
251
+ onboard_req_payload = {
252
+ "tierId": tier.get("id"),
253
+ "cloudaicompanionProject": project_id,
254
+ "metadata": get_client_metadata(project_id),
255
+ }
256
+
257
+ # 2. Poll onboardUser until complete (matching gemini-cli polling logic)
258
+ while True:
259
+ onboard_resp = requests.post(
260
+ f"{CODE_ASSIST_ENDPOINT}/v1internal:onboardUser",
261
+ data=json.dumps(onboard_req_payload),
262
+ headers=headers,
263
+ )
264
+ onboard_resp.raise_for_status()
265
+ lro_data = onboard_resp.json()
266
+
267
+ if lro_data.get("done"):
268
+ print("Onboarding successful.")
269
+ onboarding_complete = True
270
+ break
271
+
272
+ print("Onboarding in progress, waiting 5 seconds...")
273
+ time.sleep(5)
274
+
275
+ except requests.exceptions.HTTPError as e:
276
+ print(f"Error during onboarding: {e.response.text}")
277
+ raise
278
+
279
  def get_user_project_id(creds):
280
+ """Gets the user's project ID matching gemini-cli setupUser logic."""
281
  global user_project_id
282
  if user_project_id:
283
  return user_project_id
284
 
285
+ # First, check for GOOGLE_CLOUD_PROJECT environment variable (matching gemini-cli)
286
+ env_project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
287
  if env_project_id:
288
  user_project_id = env_project_id
289
+ print(f"Using project ID from GOOGLE_CLOUD_PROJECT: {user_project_id}")
290
+ save_credentials(creds, user_project_id)
291
+ return user_project_id
292
+
293
+ # Second, check for GEMINI_PROJECT_ID as fallback
294
+ gemini_env_project_id = os.getenv("GEMINI_PROJECT_ID")
295
+ if gemini_env_project_id:
296
+ user_project_id = gemini_env_project_id
297
+ print(f"Using project ID from GEMINI_PROJECT_ID: {user_project_id}")
298
  save_credentials(creds, user_project_id)
299
  return user_project_id
300
 
301
+ # Third, try to load project ID from credential file
302
  if os.path.exists(CREDENTIAL_FILE):
303
  try:
304
  with open(CREDENTIAL_FILE, "r") as f:
 
311
  except Exception as e:
312
  print(f"Could not load project ID from cache: {e}")
313
 
314
+ # If not found in environment or cache, probe for it via loadCodeAssist
315
  print("Project ID not found in environment or cache. Probing for user project ID...")
316
+
317
+ # Refresh credentials if expired before making API calls
318
+ if creds.expired and creds.refresh_token:
319
+ print("Credentials expired. Refreshing before project ID probe...")
320
+ try:
321
+ creds.refresh(GoogleAuthRequest())
322
+ save_credentials(creds)
323
+ print("Credentials refreshed successfully.")
324
+ except Exception as e:
325
+ print(f"Could not refresh credentials: {e}")
326
+ raise
327
+
328
  headers = {
329
  "Authorization": f"Bearer {creds.token}",
330
  "Content-Type": "application/json",
331
+ "User-Agent": get_user_agent(),
332
  }
333
 
334
  probe_payload = {
335
+ "metadata": get_client_metadata(),
 
 
 
 
336
  }
337
 
338
  try:
 
348
  raise ValueError("Could not find 'cloudaicompanionProject' in loadCodeAssist response.")
349
  print(f"Successfully fetched user project ID: {user_project_id}")
350
 
 
351
  save_credentials(creds, user_project_id)
352
  print("Project ID saved to credential file for future use.")
353
 
 
357
  raise
358
 
359
  def save_credentials(creds, project_id=None):
 
360
  creds_data = {
361
+ "client_id": CLIENT_ID,
362
+ "client_secret": CLIENT_SECRET,
363
  "access_token": creds.token,
364
  "refresh_token": creds.refresh_token,
365
+ "scope": " ".join(creds.scopes) if creds.scopes else " ".join(SCOPES),
366
  "token_type": "Bearer",
367
+ "token_uri": "https://oauth2.googleapis.com/token",
368
  }
369
 
370
+ # Add expiry if available
371
+ if creds.expiry:
372
+ creds_data["expiry"] = creds.expiry.isoformat()
373
+
374
  # If project_id is provided, save it; otherwise preserve existing project_id
375
  if project_id:
376
  creds_data["project_id"] = project_id
 
387
  json.dump(creds_data, f)
388
 
389
  def get_credentials():
390
+ """Loads credentials matching gemini-cli OAuth2 flow."""
391
  global credentials
392
+
393
+ # Check environment for credentials first
394
+ env_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
395
+ if env_creds and os.path.exists(env_creds):
396
+ try:
397
+ with open(env_creds, "r") as f:
398
+ creds_data = json.load(f)
399
+ credentials = Credentials.from_authorized_user_info(creds_data, SCOPES)
400
+ print("Loaded credentials from GOOGLE_APPLICATION_CREDENTIALS.")
401
+ if credentials.expired and credentials.refresh_token:
402
+ print("Refreshing expired credentials...")
403
  credentials.refresh(GoogleAuthRequest())
404
  save_credentials(credentials)
405
+ return credentials
406
+ except Exception as e:
407
+ print(f"Could not load credentials from GOOGLE_APPLICATION_CREDENTIALS: {e}")
408
+
409
+ # Fallback to cached credentials
410
  if os.path.exists(CREDENTIAL_FILE):
411
  try:
412
  with open(CREDENTIAL_FILE, "r") as f:
413
  creds_data = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ credentials = Credentials.from_authorized_user_info(creds_data, SCOPES)
416
+ print("Loaded credentials from cache.")
417
+
418
+ # Try to refresh if we have refresh token but no access token
419
+ if not credentials.token and credentials.refresh_token:
420
+ print("Attempting to refresh credentials...")
421
+ try:
422
+ from google.auth.transport.requests import Request as AuthRequest
423
+ auth_request = AuthRequest()
424
+ credentials.refresh(auth_request)
425
+ print("Credentials refreshed successfully!")
426
+
427
+ # Save refreshed credentials
428
+ updated_creds_data = {
429
+ 'client_id': credentials.client_id,
430
+ 'client_secret': credentials.client_secret,
431
+ 'access_token': credentials.token,
432
+ 'refresh_token': credentials.refresh_token,
433
+ 'scope': credentials.scopes,
434
+ 'token_type': 'Bearer',
435
+ 'token_uri': credentials.token_uri,
436
+ 'expiry': credentials.expiry.isoformat() if credentials.expiry else None,
437
+ 'project_id': creds_data.get('project_id')
438
+ }
439
+
440
+ with open(CREDENTIAL_FILE, 'w') as f:
441
+ json.dump(updated_creds_data, f, indent=2)
442
+ print("Refreshed credentials saved.")
443
+
444
+ except Exception as e:
445
+ print(f"Failed to refresh credentials: {e}")
446
+ return None
447
+
448
+ # Check if we have a valid token after potential refresh
449
+ if not credentials.token:
450
+ print("No access token available after refresh attempt. Starting new login.")
451
+ return None
452
+
453
  if credentials.expired and credentials.refresh_token:
454
+ print("Refreshing expired credentials...")
455
+ try:
456
+ credentials.refresh(GoogleAuthRequest())
457
+ save_credentials(credentials)
458
+ print("Credentials refreshed and saved.")
459
+ except Exception as refresh_error:
460
+ print(f"Failed to refresh credentials: {refresh_error}. Starting new login.")
461
+ return None
462
  return credentials
463
  except Exception as e:
464
  print(f"Could not load cached credentials: {e}. Starting new login.")
465
 
466
+ # If no valid credentials, start new login flow
467
  client_config = {
468
  "installed": {
469
  "client_id": CLIENT_ID,
 
472
  "token_uri": "https://oauth2.googleapis.com/token",
473
  }
474
  }
475
+
476
+ # Create flow with include_granted_scopes to handle scope changes
477
  flow = Flow.from_client_config(
478
+ client_config,
479
+ scopes=SCOPES,
480
+ redirect_uri="http://localhost:8080"
481
+ )
482
+
483
+ # Set include_granted_scopes to handle additional scopes gracefully
484
+ flow.oauth2session.scope = SCOPES
485
+
486
+ auth_url, _ = flow.authorization_url(
487
+ access_type="offline",
488
+ prompt="consent",
489
+ include_granted_scopes='true'
490
  )
 
491
  print(f"\nPlease open this URL in your browser to log in:\n{auth_url}\n")
492
 
493
  server = HTTPServer(("", 8080), _OAuthCallbackHandler)
 
498
  print("Failed to retrieve authorization code.")
499
  return None
500
 
501
+ # Monkey patch to handle scope validation warnings
502
+ import oauthlib.oauth2.rfc6749.parameters
503
+ original_validate = oauthlib.oauth2.rfc6749.parameters.validate_token_parameters
504
+
505
+ def patched_validate(params):
506
+ try:
507
+ return original_validate(params)
508
+ except Warning:
509
+ # Ignore scope change warnings
510
+ pass
511
+
512
+ oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = patched_validate
513
+
514
+ try:
515
+ flow.fetch_token(code=auth_code)
516
+ credentials = flow.credentials
517
+ save_credentials(credentials)
518
+ print("Authentication successful! Credentials saved.")
519
+ return credentials
520
+ except Exception as e:
521
+ print(f"Authentication failed: {e}")
522
+ return None
523
+ finally:
524
+ # Restore original function
525
+ oauthlib.oauth2.rfc6749.parameters.validate_token_parameters = original_validate
526
+
527
+
528
+ @app.options("/{full_path:path}")
529
+ async def handle_preflight(request: Request, full_path: str):
530
+ """Handle CORS preflight requests without authentication."""
531
+ return Response(
532
+ status_code=200,
533
+ headers={
534
+ "Access-Control-Allow-Origin": "*",
535
+ "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, PATCH, OPTIONS",
536
+ "Access-Control-Allow-Headers": "*",
537
+ "Access-Control-Allow-Credentials": "true",
538
+ }
539
+ )
540
 
541
+ @app.api_route("/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
542
  async def proxy_request(request: Request, full_path: str, username: str = Depends(authenticate_user)):
543
+ print(f"[{request.method}] /{full_path} - User: {username}")
544
+
545
  creds = get_credentials()
546
  if not creds:
547
+ print("❌ No credentials available")
548
  return Response(content="Authentication failed. Please restart the proxy to log in.", status_code=500)
549
 
550
+ # Check if credentials need refreshing (more lenient validation)
551
+ if creds.expired and creds.refresh_token:
552
+ print("Credentials expired. Refreshing...")
553
+ try:
554
+ creds.refresh(GoogleAuthRequest())
555
+ save_credentials(creds)
556
+ print("Credentials refreshed successfully.")
557
+ except Exception as e:
558
+ print(f"Could not refresh token during request: {e}")
559
+ return Response(content="Token refresh failed. Please restart the proxy to re-authenticate.", status_code=500)
560
+ elif not creds.token:
561
+ print("No access token available.")
562
+ return Response(content="No access token. Please restart the proxy to re-authenticate.", status_code=500)
 
563
 
564
  proj_id = get_user_project_id(creds)
565
  if not proj_id:
566
  return Response(content="Failed to get user project ID.", status_code=500)
567
+
568
+ onboard_user(creds, proj_id)
569
 
570
  post_data = await request.body()
571
  path = f"/{full_path}"
 
634
  headers = {
635
  "Authorization": f"Bearer {creds.token}",
636
  "Content-Type": "application/json",
637
+ "User-Agent": get_user_agent(),
 
638
  }
639
 
640
  if is_streaming:
 
813
  print("Initializing credentials...")
814
  creds = get_credentials()
815
  if creds:
816
+ proj_id = get_user_project_id(creds)
817
+ if proj_id:
818
+ onboard_user(creds, proj_id)
819
  print(f"\nStarting Gemini proxy server on http://localhost:{GEMINI_PORT}")
820
  print("Send your Gemini API requests to this address.")
821
  print(f"Authentication required - Password: {GEMINI_AUTH_PASSWORD}")