|
import os |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
import requests |
|
import base64 |
|
import json |
|
import boto3 |
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
COGNITO_DOMAIN = os.environ.get("COGNITO_DOMAIN") |
|
CLIENT_ID = os.environ.get("CLIENT_ID") |
|
CLIENT_SECRET = os.environ.get("CLIENT_SECRET") |
|
APP_URI = os.environ.get("APP_URI") |
|
|
|
|
|
|
|
def initialise_st_state_vars(): |
|
""" |
|
Initialise Streamlit state variables. |
|
Returns: |
|
Nothing. |
|
""" |
|
if "auth_code" not in st.session_state: |
|
st.session_state["auth_code"] = "" |
|
if "authenticated" not in st.session_state: |
|
st.session_state["authenticated"] = False |
|
if "user_cognito_groups" not in st.session_state: |
|
st.session_state["user_cognito_groups"] = [] |
|
if "user_info" not in st.session_state: |
|
st.session_state["user_info"] = {} |
|
|
|
|
|
|
|
def get_auth_code(): |
|
""" |
|
Gets auth_code state variable. |
|
Returns: |
|
Nothing. |
|
""" |
|
auth_query_params = st.experimental_get_query_params() |
|
try: |
|
auth_code = dict(auth_query_params)["code"][0] |
|
except (KeyError, TypeError): |
|
auth_code = "" |
|
return auth_code |
|
|
|
|
|
|
|
def set_auth_code(): |
|
""" |
|
Sets auth_code state variable. |
|
Returns: |
|
Nothing. |
|
""" |
|
initialise_st_state_vars() |
|
auth_code = get_auth_code() |
|
st.session_state["auth_code"] = auth_code |
|
|
|
|
|
|
|
def get_user_tokens(auth_code): |
|
""" |
|
Gets user tokens by making a post request call. |
|
Args: |
|
auth_code: Authorization code from cognito server. |
|
Returns: |
|
{ |
|
'access_token': access token from cognito server if user is successfully authenticated. |
|
'id_token': access token from cognito server if user is successfully authenticated. |
|
} |
|
""" |
|
|
|
token_url = f"{COGNITO_DOMAIN}/oauth2/token" |
|
client_secret_string = f"{CLIENT_ID}:{CLIENT_SECRET}" |
|
client_secret_encoded = str( |
|
base64.b64encode(client_secret_string.encode("utf-8")), "utf-8" |
|
) |
|
headers = { |
|
"Content-Type": "application/x-www-form-urlencoded", |
|
"Authorization": f"Basic {client_secret_encoded}", |
|
} |
|
body = { |
|
"grant_type": "authorization_code", |
|
"client_id": CLIENT_ID, |
|
"code": auth_code, |
|
"redirect_uri": APP_URI, |
|
} |
|
token_response = requests.post(token_url, headers=headers, data=body) |
|
try: |
|
access_token = token_response.json()["access_token"] |
|
id_token = token_response.json()["id_token"] |
|
except (KeyError, TypeError): |
|
access_token = "" |
|
id_token = "" |
|
return access_token, id_token |
|
|
|
|
|
|
|
def get_user_info(access_token): |
|
""" |
|
Gets user info from aws cognito server. |
|
Args: |
|
access_token: string access token from the aws cognito user pool |
|
retrieved using the access code. |
|
Returns: |
|
userinfo_response: json object. |
|
""" |
|
userinfo_url = f"{COGNITO_DOMAIN}/oauth2/userInfo" |
|
headers = { |
|
"Content-Type": "application/json;charset=UTF-8", |
|
"Authorization": f"Bearer {access_token}", |
|
} |
|
userinfo_response = requests.get(userinfo_url, headers=headers) |
|
return userinfo_response.json() |
|
|
|
|
|
def add_user_to_group(user_pool_id, username, group_name): |
|
client = boto3.client('cognito-idp') |
|
|
|
response = client.admin_add_user_to_group( |
|
UserPoolId=user_pool_id, |
|
Username=username, |
|
GroupName=group_name |
|
) |
|
|
|
|
|
if 'ResponseMetadata' in response and response['ResponseMetadata']['HTTPStatusCode'] == 200: |
|
print("User added to group successfully.") |
|
else: |
|
print("Failed to add user to group. Error:", response) |
|
|
|
|
|
|
|
|
|
def get_user_pool_id(user_pool_name): |
|
client = boto3.client('cognito-idp') |
|
|
|
response = client.list_user_pools(MaxResults=60) |
|
|
|
for user_pool in response['UserPools']: |
|
if user_pool['Name'] == user_pool_name: |
|
return user_pool['Id'] |
|
|
|
|
|
raise ValueError("User Pool '{}' not found.".format(user_pool_name)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_base64(data): |
|
""" |
|
Makes sure base64 data is padded. |
|
Args: |
|
data: base64 token string. |
|
Returns: |
|
data: padded token string. |
|
""" |
|
missing_padding = len(data) % 4 |
|
if missing_padding != 0: |
|
data += "=" * (4 - missing_padding) |
|
return data |
|
def get_user_cognito_groups(id_token): |
|
""" |
|
Decode id token to get user cognito groups. |
|
Args: |
|
id_token: id token of a successfully authenticated user. |
|
Returns: |
|
user_cognito_groups: a list of all the cognito groups the user belongs to. |
|
""" |
|
if id_token != "": |
|
header, payload, signature = id_token.split(".") |
|
printable_payload = base64.urlsafe_b64decode(pad_base64(payload)) |
|
payload_dict = json.loads(printable_payload) |
|
user_cognito_groups = list(dict(payload_dict)["cognito:groups"]) |
|
else: |
|
user_cognito_groups = [] |
|
return user_cognito_groups |
|
|
|
|
|
|
|
def set_st_state_vars(): |
|
""" |
|
Sets the streamlit state variables after user authentication. |
|
Returns: |
|
Nothing. |
|
""" |
|
initialise_st_state_vars() |
|
auth_code = get_auth_code() |
|
access_token, id_token = get_user_tokens(auth_code) |
|
user_info = get_user_info(access_token) |
|
user_cognito_groups = get_user_cognito_groups(id_token) |
|
if access_token != "": |
|
st.session_state["auth_code"] = auth_code |
|
st.session_state["authenticated"] = True |
|
st.session_state["user_info"] = user_info |
|
st.session_state["user_cognito_groups"] = user_cognito_groups |
|
|
|
|
|
|
|
login_link = f"{COGNITO_DOMAIN}/login?client_id={CLIENT_ID}&response_type=code&scope=email+openid&redirect_uri={APP_URI}" |
|
logout_link = f"{COGNITO_DOMAIN}/logout?client_id={CLIENT_ID}&logout_uri={APP_URI}" |
|
html_css_login = """ |
|
<style> |
|
.button-login { |
|
background-color: skyblue; |
|
color: white !important; |
|
padding: 1em 1.5em; |
|
text-decoration: none; |
|
text-transform: uppercase; |
|
} |
|
.button-login:hover { |
|
background-color: #555; |
|
text-decoration: none; |
|
} |
|
.button-login:active { |
|
background-color: black; |
|
} |
|
</style> |
|
""" |
|
html_button_login = ( |
|
html_css_login |
|
+ f"<a href='{login_link}' class='button-login' target='_self'>Se connecter</a>" |
|
) |
|
html_button_logout = ( |
|
html_css_login |
|
+ f"<a href='{logout_link}' class='button-login' target='_self'>Se déconnecter</a>" |
|
) |
|
def button_login(): |
|
""" |
|
Returns: |
|
Html of the login button. |
|
""" |
|
return st.sidebar.markdown(f"{html_button_login}", unsafe_allow_html=True) |
|
def button_logout(): |
|
""" |
|
Returns: |
|
Html of the logout button. |
|
""" |
|
return st.sidebar.markdown(f"{html_button_logout}", unsafe_allow_html=True) |
|
|