Abhishek Thakur commited on
Commit
bef585f
·
1 Parent(s): b745d94
competitions/app.py CHANGED
@@ -16,6 +16,7 @@ from competitions import utils
16
  from competitions.errors import AuthenticationError
17
  from competitions.info import CompetitionInfo
18
  from competitions.leaderboard import Leaderboard
 
19
  from competitions.runner import JobRunner
20
  from competitions.submissions import Submissions
21
  from competitions.text import SUBMISSION_SELECTION_TEXT, SUBMISSION_TEXT
@@ -27,7 +28,7 @@ COMPETITION_ID = os.environ.get("COMPETITION_ID")
27
  OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/tmp/model")
28
  START_DATE = os.environ.get("START_DATE", "2000-12-31")
29
  DISABLE_PUBLIC_LB = int(os.environ.get("DISABLE_PUBLIC_LB", 0))
30
-
31
 
32
  disable_progress_bars()
33
 
@@ -69,6 +70,9 @@ thread.start()
69
 
70
 
71
  app = FastAPI()
 
 
 
72
  static_path = os.path.join(BASE_DIR, "static")
73
  app.mount("/static", StaticFiles(directory=static_path), name="static")
74
  templates_path = os.path.join(BASE_DIR, "templates")
@@ -92,6 +96,11 @@ async def read_form(request: Request):
92
  return templates.TemplateResponse("index.html", context)
93
 
94
 
 
 
 
 
 
95
  @app.get("/competition_info", response_class=JSONResponse)
96
  async def get_comp_info(request: Request):
97
  info = COMP_INFO.competition_desc
 
16
  from competitions.errors import AuthenticationError
17
  from competitions.info import CompetitionInfo
18
  from competitions.leaderboard import Leaderboard
19
+ from competitions.oauth import attach_oauth
20
  from competitions.runner import JobRunner
21
  from competitions.submissions import Submissions
22
  from competitions.text import SUBMISSION_SELECTION_TEXT, SUBMISSION_TEXT
 
28
  OUTPUT_PATH = os.environ.get("OUTPUT_PATH", "/tmp/model")
29
  START_DATE = os.environ.get("START_DATE", "2000-12-31")
30
  DISABLE_PUBLIC_LB = int(os.environ.get("DISABLE_PUBLIC_LB", 0))
31
+ USE_OAUTH = int(os.environ.get("USE_OAUTH", 0))
32
 
33
  disable_progress_bars()
34
 
 
70
 
71
 
72
  app = FastAPI()
73
+ if USE_OAUTH == 1:
74
+ attach_oauth(app)
75
+
76
  static_path = os.path.join(BASE_DIR, "static")
77
  app.mount("/static", StaticFiles(directory=static_path), name="static")
78
  templates_path = os.path.join(BASE_DIR, "templates")
 
96
  return templates.TemplateResponse("index.html", context)
97
 
98
 
99
+ @app.get("/use_oauth", response_class=JSONResponse)
100
+ async def use_oauth(request: Request):
101
+ return {"response": USE_OAUTH}
102
+
103
+
104
  @app.get("/competition_info", response_class=JSONResponse)
105
  async def get_comp_info(request: Request):
106
  info = COMP_INFO.competition_desc
competitions/create.py CHANGED
@@ -119,6 +119,9 @@ def _create_readme(competition_name):
119
  _readme += "sdk: docker\n"
120
  _readme += "pinned: false\n"
121
  _readme += "duplicated_from: autotrain-projects/autotrain-advanced\n"
 
 
 
122
  _readme += "---\n"
123
  _readme = io.BytesIO(_readme.encode())
124
  return _readme
 
119
  _readme += "sdk: docker\n"
120
  _readme += "pinned: false\n"
121
  _readme += "duplicated_from: autotrain-projects/autotrain-advanced\n"
122
+ _readme += "hf_oauth: true\n"
123
+ _readme += "hf_oauth_scopes:\n"
124
+ _readme += " - read-repos\n"
125
  _readme += "---\n"
126
  _readme = io.BytesIO(_readme.encode())
127
  return _readme
competitions/oauth.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OAuth support for AutoTrain.
2
+ Taken from: https://github.com/gradio-app/gradio/blob/main/gradio/oauth.py
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import hashlib
8
+ import os
9
+ import typing
10
+ import urllib.parse
11
+ import warnings
12
+ from dataclasses import dataclass, field
13
+
14
+ import fastapi
15
+ from authlib.integrations.starlette_client import OAuth
16
+ from fastapi.responses import RedirectResponse
17
+ from huggingface_hub import whoami
18
+ from starlette.middleware.sessions import SessionMiddleware
19
+
20
+
21
+ OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID")
22
+ OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET")
23
+ OAUTH_SCOPES = os.environ.get("OAUTH_SCOPES")
24
+ OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL")
25
+
26
+
27
+ def attach_oauth(app: fastapi.FastAPI):
28
+ # Add `/login/huggingface`, `/login/callback` and `/logout` routes to enable OAuth in the Gradio app.
29
+ # If the app is running in a Space, OAuth is enabled normally. Otherwise, we mock the "real" routes to make the
30
+ # user log in with a fake user profile - without any calls to hf.co.
31
+ if os.environ.get("SPACE_ID") is not None and os.environ.get("HF_TOKEN") is None:
32
+ _add_oauth_routes(app)
33
+ else:
34
+ _add_mocked_oauth_routes(app)
35
+
36
+ # Session Middleware requires a secret key to sign the cookies. Let's use a hash
37
+ # of the OAuth secret key to make it unique to the Space + updated in case OAuth
38
+ # config gets updated.
39
+ session_secret = (OAUTH_CLIENT_SECRET or "") + "-v3"
40
+ # ^ if we change the session cookie format in the future, we can bump the version of the session secret to make
41
+ # sure cookies are invalidated. Otherwise some users with an old cookie format might get a HTTP 500 error.
42
+ app.add_middleware(
43
+ SessionMiddleware,
44
+ secret_key=hashlib.sha256(session_secret.encode()).hexdigest(),
45
+ same_site="none",
46
+ https_only=True,
47
+ )
48
+
49
+
50
+ def _add_oauth_routes(app: fastapi.FastAPI) -> None:
51
+ """Add OAuth routes to the FastAPI app (login, callback handler and logout)."""
52
+ # Check environment variables
53
+ msg = (
54
+ "OAuth is required but {} environment variable is not set. Make sure you've enabled OAuth in your Space by"
55
+ " setting `hf_oauth: true` in the Space metadata."
56
+ )
57
+ if OAUTH_CLIENT_ID is None:
58
+ raise ValueError(msg.format("OAUTH_CLIENT_ID"))
59
+ if OAUTH_CLIENT_SECRET is None:
60
+ raise ValueError(msg.format("OAUTH_CLIENT_SECRET"))
61
+ if OAUTH_SCOPES is None:
62
+ raise ValueError(msg.format("OAUTH_SCOPES"))
63
+ if OPENID_PROVIDER_URL is None:
64
+ raise ValueError(msg.format("OPENID_PROVIDER_URL"))
65
+
66
+ # Register OAuth server
67
+ oauth = OAuth()
68
+ oauth.register(
69
+ name="huggingface",
70
+ client_id=OAUTH_CLIENT_ID,
71
+ client_secret=OAUTH_CLIENT_SECRET,
72
+ client_kwargs={"scope": OAUTH_SCOPES},
73
+ server_metadata_url=OPENID_PROVIDER_URL + "/.well-known/openid-configuration",
74
+ )
75
+
76
+ # Define OAuth routes
77
+ @app.get("/login/huggingface")
78
+ async def oauth_login(request: fastapi.Request):
79
+ """Endpoint that redirects to HF OAuth page."""
80
+ # Define target (where to redirect after login)
81
+ redirect_uri = _generate_redirect_uri(request)
82
+ return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore
83
+
84
+ @app.get("/login/callback")
85
+ async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
86
+ """Endpoint that handles the OAuth callback."""
87
+ oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore
88
+ request.session["oauth_info"] = oauth_info
89
+ return _redirect_to_target(request)
90
+
91
+ @app.get("/logout")
92
+ async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
93
+ """Endpoint that logs out the user (e.g. delete cookie session)."""
94
+ request.session.pop("oauth_info", None)
95
+ return _redirect_to_target(request)
96
+
97
+
98
+ def _add_mocked_oauth_routes(app: fastapi.FastAPI) -> None:
99
+ """Add fake oauth routes if Gradio is run locally and OAuth is enabled.
100
+ Instead of authenticating with HF, a mocked user profile is added to the session.
101
+ """
102
+ warnings.warn(
103
+ "AutoTrain does not support OAuth features outside of a Space environment. To help"
104
+ " you debug your app locally, the login and logout buttons are mocked with your"
105
+ " profile. To make it work, your machine must be logged in to Huggingface."
106
+ )
107
+ mocked_oauth_info = _get_mocked_oauth_info()
108
+
109
+ # Define OAuth routes
110
+ @app.get("/login/huggingface")
111
+ async def oauth_login(request: fastapi.Request): # noqa: ARG001
112
+ """Fake endpoint that redirects to HF OAuth page."""
113
+ # Define target (where to redirect after login)
114
+ redirect_uri = _generate_redirect_uri(request)
115
+ return RedirectResponse("/login/callback?" + urllib.parse.urlencode({"_target_url": redirect_uri}))
116
+
117
+ @app.get("/login/callback")
118
+ async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse:
119
+ """Endpoint that handles the OAuth callback."""
120
+ request.session["oauth_info"] = mocked_oauth_info
121
+ return _redirect_to_target(request)
122
+
123
+ @app.get("/logout")
124
+ async def oauth_logout(request: fastapi.Request) -> RedirectResponse:
125
+ """Endpoint that logs out the user (e.g. delete cookie session)."""
126
+ request.session.pop("oauth_info", None)
127
+ logout_url = str(request.url).replace("/logout", "/") # preserve query params
128
+ return RedirectResponse(url=logout_url)
129
+
130
+
131
+ def _generate_redirect_uri(request: fastapi.Request) -> str:
132
+ if "_target_url" in request.query_params:
133
+ # if `_target_url` already in query params => respect it
134
+ target = request.query_params["_target_url"]
135
+ else:
136
+ # otherwise => keep query params
137
+ target = "/?" + urllib.parse.urlencode(request.query_params)
138
+
139
+ redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target)
140
+ redirect_uri_as_str = str(redirect_uri)
141
+ if redirect_uri.netloc.endswith(".hf.space"):
142
+ # In Space, FastAPI redirect as http but we want https
143
+ redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://")
144
+ return redirect_uri_as_str
145
+
146
+
147
+ def _redirect_to_target(request: fastapi.Request, default_target: str = "/") -> RedirectResponse:
148
+ target = request.query_params.get("_target_url", default_target)
149
+ return RedirectResponse(target)
150
+
151
+
152
+ @dataclass
153
+ class OAuthProfile(typing.Dict): # inherit from Dict for backward compatibility
154
+ """
155
+ A OAuthProfile object that can be used to inject the profile of a user in a
156
+ function. If a function expects `OAuthProfile` or `Optional[OAuthProfile]` as input,
157
+ the value will be injected from the FastAPI session if the user is logged in. If the
158
+ user is not logged in and the function expects `OAuthProfile`, an error will be
159
+ raised.
160
+
161
+ Attributes:
162
+ name (str): The name of the user (e.g. 'abhishek').
163
+ username (str): The username of the user (e.g. 'abhishek')
164
+ profile (str): The profile URL of the user (e.g. 'https://huggingface.co/abhishek').
165
+ picture (str): The profile picture URL of the user.
166
+ """
167
+
168
+ name: str = field(init=False)
169
+ username: str = field(init=False)
170
+ profile: str = field(init=False)
171
+ picture: str = field(init=False)
172
+
173
+ def __init__(self, data: dict): # hack to make OAuthProfile backward compatible
174
+ self.update(data)
175
+ self.name = self["name"]
176
+ self.username = self["preferred_username"]
177
+ self.profile = self["profile"]
178
+ self.picture = self["picture"]
179
+
180
+
181
+ @dataclass
182
+ class OAuthToken:
183
+ """
184
+ A Gradio OAuthToken object that can be used to inject the access token of a user in a
185
+ function. If a function expects `OAuthToken` or `Optional[OAuthToken]` as input,
186
+ the value will be injected from the FastAPI session if the user is logged in. If the
187
+ user is not logged in and the function expects `OAuthToken`, an error will be
188
+ raised.
189
+
190
+ Attributes:
191
+ token (str): The access token of the user.
192
+ scope (str): The scope of the access token.
193
+ expires_at (int): The expiration timestamp of the access token.
194
+ """
195
+
196
+ token: str
197
+ scope: str
198
+ expires_at: int
199
+
200
+
201
+ def _get_mocked_oauth_info() -> typing.Dict:
202
+ token = os.environ.get("HF_TOKEN")
203
+ if token is None:
204
+ raise ValueError(
205
+ "Your machine must be logged in to HF to debug AutoTrain locally. Please "
206
+ "set `HF_TOKEN` as environment variable "
207
+ "with one of your access token. You can generate a new token in your "
208
+ "settings page (https://huggingface.co/settings/tokens)."
209
+ )
210
+
211
+ user = whoami(token=token)
212
+ if user["type"] != "user":
213
+ raise ValueError(
214
+ "Your machine is not logged in with a personal account. Please use a "
215
+ "personal access token. You can generate a new token in your settings page"
216
+ " (https://huggingface.co/settings/tokens)."
217
+ )
218
+
219
+ return {
220
+ "access_token": token,
221
+ "token_type": "bearer",
222
+ "expires_in": 3600,
223
+ "id_token": "AAAAAAAAAAAAAAAAAAAAAAAAAA",
224
+ "scope": "openid profile",
225
+ "expires_at": 1691676444,
226
+ "userinfo": {
227
+ "sub": "11111111111111111111111",
228
+ "name": user["fullname"],
229
+ "preferred_username": user["name"],
230
+ "profile": f"https://huggingface.co/{user['name']}",
231
+ "picture": user["avatarUrl"],
232
+ "website": "",
233
+ "aud": "00000000-0000-0000-0000-000000000000",
234
+ "auth_time": 1691672844,
235
+ "nonce": "aaaaaaaaaaaaaaaaaaa",
236
+ "iat": 1691672844,
237
+ "exp": 1691676444,
238
+ "iss": "https://huggingface.co",
239
+ },
240
+ }
competitions/templates/index.html CHANGED
@@ -244,6 +244,31 @@
244
  });
245
 
246
  </script>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  </head>
248
 
249
  <body class="flex h-screen">
@@ -355,13 +380,18 @@
355
  <span class="flex-1 ms-3 whitespace-nowrap">Team</span>
356
  </a>
357
  </li> -->
358
- <li>
359
  <label for="user_token" class="text-xs font-medium">Hugging Face <a
360
  href="https://huggingface.co/settings/tokens" target="_blank">Token</a> (read-only)
361
  </label>
362
  <input type="password" name="user_token" id="user_token"
363
  class="mt-1 block w-full border border-gray-300 px-3 py-1.5 bg-white rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
364
  </li>
 
 
 
 
 
365
  </ul>
366
  <footer>
367
  <div class="w-full mx-auto max-w-screen-xl p-4 md:flex md:items-center md:justify-between">
 
244
  });
245
 
246
  </script>
247
+ <script>
248
+ function makeApiRequest(url, callback) {
249
+ var xhr = new XMLHttpRequest();
250
+ xhr.open("GET", url, true);
251
+ xhr.onreadystatechange = function () {
252
+ if (xhr.readyState === 4 && xhr.status === 200) {
253
+ var response = JSON.parse(xhr.responseText);
254
+ callback(response.response);
255
+ }
256
+ };
257
+ xhr.send();
258
+ }
259
+
260
+ function checkOAuth() {
261
+ var url = "/use_oauth";
262
+ makeApiRequest(url, function (response) {
263
+ if (response === 1) {
264
+ document.getElementById("userToken").style.display = "none";
265
+ document.getElementById("loginButton").style.display = "block";
266
+ }
267
+ });
268
+ }
269
+
270
+ window.onload = checkOAuth;
271
+ </script>
272
  </head>
273
 
274
  <body class="flex h-screen">
 
380
  <span class="flex-1 ms-3 whitespace-nowrap">Team</span>
381
  </a>
382
  </li> -->
383
+ <li id="userToken">
384
  <label for="user_token" class="text-xs font-medium">Hugging Face <a
385
  href="https://huggingface.co/settings/tokens" target="_blank">Token</a> (read-only)
386
  </label>
387
  <input type="password" name="user_token" id="user_token"
388
  class="mt-1 block w-full border border-gray-300 px-3 py-1.5 bg-white rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
389
  </li>
390
+ <li id="loginButton" style="display: none;">
391
+ <a href="/login/huggingface"
392
+ class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded">Login with Hugging
393
+ Face</a>
394
+ </li>
395
  </ul>
396
  <footer>
397
  <div class="w-full mx-auto max-w-screen-xl p-4 md:flex md:items-center md:justify-between">