Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import plotly.express as px | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from utilities import ( | |
| channel_name_formating, | |
| load_authenticator, | |
| initialize_data, | |
| fetch_actual_data, | |
| ) | |
| from sklearn.metrics import r2_score | |
| from collections import OrderedDict | |
| from classes import class_from_dict, class_to_dict | |
| import pickle | |
| import json | |
| import sqlite3 | |
| from utilities import update_db | |
| for k, v in st.session_state.items(): | |
| if k not in ["logout", "login", "config"] and not k.startswith( | |
| "FormSubmitter" | |
| ): | |
| st.session_state[k] = v | |
| def s_curve(x, K, b, a, x0): | |
| return K / (1 + b * np.exp(-a * (x - x0))) | |
| def save_scenario(scenario_name): | |
| """ | |
| Save the current scenario with the mentioned name in the session state | |
| Parameters | |
| ---------- | |
| scenario_name | |
| Name of the scenario to be saved | |
| """ | |
| if "saved_scenarios" not in st.session_state: | |
| st.session_state = OrderedDict() | |
| # st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save() | |
| st.session_state["saved_scenarios"][scenario_name] = class_to_dict( | |
| st.session_state["scenario"] | |
| ) | |
| st.session_state["scenario_input"] = "" | |
| print(type(st.session_state["saved_scenarios"])) | |
| with open("../saved_scenarios.pkl", "wb") as f: | |
| pickle.dump(st.session_state["saved_scenarios"], f) | |
| def reset_curve_parameters( | |
| metrics=None, panel=None, selected_channel_name=None | |
| ): | |
| del st.session_state["K"] | |
| del st.session_state["b"] | |
| del st.session_state["a"] | |
| del st.session_state["x0"] | |
| if ( | |
| metrics is not None | |
| and panel is not None | |
| and selected_channel_name is not None | |
| ): | |
| if f"{metrics}#@{panel}#@{selected_channel_name}" in list( | |
| st.session_state["update_rcs"].keys() | |
| ): | |
| del st.session_state["update_rcs"][ | |
| f"{metrics}#@{panel}#@{selected_channel_name}" | |
| ] | |
| def update_response_curve( | |
| K_updated, | |
| b_updated, | |
| a_updated, | |
| x0_updated, | |
| metrics=None, | |
| panel=None, | |
| selected_channel_name=None, | |
| ): | |
| print( | |
| "[DEBUG] update_response_curves: ", | |
| st.session_state["project_dct"]["scenario_planner"].keys(), | |
| ) | |
| st.session_state["project_dct"]["scenario_planner"][unique_key].channels[ | |
| selected_channel_name | |
| ].response_curve_params = { | |
| "K": st.session_state["K"], | |
| "b": st.session_state["b"], | |
| "a": st.session_state["a"], | |
| "x0": st.session_state["x0"], | |
| } | |
| # if ( | |
| # metrics is not None | |
| # and panel is not None | |
| # and selected_channel_name is not None | |
| # ): | |
| # st.session_state["update_rcs"][ | |
| # f"{metrics}#@{panel}#@{selected_channel_name}" | |
| # ] = { | |
| # "K": K_updated, | |
| # "b": b_updated, | |
| # "a": a_updated, | |
| # "x0": x0_updated, | |
| # } | |
| # st.session_state["scenario"].channels[ | |
| # selected_channel_name | |
| # ].response_curve_params = { | |
| # "K": K_updated, | |
| # "b": b_updated, | |
| # "a": a_updated, | |
| # "x0": x0_updated, | |
| # } | |
| # authenticator = st.session_state.get('authenticator') | |
| # if authenticator is None: | |
| # authenticator = load_authenticator() | |
| # name, authentication_status, username = authenticator.login('Login', 'main') | |
| # auth_status = st.session_state.get('authentication_status') | |
| # if auth_status == True: | |
| # is_state_initiaized = st.session_state.get('initialized',False) | |
| # if not is_state_initiaized: | |
| # print("Scenario page state reloaded") | |
| import pandas as pd | |
| def panel_fetch(file_selected): | |
| raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM") | |
| if "Panel" in raw_data_mmm_df.columns: | |
| panel = list(set(raw_data_mmm_df["Panel"])) | |
| else: | |
| raw_data_mmm_df = None | |
| panel = None | |
| return panel | |
| import glob | |
| import os | |
| def get_excel_names(directory): | |
| # Create a list to hold the final parts of the filenames | |
| last_portions = [] | |
| # Patterns to match Excel files (.xlsx and .xls) that contain @# | |
| patterns = [ | |
| os.path.join(directory, "*@#*.xlsx"), | |
| os.path.join(directory, "*@#*.xls"), | |
| ] | |
| # Process each pattern | |
| for pattern in patterns: | |
| files = glob.glob(pattern) | |
| # Extracting the last portion after @# for each file | |
| for file in files: | |
| base_name = os.path.basename(file) | |
| last_portion = base_name.split("@#")[-1] | |
| last_portion = last_portion.replace(".xlsx", "").replace( | |
| ".xls", "" | |
| ) # Removing extensions | |
| last_portions.append(last_portion) | |
| return last_portions | |
| def name_formating(channel_name): | |
| # Replace underscores with spaces | |
| name_mod = channel_name.replace("_", " ") | |
| # Capitalize the first letter of each word | |
| name_mod = name_mod.title() | |
| return name_mod | |
| def fetch_panel_data(): | |
| print("DEBUG etch_panel_data: running... ") | |
| file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx" | |
| panel_selected = st.session_state["panel_selected_selectbox"] | |
| print(panel_selected) | |
| if panel_selected == "Aggregated": | |
| ( | |
| st.session_state["actual_input_df"], | |
| st.session_state["actual_contribution_df"], | |
| ) = fetch_actual_data(panel=panel_selected, target_file=file_selected) | |
| else: | |
| ( | |
| st.session_state["actual_input_df"], | |
| st.session_state["actual_contribution_df"], | |
| ) = fetch_actual_data(panel=panel_selected, target_file=file_selected) | |
| unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}" | |
| print("unique_key") | |
| if unique_key not in st.session_state["project_dct"]["scenario_planner"]: | |
| if panel_selected == "Aggregated": | |
| initialize_data( | |
| panel=panel_selected, | |
| target_file=file_selected, | |
| updated_rcs={}, | |
| metrics=metrics_selected, | |
| ) | |
| panel = None | |
| else: | |
| initialize_data( | |
| panel=panel_selected, | |
| target_file=file_selected, | |
| updated_rcs={}, | |
| metrics=metrics_selected, | |
| ) | |
| st.session_state["project_dct"]["scenario_planner"][unique_key] = ( | |
| st.session_state["scenario"] | |
| ) | |
| # print( | |
| # "DEBUG etch_panel_data: ", | |
| # st.session_state["project_dct"]["scenario_planner"][ | |
| # unique_key | |
| # ].keys(), | |
| # ) | |
| else: | |
| st.session_state["scenario"] = st.session_state["project_dct"][ | |
| "scenario_planner" | |
| ][unique_key] | |
| st.session_state["rcs"] = {} | |
| st.session_state["powers"] = {} | |
| for channel_name, _channel in st.session_state["project_dct"][ | |
| "scenario_planner" | |
| ][unique_key].channels.items(): | |
| st.session_state["rcs"][ | |
| channel_name | |
| ] = _channel.response_curve_params | |
| st.session_state["powers"][channel_name] = _channel.power | |
| if "K" in st.session_state: | |
| del st.session_state["K"] | |
| if "b" in st.session_state: | |
| del st.session_state["b"] | |
| if "a" in st.session_state: | |
| del st.session_state["a"] | |
| if "x0" in st.session_state: | |
| del st.session_state["x0"] | |
| if "project_dct" not in st.session_state: | |
| st.error("Please load a project from home") | |
| st.stop() | |
| database_file = r"DB\User.db" | |
| conn = sqlite3.connect( | |
| database_file, check_same_thread=False | |
| ) # connection with sql db | |
| c = conn.cursor() | |
| st.subheader("Build Response Curves") | |
| if "update_rcs" not in st.session_state: | |
| st.session_state["update_rcs"] = {} | |
| st.session_state["first_time"] = True | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| directory = "metrics_level_data" | |
| metrics_list = get_excel_names(directory) | |
| metrics_selected = col1.selectbox( | |
| "Response Metrics", | |
| metrics_list, | |
| on_change=fetch_panel_data, | |
| format_func=name_formating, | |
| key="response_metrics_selectbox", | |
| ) | |
| file_selected = ( | |
| f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx" | |
| ) | |
| panel_list = panel_fetch(file_selected) | |
| final_panel_list = ["Aggregated"] + panel_list | |
| panel_selected = col3.selectbox( | |
| "Panel", | |
| final_panel_list, | |
| on_change=fetch_panel_data, | |
| key="panel_selected_selectbox", | |
| ) | |
| is_state_initiaized = st.session_state.get("initialized_rcs", False) | |
| print(is_state_initiaized) | |
| if not is_state_initiaized: | |
| print("DEBUG.....", "Here") | |
| fetch_panel_data() | |
| # if panel_selected == "Aggregated": | |
| # initialize_data(panel=panel_selected, target_file=file_selected) | |
| # panel = None | |
| # else: | |
| # initialize_data(panel=panel_selected, target_file=file_selected) | |
| st.session_state["initialized_rcs"] = True | |
| # channels_list = st.session_state["channels_list"] | |
| unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}" | |
| chanel_list_final = list( | |
| st.session_state["project_dct"]["scenario_planner"][ | |
| unique_key | |
| ].channels.keys() | |
| ) + ["Others"] | |
| selected_channel_name = col2.selectbox( | |
| "Channel", | |
| chanel_list_final, | |
| format_func=channel_name_formating, | |
| on_change=reset_curve_parameters, | |
| key="selected_channel_name_selectbox", | |
| ) | |
| rcs = st.session_state["rcs"] | |
| if "K" not in st.session_state: | |
| st.session_state["K"] = rcs[selected_channel_name]["K"] | |
| if "b" not in st.session_state: | |
| st.session_state["b"] = rcs[selected_channel_name]["b"] | |
| if "a" not in st.session_state: | |
| st.session_state["a"] = rcs[selected_channel_name]["a"] | |
| if "x0" not in st.session_state: | |
| st.session_state["x0"] = rcs[selected_channel_name]["x0"] | |
| x = st.session_state["actual_input_df"][selected_channel_name].values | |
| y = st.session_state["actual_contribution_df"][selected_channel_name].values | |
| power = np.ceil(np.log(x.max()) / np.log(10)) - 3 | |
| print(f"DEBUG BUILD RCS: {selected_channel_name}") | |
| print(f"DEBUG BUILD RCS: K : {st.session_state['K']}") | |
| print(f"DEBUG BUILD RCS: b : {st.session_state['b']}") | |
| print(f"DEBUG BUILD RCS: a : {st.session_state['a']}") | |
| print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}") | |
| # fig = px.scatter(x, s_curve(x/10**power, | |
| # st.session_state['K'], | |
| # st.session_state['b'], | |
| # st.session_state['a'], | |
| # st.session_state['x0'])) | |
| x_plot = np.linspace(0, 5 * max(x), 50) | |
| fig = px.scatter(x=x, y=y) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x_plot, | |
| y=s_curve( | |
| x_plot / 10**power, | |
| st.session_state["K"], | |
| st.session_state["b"], | |
| st.session_state["a"], | |
| st.session_state["x0"], | |
| ), | |
| line=dict(color="red"), | |
| name="Modified", | |
| ), | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=x_plot, | |
| y=s_curve( | |
| x_plot / 10**power, | |
| rcs[selected_channel_name]["K"], | |
| rcs[selected_channel_name]["b"], | |
| rcs[selected_channel_name]["a"], | |
| rcs[selected_channel_name]["x0"], | |
| ), | |
| line=dict(color="rgba(0, 255, 0, 0.4)"), | |
| name="Actual", | |
| ), | |
| ) | |
| fig.update_layout(title_text="Response Curve", showlegend=True) | |
| fig.update_annotations(font_size=10) | |
| fig.update_xaxes(title="Spends") | |
| fig.update_yaxes(title="Revenue") | |
| st.plotly_chart(fig, use_container_width=True) | |
| r2 = r2_score( | |
| y, | |
| s_curve( | |
| x / 10**power, | |
| st.session_state["K"], | |
| st.session_state["b"], | |
| st.session_state["a"], | |
| st.session_state["x0"], | |
| ), | |
| ) | |
| r2_actual = r2_score( | |
| y, | |
| s_curve( | |
| x / 10**power, | |
| rcs[selected_channel_name]["K"], | |
| rcs[selected_channel_name]["b"], | |
| rcs[selected_channel_name]["a"], | |
| rcs[selected_channel_name]["x0"], | |
| ), | |
| ) | |
| columns = st.columns((1, 1, 2)) | |
| with columns[0]: | |
| st.metric("R2 Modified", round(r2, 2)) | |
| with columns[1]: | |
| st.metric("R2 Actual", round(r2_actual, 2)) | |
| st.markdown("#### Set Parameters", unsafe_allow_html=True) | |
| columns = st.columns(4) | |
| if "updated_parms" not in st.session_state: | |
| st.session_state["updated_parms"] = { | |
| "K_updated": 0, | |
| "b_updated": 0, | |
| "a_updated": 0, | |
| "x0_updated": 0, | |
| } | |
| with columns[0]: | |
| st.session_state["updated_parms"]["K_updated"] = st.number_input( | |
| "K", key="K", format="%0.5f" | |
| ) | |
| with columns[1]: | |
| st.session_state["updated_parms"]["b_updated"] = st.number_input( | |
| "b", key="b", format="%0.5f" | |
| ) | |
| with columns[2]: | |
| st.session_state["updated_parms"]["a_updated"] = st.number_input( | |
| "a", key="a", step=0.0001, format="%0.5f" | |
| ) | |
| with columns[3]: | |
| st.session_state["updated_parms"]["x0_updated"] = st.number_input( | |
| "x0", key="x0", format="%0.5f" | |
| ) | |
| # st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = ( | |
| # st.session_state["updated_parms"]["K_updated"] | |
| # ) | |
| # st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = ( | |
| # st.session_state["updated_parms"]["b_updated"] | |
| # ) | |
| # st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = ( | |
| # st.session_state["updated_parms"]["a_updated"] | |
| # ) | |
| # st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = ( | |
| # st.session_state["updated_parms"]["x0_updated"] | |
| # ) | |
| update_col, reset_col = st.columns([1, 1]) | |
| if update_col.button( | |
| "Update Parameters", | |
| on_click=update_response_curve, | |
| args=( | |
| st.session_state["updated_parms"]["K_updated"], | |
| st.session_state["updated_parms"]["b_updated"], | |
| st.session_state["updated_parms"]["a_updated"], | |
| st.session_state["updated_parms"]["x0_updated"], | |
| metrics_selected, | |
| panel_selected, | |
| selected_channel_name, | |
| ), | |
| use_container_width=True, | |
| ): | |
| st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[ | |
| "updated_parms" | |
| ]["K_updated"] | |
| st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[ | |
| "updated_parms" | |
| ]["b_updated"] | |
| st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[ | |
| "updated_parms" | |
| ]["a_updated"] | |
| st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[ | |
| "updated_parms" | |
| ]["x0_updated"] | |
| reset_col.button( | |
| "Reset Parameters", | |
| on_click=reset_curve_parameters, | |
| args=(metrics_selected, panel_selected, selected_channel_name), | |
| use_container_width=True, | |
| ) | |
| st.divider() | |
| save_col, down_col = st.columns([1, 1]) | |
| with save_col: | |
| file_name = st.text_input( | |
| "rcs download file name", | |
| key="file_name_input", | |
| placeholder="File name", | |
| label_visibility="collapsed", | |
| ) | |
| down_col.download_button( | |
| label="Download response curves", | |
| data=json.dumps(rcs), | |
| file_name=f"{file_name}.json", | |
| mime="application/json", | |
| disabled=len(file_name) == 0, | |
| use_container_width=True, | |
| ) | |
| def s_curve_derivative(x, K, b, a, x0): | |
| # Derivative of the S-curve function | |
| return ( | |
| a | |
| * b | |
| * K | |
| * np.exp(-a * (x - x0)) | |
| / ((1 + b * np.exp(-a * (x - x0))) ** 2) | |
| ) | |
| # Parameters of the S-curve | |
| K = st.session_state["K"] | |
| b = st.session_state["b"] | |
| a = st.session_state["a"] | |
| x0 = st.session_state["x0"] | |
| # # Optimized spend value obtained from the tool | |
| # optimized_spend = st.number_input( | |
| # "value of x" | |
| # ) # Replace this with your optimized spend value | |
| # # Calculate the slope at the optimized spend value | |
| # slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0) | |
| # st.write("Slope ", slope_at_optimized_spend) | |
| # Initialize a list to hold our rows | |
| rows = [] | |
| # Iterate over the dictionary | |
| for key, value in st.session_state["update_rcs"].items(): | |
| # Split the key into its components | |
| metrics, panel, channel_name = key.split("#@") | |
| # Create a new row with the components and the values | |
| row = { | |
| "Metrics": name_formating(metrics), | |
| "Panel": name_formating(panel), | |
| "Channel Name": channel_name, | |
| "K": value["K"], | |
| "b": value["b"], | |
| "a": value["a"], | |
| "x0": value["x0"], | |
| } | |
| # Append the row to our list | |
| rows.append(row) | |
| # Convert the list of rows into a DataFrame | |
| updated_parms_df = pd.DataFrame(rows) | |
| if len(list(st.session_state["update_rcs"].keys())) > 0: | |
| st.markdown("#### Updated Parameters", unsafe_allow_html=True) | |
| st.dataframe(updated_parms_df, hide_index=True) | |
| else: | |
| st.info("No parameters are updated") | |
| update_db("8_Build_Response_Curves.py") | |