import os import streamlit as st from dotenv import load_dotenv import requests import base64 import json import boto3 # ------------------------------------ # Read constants from environment file # ------------------------------------ load_dotenv() COGNITO_DOMAIN = os.environ.get("COGNITO_DOMAIN") CLIENT_ID = os.environ.get("CLIENT_ID") CLIENT_SECRET = os.environ.get("CLIENT_SECRET") APP_URI = os.environ.get("APP_URI") # ------------------------------------ # Initialise Streamlit state variables # ------------------------------------ def initialise_st_state_vars(): """ Initialise Streamlit state variables. Returns: Nothing. """ if "auth_code" not in st.session_state: st.session_state["auth_code"] = "" if "authenticated" not in st.session_state: st.session_state["authenticated"] = False if "user_cognito_groups" not in st.session_state: st.session_state["user_cognito_groups"] = [] if "user_info" not in st.session_state: st.session_state["user_info"] = {} # ---------------------------------- # Get authorization code after login # ---------------------------------- def get_auth_code(): """ Gets auth_code state variable. Returns: Nothing. """ auth_query_params = st.experimental_get_query_params() try: auth_code = dict(auth_query_params)["code"][0] except (KeyError, TypeError): auth_code = "" return auth_code # ---------------------------------- # Set authorization code after login # ---------------------------------- def set_auth_code(): """ Sets auth_code state variable. Returns: Nothing. """ initialise_st_state_vars() auth_code = get_auth_code() st.session_state["auth_code"] = auth_code # ------------------------------------------------------- # Use authorization code to get user access and id tokens # ------------------------------------------------------- def get_user_tokens(auth_code): """ Gets user tokens by making a post request call. Args: auth_code: Authorization code from cognito server. Returns: { 'access_token': access token from cognito server if user is successfully authenticated. 'id_token': access token from cognito server if user is successfully authenticated. } """ # Variables to make a post request token_url = f"{COGNITO_DOMAIN}/oauth2/token" client_secret_string = f"{CLIENT_ID}:{CLIENT_SECRET}" client_secret_encoded = str( base64.b64encode(client_secret_string.encode("utf-8")), "utf-8" ) headers = { "Content-Type": "application/x-www-form-urlencoded", "Authorization": f"Basic {client_secret_encoded}", } body = { "grant_type": "authorization_code", "client_id": CLIENT_ID, "code": auth_code, "redirect_uri": APP_URI, } token_response = requests.post(token_url, headers=headers, data=body) try: access_token = token_response.json()["access_token"] id_token = token_response.json()["id_token"] except (KeyError, TypeError): access_token = "" id_token = "" return access_token, id_token # --------------------------------------------- # Use access token to retrieve user information # --------------------------------------------- def get_user_info(access_token): """ Gets user info from aws cognito server. Args: access_token: string access token from the aws cognito user pool retrieved using the access code. Returns: userinfo_response: json object. """ userinfo_url = f"{COGNITO_DOMAIN}/oauth2/userInfo" headers = { "Content-Type": "application/json;charset=UTF-8", "Authorization": f"Bearer {access_token}", } userinfo_response = requests.get(userinfo_url, headers=headers) return userinfo_response.json() ### Add user to a group def add_user_to_group(user_pool_id, username, group_name): client = boto3.client('cognito-idp') response = client.admin_add_user_to_group( UserPoolId=user_pool_id, Username=username, GroupName=group_name ) # Check the response for any errors if 'ResponseMetadata' in response and response['ResponseMetadata']['HTTPStatusCode'] == 200: print("User added to group successfully.") else: print("Failed to add user to group. Error:", response) ## Get user pool id def get_user_pool_id(user_pool_name): client = boto3.client('cognito-idp') response = client.list_user_pools(MaxResults=60) for user_pool in response['UserPools']: if user_pool['Name'] == user_pool_name: return user_pool['Id'] # If the User Pool with the given name is not found raise ValueError("User Pool '{}' not found.".format(user_pool_name)) # ------------------------------------------------------- # Decode access token to JWT to get user's cognito groups # ------------------------------------------------------- # Ref - https://gist.github.com/GuillaumeDerval/b300af6d4f906f38a051351afab3b95c def pad_base64(data): """ Makes sure base64 data is padded. Args: data: base64 token string. Returns: data: padded token string. """ missing_padding = len(data) % 4 if missing_padding != 0: data += "=" * (4 - missing_padding) return data def get_user_cognito_groups(id_token): """ Decode id token to get user cognito groups. Args: id_token: id token of a successfully authenticated user. Returns: user_cognito_groups: a list of all the cognito groups the user belongs to. """ if id_token != "": header, payload, signature = id_token.split(".") printable_payload = base64.urlsafe_b64decode(pad_base64(payload)) payload_dict = json.loads(printable_payload) user_cognito_groups = list(dict(payload_dict)["cognito:groups"]) else: user_cognito_groups = [] return user_cognito_groups # ----------------------------- # Set Streamlit state variables # ----------------------------- def set_st_state_vars(): """ Sets the streamlit state variables after user authentication. Returns: Nothing. """ initialise_st_state_vars() auth_code = get_auth_code() access_token, id_token = get_user_tokens(auth_code) user_info = get_user_info(access_token) user_cognito_groups = get_user_cognito_groups(id_token) if access_token != "": st.session_state["auth_code"] = auth_code st.session_state["authenticated"] = True st.session_state["user_info"] = user_info st.session_state["user_cognito_groups"] = user_cognito_groups # ----------------------------- # Login/ Logout HTML components # ----------------------------- login_link = f"{COGNITO_DOMAIN}/login?client_id={CLIENT_ID}&response_type=code&scope=email+openid&redirect_uri={APP_URI}" logout_link = f"{COGNITO_DOMAIN}/logout?client_id={CLIENT_ID}&logout_uri={APP_URI}" html_css_login = """ """ html_button_login = ( html_css_login + f"Se connecter" ) html_button_logout = ( html_css_login + f"Se déconnecter" ) def button_login(): """ Returns: Html of the login button. """ return st.sidebar.markdown(f"{html_button_login}", unsafe_allow_html=True) def button_logout(): """ Returns: Html of the logout button. """ return st.sidebar.markdown(f"{html_button_logout}", unsafe_allow_html=True)