Abhishek Thakur commited on
Commit
6cdb208
·
1 Parent(s): c86fede
Files changed (2) hide show
  1. competitions/app.py +6 -0
  2. competitions/oauth.py +6 -137
competitions/app.py CHANGED
@@ -2,6 +2,7 @@ import datetime
2
  import os
3
  import threading
4
 
 
5
  from fastapi import FastAPI, File, Form, Request, UploadFile
6
  from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
7
  from fastapi.staticfiles import StaticFiles
@@ -88,6 +89,11 @@ async def read_form(request: Request):
88
  """
89
  if USE_OAUTH == 1:
90
  logger.info(request.session.get("oauth_info"))
 
 
 
 
 
91
  if HF_TOKEN is None:
92
  return templates.TemplateResponse("error.html", {"request": request})
93
  context = {
 
2
  import os
3
  import threading
4
 
5
+ import requests
6
  from fastapi import FastAPI, File, Form, Request, UploadFile
7
  from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
8
  from fastapi.staticfiles import StaticFiles
 
89
  """
90
  if USE_OAUTH == 1:
91
  logger.info(request.session.get("oauth_info"))
92
+ oauth_info = request.session.get("oauth_info")
93
+ access_token = oauth_info["access_token"]
94
+ oauth_userinfo_endpoint = "https://huggingface.co/oauth/userinfo"
95
+ res = requests.post(oauth_userinfo_endpoint, headers={"Authorization": f"Bearer {access_token}"}, timeout=10)
96
+ oauth_info["_userinfo"] = res.json()
97
  if HF_TOKEN is None:
98
  return templates.TemplateResponse("error.html", {"request": request})
99
  context = {
competitions/oauth.py CHANGED
@@ -6,16 +6,11 @@ from __future__ import annotations
6
 
7
  import hashlib
8
  import os
9
- import typing
10
  import urllib.parse
11
- import warnings
12
- from dataclasses import dataclass, field
13
 
14
  import fastapi
15
- import requests
16
  from authlib.integrations.starlette_client import OAuth
17
  from fastapi.responses import RedirectResponse
18
- from huggingface_hub import whoami
19
  from starlette.middleware.sessions import SessionMiddleware
20
 
21
 
@@ -32,8 +27,7 @@ def attach_oauth(app: fastapi.FastAPI):
32
  if os.environ.get("SPACE_ID") is not None and int(os.environ.get("USE_OAUTH", 0)) == 1:
33
  _add_oauth_routes(app)
34
  else:
35
- _add_mocked_oauth_routes(app)
36
-
37
  # Session Middleware requires a secret key to sign the cookies. Let's use a hash
38
  # of the OAuth secret key to make it unique to the Space + updated in case OAuth
39
  # config gets updated.
@@ -79,17 +73,14 @@ def _add_oauth_routes(app: fastapi.FastAPI) -> None:
79
  async def oauth_login(request: fastapi.Request):
80
  """Endpoint that redirects to HF OAuth page."""
81
  # Define target (where to redirect after login)
82
- redirect_uri = _generate_redirect_uri(request)
 
83
  return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore
84
 
85
- @app.get("/login/callback")
86
  async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
87
  """Endpoint that handles the OAuth callback."""
88
  oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
89
- access_token = oauth_info["access_token"]
90
- oauth_userinfo_endpoint = "https://huggingface.co/oauth/userinfo"
91
- res = requests.post(oauth_userinfo_endpoint, headers={"Authorization": f"Bearer {access_token}"}, timeout=10)
92
- oauth_info["_userinfo"] = res.json()
93
  request.session["oauth_info"] = oauth_info
94
  return _redirect_to_target(request)
95
 
@@ -100,39 +91,6 @@ def _add_oauth_routes(app: fastapi.FastAPI) -> None:
100
  return _redirect_to_target(request)
101
 
102
 
103
- def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
104
- """Add fake oauth routes if Gradio is run locally and OAuth is enabled.
105
- Instead of authenticating with HF, a mocked user profile is added to the session.
106
- """
107
- warnings.warn(
108
- "AutoTrain does not support OAuth features outside of a Space environment. To help"
109
- " you debug your app locally, the login and logout buttons are mocked with your"
110
- " profile. To make it work, your machine must be logged in to Huggingface."
111
- )
112
- mocked_oauth_info = _get_mocked_oauth_info()
113
-
114
- # Define OAuth routes
115
- @app.get("/login/huggingface")
116
- async def oauth_login(request: fastapi.Request): # noqa: ARG001
117
- """Fake endpoint that redirects to HF OAuth page."""
118
- # Define target (where to redirect after login)
119
- redirect_uri = _generate_redirect_uri(request)
120
- return RedirectResponse("/login/callback?" + urllib.parse.urlencode({"_target_url": redirect_uri}))
121
-
122
- @app.get("/login/callback")
123
- async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
124
- """Endpoint that handles the OAuth callback."""
125
- request.session["oauth_info"] = mocked_oauth_info
126
- return _redirect_to_target(request)
127
-
128
- @app.get("/logout")
129
- async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
130
- """Endpoint that logs out the user (e.g. delete cookie session)."""
131
- request.session.pop("oauth_info", None)
132
- logout_url = str(request.url).replace("/logout", "/") # preserve query params
133
- return RedirectResponse(url=logout_url)
134
-
135
-
136
  def _generate_redirect_uri(request: fastapi.Request) -> str:
137
  if "_target_url" in request.query_params:
138
  # if `_target_url` already in query params => respect it
@@ -150,95 +108,6 @@ def _generate_redirect_uri(request: fastapi.Request) -> str:
150
 
151
 
152
  def _redirect_to_target(request: fastapi.Request, default_target: str = "/") -> RedirectResponse:
153
- target = request.query_params.get("_target_url", default_target)
 
154
  return RedirectResponse(target)
155
-
156
-
157
- @dataclass
158
- class OAuthProfile(typing.Dict): # inherit from Dict for backward compatibility
159
- """
160
- A OAuthProfile object that can be used to inject the profile of a user in a
161
- function. If a function expects `OAuthProfile` or `Optional[OAuthProfile]` as input,
162
- the value will be injected from the FastAPI session if the user is logged in. If the
163
- user is not logged in and the function expects `OAuthProfile`, an error will be
164
- raised.
165
-
166
- Attributes:
167
- name (str): The name of the user (e.g. 'abhishek').
168
- username (str): The username of the user (e.g. 'abhishek')
169
- profile (str): The profile URL of the user (e.g. 'https://huggingface.co/abhishek').
170
- picture (str): The profile picture URL of the user.
171
- """
172
-
173
- name: str = field(init=False)
174
- username: str = field(init=False)
175
- profile: str = field(init=False)
176
- picture: str = field(init=False)
177
-
178
- def __init__(self, data: dict): # hack to make OAuthProfile backward compatible
179
- self.update(data)
180
- self.name = self["name"]
181
- self.username = self["preferred_username"]
182
- self.profile = self["profile"]
183
- self.picture = self["picture"]
184
-
185
-
186
- @dataclass
187
- class OAuthToken:
188
- """
189
- A Gradio OAuthToken object that can be used to inject the access token of a user in a
190
- function. If a function expects `OAuthToken` or `Optional[OAuthToken]` as input,
191
- the value will be injected from the FastAPI session if the user is logged in. If the
192
- user is not logged in and the function expects `OAuthToken`, an error will be
193
- raised.
194
-
195
- Attributes:
196
- token (str): The access token of the user.
197
- scope (str): The scope of the access token.
198
- expires_at (int): The expiration timestamp of the access token.
199
- """
200
-
201
- token: str
202
- scope: str
203
- expires_at: int
204
-
205
-
206
- def _get_mocked_oauth_info() -> typing.Dict:
207
- token = os.environ.get("USER_HF_TOKEN")
208
- if token is None:
209
- raise ValueError(
210
- "Your machine must be logged in to HF to debug AutoTrain locally. Please "
211
- "set `USER_HF_TOKEN` as environment variable "
212
- "with one of your access token. You can generate a new token in your "
213
- "settings page (https://huggingface.co/settings/tokens)."
214
- )
215
-
216
- user = whoami(token=token)
217
- if user["type"] != "user":
218
- raise ValueError(
219
- "Your machine is not logged in with a personal account. Please use a "
220
- "personal access token. You can generate a new token in your settings page"
221
- " (https://huggingface.co/settings/tokens)."
222
- )
223
-
224
- return {
225
- "access_token": "hf_oauth_XXX",
226
- "token_type": "bearer",
227
- "expires_in": 28799,
228
- "id_token": "XXX",
229
- "scope": "openid profile read-repos",
230
- "expires_at": 1709003175,
231
- "userinfo": {
232
- "sub": "123hello123",
233
- "name": "my name",
234
- "preferred_username": "me",
235
- "profile": "https://huggingface.co/user",
236
- "picture": "https://img",
237
- "aud": "jksdahffasdk-435-3-dsf-a",
238
- "auth_time": 1708974376,
239
- "nonce": "jdkfghskfdjhgkfd",
240
- "iat": 1708974376,
241
- "exp": 1708977976,
242
- "iss": "https://huggingface.co",
243
- },
244
- }
 
6
 
7
  import hashlib
8
  import os
 
9
  import urllib.parse
 
 
10
 
11
  import fastapi
 
12
  from authlib.integrations.starlette_client import OAuth
13
  from fastapi.responses import RedirectResponse
 
14
  from starlette.middleware.sessions import SessionMiddleware
15
 
16
 
 
27
  if os.environ.get("SPACE_ID") is not None and int(os.environ.get("USE_OAUTH", 0)) == 1:
28
  _add_oauth_routes(app)
29
  else:
30
+ return
 
31
  # Session Middleware requires a secret key to sign the cookies. Let's use a hash
32
  # of the OAuth secret key to make it unique to the Space + updated in case OAuth
33
  # config gets updated.
 
73
  async def oauth_login(request: fastapi.Request):
74
  """Endpoint that redirects to HF OAuth page."""
75
  # Define target (where to redirect after login)
76
+ # redirect_uri = _generate_redirect_uri(request)
77
+ redirect_uri = request.url_for("auth")
78
  return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore
79
 
80
+ @app.get("/auth")
81
  async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
82
  """Endpoint that handles the OAuth callback."""
83
  oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
 
 
 
 
84
  request.session["oauth_info"] = oauth_info
85
  return _redirect_to_target(request)
86
 
 
91
  return _redirect_to_target(request)
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def _generate_redirect_uri(request: fastapi.Request) -> str:
95
  if "_target_url" in request.query_params:
96
  # if `_target_url` already in query params => respect it
 
108
 
109
 
110
  def _redirect_to_target(request: fastapi.Request, default_target: str = "/") -> RedirectResponse:
111
+ # target = request.query_params.get("_target_url", default_target)
112
+ target = "https://huggingface.co/" + os.environ.get("SPACE_ID")
113
  return RedirectResponse(target)