import streamlit as st import plotly.express as px import numpy as np import plotly.graph_objects as go from utilities import ( channel_name_formating, load_authenticator, initialize_data, fetch_actual_data, ) from sklearn.metrics import r2_score from collections import OrderedDict from classes import class_from_dict, class_to_dict import pickle import json import sqlite3 from utilities import update_db for k, v in st.session_state.items(): if k not in ["logout", "login", "config"] and not k.startswith( "FormSubmitter" ): st.session_state[k] = v def s_curve(x, K, b, a, x0): return K / (1 + b * np.exp(-a * (x - x0))) def save_scenario(scenario_name): """ Save the current scenario with the mentioned name in the session state Parameters ---------- scenario_name Name of the scenario to be saved """ if "saved_scenarios" not in st.session_state: st.session_state = OrderedDict() # st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save() st.session_state["saved_scenarios"][scenario_name] = class_to_dict( st.session_state["scenario"] ) st.session_state["scenario_input"] = "" print(type(st.session_state["saved_scenarios"])) with open("../saved_scenarios.pkl", "wb") as f: pickle.dump(st.session_state["saved_scenarios"], f) def reset_curve_parameters( metrics=None, panel=None, selected_channel_name=None ): del st.session_state["K"] del st.session_state["b"] del st.session_state["a"] del st.session_state["x0"] if ( metrics is not None and panel is not None and selected_channel_name is not None ): if f"{metrics}#@{panel}#@{selected_channel_name}" in list( st.session_state["update_rcs"].keys() ): del st.session_state["update_rcs"][ f"{metrics}#@{panel}#@{selected_channel_name}" ] def update_response_curve( K_updated, b_updated, a_updated, x0_updated, metrics=None, panel=None, selected_channel_name=None, ): print( "[DEBUG] update_response_curves: ", st.session_state["project_dct"]["scenario_planner"].keys(), ) st.session_state["project_dct"]["scenario_planner"][unique_key].channels[ selected_channel_name ].response_curve_params = { "K": st.session_state["K"], "b": st.session_state["b"], "a": st.session_state["a"], "x0": st.session_state["x0"], } # if ( # metrics is not None # and panel is not None # and selected_channel_name is not None # ): # st.session_state["update_rcs"][ # f"{metrics}#@{panel}#@{selected_channel_name}" # ] = { # "K": K_updated, # "b": b_updated, # "a": a_updated, # "x0": x0_updated, # } # st.session_state["scenario"].channels[ # selected_channel_name # ].response_curve_params = { # "K": K_updated, # "b": b_updated, # "a": a_updated, # "x0": x0_updated, # } # authenticator = st.session_state.get('authenticator') # if authenticator is None: # authenticator = load_authenticator() # name, authentication_status, username = authenticator.login('Login', 'main') # auth_status = st.session_state.get('authentication_status') # if auth_status == True: # is_state_initiaized = st.session_state.get('initialized',False) # if not is_state_initiaized: # print("Scenario page state reloaded") import pandas as pd @st.cache_resource(show_spinner=False) def panel_fetch(file_selected): raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM") if "Panel" in raw_data_mmm_df.columns: panel = list(set(raw_data_mmm_df["Panel"])) else: raw_data_mmm_df = None panel = None return panel import glob import os def get_excel_names(directory): # Create a list to hold the final parts of the filenames last_portions = [] # Patterns to match Excel files (.xlsx and .xls) that contain @# patterns = [ os.path.join(directory, "*@#*.xlsx"), os.path.join(directory, "*@#*.xls"), ] # Process each pattern for pattern in patterns: files = glob.glob(pattern) # Extracting the last portion after @# for each file for file in files: base_name = os.path.basename(file) last_portion = base_name.split("@#")[-1] last_portion = last_portion.replace(".xlsx", "").replace( ".xls", "" ) # Removing extensions last_portions.append(last_portion) return last_portions def name_formating(channel_name): # Replace underscores with spaces name_mod = channel_name.replace("_", " ") # Capitalize the first letter of each word name_mod = name_mod.title() return name_mod def fetch_panel_data(): print("DEBUG etch_panel_data: running... ") file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx" panel_selected = st.session_state["panel_selected_selectbox"] print(panel_selected) if panel_selected == "Aggregated": ( st.session_state["actual_input_df"], st.session_state["actual_contribution_df"], ) = fetch_actual_data(panel=panel_selected, target_file=file_selected) else: ( st.session_state["actual_input_df"], st.session_state["actual_contribution_df"], ) = fetch_actual_data(panel=panel_selected, target_file=file_selected) unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}" print("unique_key") if unique_key not in st.session_state["project_dct"]["scenario_planner"]: if panel_selected == "Aggregated": initialize_data( panel=panel_selected, target_file=file_selected, updated_rcs={}, metrics=metrics_selected, ) panel = None else: initialize_data( panel=panel_selected, target_file=file_selected, updated_rcs={}, metrics=metrics_selected, ) st.session_state["project_dct"]["scenario_planner"][unique_key] = ( st.session_state["scenario"] ) # print( # "DEBUG etch_panel_data: ", # st.session_state["project_dct"]["scenario_planner"][ # unique_key # ].keys(), # ) else: st.session_state["scenario"] = st.session_state["project_dct"][ "scenario_planner" ][unique_key] st.session_state["rcs"] = {} st.session_state["powers"] = {} for channel_name, _channel in st.session_state["project_dct"][ "scenario_planner" ][unique_key].channels.items(): st.session_state["rcs"][ channel_name ] = _channel.response_curve_params st.session_state["powers"][channel_name] = _channel.power if "K" in st.session_state: del st.session_state["K"] if "b" in st.session_state: del st.session_state["b"] if "a" in st.session_state: del st.session_state["a"] if "x0" in st.session_state: del st.session_state["x0"] if "project_dct" not in st.session_state: st.error("Please load a project from home") st.stop() database_file = r"DB\User.db" conn = sqlite3.connect( database_file, check_same_thread=False ) # connection with sql db c = conn.cursor() st.subheader("Build Response Curves") if "update_rcs" not in st.session_state: st.session_state["update_rcs"] = {} st.session_state["first_time"] = True col1, col2, col3 = st.columns([1, 1, 1]) directory = "metrics_level_data" metrics_list = get_excel_names(directory) metrics_selected = col1.selectbox( "Response Metrics", metrics_list, on_change=fetch_panel_data, format_func=name_formating, key="response_metrics_selectbox", ) file_selected = ( f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx" ) panel_list = panel_fetch(file_selected) final_panel_list = ["Aggregated"] + panel_list panel_selected = col3.selectbox( "Panel", final_panel_list, on_change=fetch_panel_data, key="panel_selected_selectbox", ) is_state_initiaized = st.session_state.get("initialized_rcs", False) print(is_state_initiaized) if not is_state_initiaized: print("DEBUG.....", "Here") fetch_panel_data() # if panel_selected == "Aggregated": # initialize_data(panel=panel_selected, target_file=file_selected) # panel = None # else: # initialize_data(panel=panel_selected, target_file=file_selected) st.session_state["initialized_rcs"] = True # channels_list = st.session_state["channels_list"] unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}" chanel_list_final = list( st.session_state["project_dct"]["scenario_planner"][ unique_key ].channels.keys() ) + ["Others"] selected_channel_name = col2.selectbox( "Channel", chanel_list_final, format_func=channel_name_formating, on_change=reset_curve_parameters, key="selected_channel_name_selectbox", ) rcs = st.session_state["rcs"] if "K" not in st.session_state: st.session_state["K"] = rcs[selected_channel_name]["K"] if "b" not in st.session_state: st.session_state["b"] = rcs[selected_channel_name]["b"] if "a" not in st.session_state: st.session_state["a"] = rcs[selected_channel_name]["a"] if "x0" not in st.session_state: st.session_state["x0"] = rcs[selected_channel_name]["x0"] x = st.session_state["actual_input_df"][selected_channel_name].values y = st.session_state["actual_contribution_df"][selected_channel_name].values power = np.ceil(np.log(x.max()) / np.log(10)) - 3 print(f"DEBUG BUILD RCS: {selected_channel_name}") print(f"DEBUG BUILD RCS: K : {st.session_state['K']}") print(f"DEBUG BUILD RCS: b : {st.session_state['b']}") print(f"DEBUG BUILD RCS: a : {st.session_state['a']}") print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}") # fig = px.scatter(x, s_curve(x/10**power, # st.session_state['K'], # st.session_state['b'], # st.session_state['a'], # st.session_state['x0'])) x_plot = np.linspace(0, 5 * max(x), 50) fig = px.scatter(x=x, y=y) fig.add_trace( go.Scatter( x=x_plot, y=s_curve( x_plot / 10**power, st.session_state["K"], st.session_state["b"], st.session_state["a"], st.session_state["x0"], ), line=dict(color="red"), name="Modified", ), ) fig.add_trace( go.Scatter( x=x_plot, y=s_curve( x_plot / 10**power, rcs[selected_channel_name]["K"], rcs[selected_channel_name]["b"], rcs[selected_channel_name]["a"], rcs[selected_channel_name]["x0"], ), line=dict(color="rgba(0, 255, 0, 0.4)"), name="Actual", ), ) fig.update_layout(title_text="Response Curve", showlegend=True) fig.update_annotations(font_size=10) fig.update_xaxes(title="Spends") fig.update_yaxes(title="Revenue") st.plotly_chart(fig, use_container_width=True) r2 = r2_score( y, s_curve( x / 10**power, st.session_state["K"], st.session_state["b"], st.session_state["a"], st.session_state["x0"], ), ) r2_actual = r2_score( y, s_curve( x / 10**power, rcs[selected_channel_name]["K"], rcs[selected_channel_name]["b"], rcs[selected_channel_name]["a"], rcs[selected_channel_name]["x0"], ), ) columns = st.columns((1, 1, 2)) with columns[0]: st.metric("R2 Modified", round(r2, 2)) with columns[1]: st.metric("R2 Actual", round(r2_actual, 2)) st.markdown("#### Set Parameters", unsafe_allow_html=True) columns = st.columns(4) if "updated_parms" not in st.session_state: st.session_state["updated_parms"] = { "K_updated": 0, "b_updated": 0, "a_updated": 0, "x0_updated": 0, } with columns[0]: st.session_state["updated_parms"]["K_updated"] = st.number_input( "K", key="K", format="%0.5f" ) with columns[1]: st.session_state["updated_parms"]["b_updated"] = st.number_input( "b", key="b", format="%0.5f" ) with columns[2]: st.session_state["updated_parms"]["a_updated"] = st.number_input( "a", key="a", step=0.0001, format="%0.5f" ) with columns[3]: st.session_state["updated_parms"]["x0_updated"] = st.number_input( "x0", key="x0", format="%0.5f" ) # st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = ( # st.session_state["updated_parms"]["K_updated"] # ) # st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = ( # st.session_state["updated_parms"]["b_updated"] # ) # st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = ( # st.session_state["updated_parms"]["a_updated"] # ) # st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = ( # st.session_state["updated_parms"]["x0_updated"] # ) update_col, reset_col = st.columns([1, 1]) if update_col.button( "Update Parameters", on_click=update_response_curve, args=( st.session_state["updated_parms"]["K_updated"], st.session_state["updated_parms"]["b_updated"], st.session_state["updated_parms"]["a_updated"], st.session_state["updated_parms"]["x0_updated"], metrics_selected, panel_selected, selected_channel_name, ), use_container_width=True, ): st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[ "updated_parms" ]["K_updated"] st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[ "updated_parms" ]["b_updated"] st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[ "updated_parms" ]["a_updated"] st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[ "updated_parms" ]["x0_updated"] reset_col.button( "Reset Parameters", on_click=reset_curve_parameters, args=(metrics_selected, panel_selected, selected_channel_name), use_container_width=True, ) st.divider() save_col, down_col = st.columns([1, 1]) with save_col: file_name = st.text_input( "rcs download file name", key="file_name_input", placeholder="File name", label_visibility="collapsed", ) down_col.download_button( label="Download response curves", data=json.dumps(rcs), file_name=f"{file_name}.json", mime="application/json", disabled=len(file_name) == 0, use_container_width=True, ) def s_curve_derivative(x, K, b, a, x0): # Derivative of the S-curve function return ( a * b * K * np.exp(-a * (x - x0)) / ((1 + b * np.exp(-a * (x - x0))) ** 2) ) # Parameters of the S-curve K = st.session_state["K"] b = st.session_state["b"] a = st.session_state["a"] x0 = st.session_state["x0"] # # Optimized spend value obtained from the tool # optimized_spend = st.number_input( # "value of x" # ) # Replace this with your optimized spend value # # Calculate the slope at the optimized spend value # slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0) # st.write("Slope ", slope_at_optimized_spend) # Initialize a list to hold our rows rows = [] # Iterate over the dictionary for key, value in st.session_state["update_rcs"].items(): # Split the key into its components metrics, panel, channel_name = key.split("#@") # Create a new row with the components and the values row = { "Metrics": name_formating(metrics), "Panel": name_formating(panel), "Channel Name": channel_name, "K": value["K"], "b": value["b"], "a": value["a"], "x0": value["x0"], } # Append the row to our list rows.append(row) # Convert the list of rows into a DataFrame updated_parms_df = pd.DataFrame(rows) if len(list(st.session_state["update_rcs"].keys())) > 0: st.markdown("#### Updated Parameters", unsafe_allow_html=True) st.dataframe(updated_parms_df, hide_index=True) else: st.info("No parameters are updated") update_db("8_Build_Response_Curves.py")