Spaces:
Sleeping
Sleeping
import streamlit as st | |
import plotly.express as px | |
import numpy as np | |
import plotly.graph_objects as go | |
from utilities import ( | |
channel_name_formating, | |
load_authenticator, | |
initialize_data, | |
fetch_actual_data, | |
) | |
from sklearn.metrics import r2_score | |
from collections import OrderedDict | |
from classes import class_from_dict, class_to_dict | |
import pickle | |
import json | |
import sqlite3 | |
from utilities import update_db | |
for k, v in st.session_state.items(): | |
if k not in ["logout", "login", "config"] and not k.startswith( | |
"FormSubmitter" | |
): | |
st.session_state[k] = v | |
def s_curve(x, K, b, a, x0): | |
return K / (1 + b * np.exp(-a * (x - x0))) | |
def save_scenario(scenario_name): | |
""" | |
Save the current scenario with the mentioned name in the session state | |
Parameters | |
---------- | |
scenario_name | |
Name of the scenario to be saved | |
""" | |
if "saved_scenarios" not in st.session_state: | |
st.session_state = OrderedDict() | |
# st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save() | |
st.session_state["saved_scenarios"][scenario_name] = class_to_dict( | |
st.session_state["scenario"] | |
) | |
st.session_state["scenario_input"] = "" | |
print(type(st.session_state["saved_scenarios"])) | |
with open("../saved_scenarios.pkl", "wb") as f: | |
pickle.dump(st.session_state["saved_scenarios"], f) | |
def reset_curve_parameters( | |
metrics=None, panel=None, selected_channel_name=None | |
): | |
del st.session_state["K"] | |
del st.session_state["b"] | |
del st.session_state["a"] | |
del st.session_state["x0"] | |
if ( | |
metrics is not None | |
and panel is not None | |
and selected_channel_name is not None | |
): | |
if f"{metrics}#@{panel}#@{selected_channel_name}" in list( | |
st.session_state["update_rcs"].keys() | |
): | |
del st.session_state["update_rcs"][ | |
f"{metrics}#@{panel}#@{selected_channel_name}" | |
] | |
def update_response_curve( | |
K_updated, | |
b_updated, | |
a_updated, | |
x0_updated, | |
metrics=None, | |
panel=None, | |
selected_channel_name=None, | |
): | |
print( | |
"[DEBUG] update_response_curves: ", | |
st.session_state["project_dct"]["scenario_planner"].keys(), | |
) | |
st.session_state["project_dct"]["scenario_planner"][unique_key].channels[ | |
selected_channel_name | |
].response_curve_params = { | |
"K": st.session_state["K"], | |
"b": st.session_state["b"], | |
"a": st.session_state["a"], | |
"x0": st.session_state["x0"], | |
} | |
# if ( | |
# metrics is not None | |
# and panel is not None | |
# and selected_channel_name is not None | |
# ): | |
# st.session_state["update_rcs"][ | |
# f"{metrics}#@{panel}#@{selected_channel_name}" | |
# ] = { | |
# "K": K_updated, | |
# "b": b_updated, | |
# "a": a_updated, | |
# "x0": x0_updated, | |
# } | |
# st.session_state["scenario"].channels[ | |
# selected_channel_name | |
# ].response_curve_params = { | |
# "K": K_updated, | |
# "b": b_updated, | |
# "a": a_updated, | |
# "x0": x0_updated, | |
# } | |
# authenticator = st.session_state.get('authenticator') | |
# if authenticator is None: | |
# authenticator = load_authenticator() | |
# name, authentication_status, username = authenticator.login('Login', 'main') | |
# auth_status = st.session_state.get('authentication_status') | |
# if auth_status == True: | |
# is_state_initiaized = st.session_state.get('initialized',False) | |
# if not is_state_initiaized: | |
# print("Scenario page state reloaded") | |
import pandas as pd | |
def panel_fetch(file_selected): | |
raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM") | |
if "Panel" in raw_data_mmm_df.columns: | |
panel = list(set(raw_data_mmm_df["Panel"])) | |
else: | |
raw_data_mmm_df = None | |
panel = None | |
return panel | |
import glob | |
import os | |
def get_excel_names(directory): | |
# Create a list to hold the final parts of the filenames | |
last_portions = [] | |
# Patterns to match Excel files (.xlsx and .xls) that contain @# | |
patterns = [ | |
os.path.join(directory, "*@#*.xlsx"), | |
os.path.join(directory, "*@#*.xls"), | |
] | |
# Process each pattern | |
for pattern in patterns: | |
files = glob.glob(pattern) | |
# Extracting the last portion after @# for each file | |
for file in files: | |
base_name = os.path.basename(file) | |
last_portion = base_name.split("@#")[-1] | |
last_portion = last_portion.replace(".xlsx", "").replace( | |
".xls", "" | |
) # Removing extensions | |
last_portions.append(last_portion) | |
return last_portions | |
def name_formating(channel_name): | |
# Replace underscores with spaces | |
name_mod = channel_name.replace("_", " ") | |
# Capitalize the first letter of each word | |
name_mod = name_mod.title() | |
return name_mod | |
def fetch_panel_data(): | |
print("DEBUG etch_panel_data: running... ") | |
file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx" | |
panel_selected = st.session_state["panel_selected_selectbox"] | |
print(panel_selected) | |
if panel_selected == "Aggregated": | |
( | |
st.session_state["actual_input_df"], | |
st.session_state["actual_contribution_df"], | |
) = fetch_actual_data(panel=panel_selected, target_file=file_selected) | |
else: | |
( | |
st.session_state["actual_input_df"], | |
st.session_state["actual_contribution_df"], | |
) = fetch_actual_data(panel=panel_selected, target_file=file_selected) | |
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}" | |
print("unique_key") | |
if unique_key not in st.session_state["project_dct"]["scenario_planner"]: | |
if panel_selected == "Aggregated": | |
initialize_data( | |
panel=panel_selected, | |
target_file=file_selected, | |
updated_rcs={}, | |
metrics=metrics_selected, | |
) | |
panel = None | |
else: | |
initialize_data( | |
panel=panel_selected, | |
target_file=file_selected, | |
updated_rcs={}, | |
metrics=metrics_selected, | |
) | |
st.session_state["project_dct"]["scenario_planner"][unique_key] = ( | |
st.session_state["scenario"] | |
) | |
# print( | |
# "DEBUG etch_panel_data: ", | |
# st.session_state["project_dct"]["scenario_planner"][ | |
# unique_key | |
# ].keys(), | |
# ) | |
else: | |
st.session_state["scenario"] = st.session_state["project_dct"][ | |
"scenario_planner" | |
][unique_key] | |
st.session_state["rcs"] = {} | |
st.session_state["powers"] = {} | |
for channel_name, _channel in st.session_state["project_dct"][ | |
"scenario_planner" | |
][unique_key].channels.items(): | |
st.session_state["rcs"][ | |
channel_name | |
] = _channel.response_curve_params | |
st.session_state["powers"][channel_name] = _channel.power | |
if "K" in st.session_state: | |
del st.session_state["K"] | |
if "b" in st.session_state: | |
del st.session_state["b"] | |
if "a" in st.session_state: | |
del st.session_state["a"] | |
if "x0" in st.session_state: | |
del st.session_state["x0"] | |
if "project_dct" not in st.session_state: | |
st.error("Please load a project from home") | |
st.stop() | |
database_file = r"DB\User.db" | |
conn = sqlite3.connect( | |
database_file, check_same_thread=False | |
) # connection with sql db | |
c = conn.cursor() | |
st.subheader("Build Response Curves") | |
if "update_rcs" not in st.session_state: | |
st.session_state["update_rcs"] = {} | |
st.session_state["first_time"] = True | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
directory = "metrics_level_data" | |
metrics_list = get_excel_names(directory) | |
metrics_selected = col1.selectbox( | |
"Response Metrics", | |
metrics_list, | |
on_change=fetch_panel_data, | |
format_func=name_formating, | |
key="response_metrics_selectbox", | |
) | |
file_selected = ( | |
f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx" | |
) | |
panel_list = panel_fetch(file_selected) | |
final_panel_list = ["Aggregated"] + panel_list | |
panel_selected = col3.selectbox( | |
"Panel", | |
final_panel_list, | |
on_change=fetch_panel_data, | |
key="panel_selected_selectbox", | |
) | |
is_state_initiaized = st.session_state.get("initialized_rcs", False) | |
print(is_state_initiaized) | |
if not is_state_initiaized: | |
print("DEBUG.....", "Here") | |
fetch_panel_data() | |
# if panel_selected == "Aggregated": | |
# initialize_data(panel=panel_selected, target_file=file_selected) | |
# panel = None | |
# else: | |
# initialize_data(panel=panel_selected, target_file=file_selected) | |
st.session_state["initialized_rcs"] = True | |
# channels_list = st.session_state["channels_list"] | |
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}" | |
chanel_list_final = list( | |
st.session_state["project_dct"]["scenario_planner"][ | |
unique_key | |
].channels.keys() | |
) + ["Others"] | |
selected_channel_name = col2.selectbox( | |
"Channel", | |
chanel_list_final, | |
format_func=channel_name_formating, | |
on_change=reset_curve_parameters, | |
key="selected_channel_name_selectbox", | |
) | |
rcs = st.session_state["rcs"] | |
if "K" not in st.session_state: | |
st.session_state["K"] = rcs[selected_channel_name]["K"] | |
if "b" not in st.session_state: | |
st.session_state["b"] = rcs[selected_channel_name]["b"] | |
if "a" not in st.session_state: | |
st.session_state["a"] = rcs[selected_channel_name]["a"] | |
if "x0" not in st.session_state: | |
st.session_state["x0"] = rcs[selected_channel_name]["x0"] | |
x = st.session_state["actual_input_df"][selected_channel_name].values | |
y = st.session_state["actual_contribution_df"][selected_channel_name].values | |
power = np.ceil(np.log(x.max()) / np.log(10)) - 3 | |
print(f"DEBUG BUILD RCS: {selected_channel_name}") | |
print(f"DEBUG BUILD RCS: K : {st.session_state['K']}") | |
print(f"DEBUG BUILD RCS: b : {st.session_state['b']}") | |
print(f"DEBUG BUILD RCS: a : {st.session_state['a']}") | |
print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}") | |
# fig = px.scatter(x, s_curve(x/10**power, | |
# st.session_state['K'], | |
# st.session_state['b'], | |
# st.session_state['a'], | |
# st.session_state['x0'])) | |
x_plot = np.linspace(0, 5 * max(x), 50) | |
fig = px.scatter(x=x, y=y) | |
fig.add_trace( | |
go.Scatter( | |
x=x_plot, | |
y=s_curve( | |
x_plot / 10**power, | |
st.session_state["K"], | |
st.session_state["b"], | |
st.session_state["a"], | |
st.session_state["x0"], | |
), | |
line=dict(color="red"), | |
name="Modified", | |
), | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=x_plot, | |
y=s_curve( | |
x_plot / 10**power, | |
rcs[selected_channel_name]["K"], | |
rcs[selected_channel_name]["b"], | |
rcs[selected_channel_name]["a"], | |
rcs[selected_channel_name]["x0"], | |
), | |
line=dict(color="rgba(0, 255, 0, 0.4)"), | |
name="Actual", | |
), | |
) | |
fig.update_layout(title_text="Response Curve", showlegend=True) | |
fig.update_annotations(font_size=10) | |
fig.update_xaxes(title="Spends") | |
fig.update_yaxes(title="Revenue") | |
st.plotly_chart(fig, use_container_width=True) | |
r2 = r2_score( | |
y, | |
s_curve( | |
x / 10**power, | |
st.session_state["K"], | |
st.session_state["b"], | |
st.session_state["a"], | |
st.session_state["x0"], | |
), | |
) | |
r2_actual = r2_score( | |
y, | |
s_curve( | |
x / 10**power, | |
rcs[selected_channel_name]["K"], | |
rcs[selected_channel_name]["b"], | |
rcs[selected_channel_name]["a"], | |
rcs[selected_channel_name]["x0"], | |
), | |
) | |
columns = st.columns((1, 1, 2)) | |
with columns[0]: | |
st.metric("R2 Modified", round(r2, 2)) | |
with columns[1]: | |
st.metric("R2 Actual", round(r2_actual, 2)) | |
st.markdown("#### Set Parameters", unsafe_allow_html=True) | |
columns = st.columns(4) | |
if "updated_parms" not in st.session_state: | |
st.session_state["updated_parms"] = { | |
"K_updated": 0, | |
"b_updated": 0, | |
"a_updated": 0, | |
"x0_updated": 0, | |
} | |
with columns[0]: | |
st.session_state["updated_parms"]["K_updated"] = st.number_input( | |
"K", key="K", format="%0.5f" | |
) | |
with columns[1]: | |
st.session_state["updated_parms"]["b_updated"] = st.number_input( | |
"b", key="b", format="%0.5f" | |
) | |
with columns[2]: | |
st.session_state["updated_parms"]["a_updated"] = st.number_input( | |
"a", key="a", step=0.0001, format="%0.5f" | |
) | |
with columns[3]: | |
st.session_state["updated_parms"]["x0_updated"] = st.number_input( | |
"x0", key="x0", format="%0.5f" | |
) | |
# st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = ( | |
# st.session_state["updated_parms"]["K_updated"] | |
# ) | |
# st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = ( | |
# st.session_state["updated_parms"]["b_updated"] | |
# ) | |
# st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = ( | |
# st.session_state["updated_parms"]["a_updated"] | |
# ) | |
# st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = ( | |
# st.session_state["updated_parms"]["x0_updated"] | |
# ) | |
update_col, reset_col = st.columns([1, 1]) | |
if update_col.button( | |
"Update Parameters", | |
on_click=update_response_curve, | |
args=( | |
st.session_state["updated_parms"]["K_updated"], | |
st.session_state["updated_parms"]["b_updated"], | |
st.session_state["updated_parms"]["a_updated"], | |
st.session_state["updated_parms"]["x0_updated"], | |
metrics_selected, | |
panel_selected, | |
selected_channel_name, | |
), | |
use_container_width=True, | |
): | |
st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[ | |
"updated_parms" | |
]["K_updated"] | |
st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[ | |
"updated_parms" | |
]["b_updated"] | |
st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[ | |
"updated_parms" | |
]["a_updated"] | |
st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[ | |
"updated_parms" | |
]["x0_updated"] | |
reset_col.button( | |
"Reset Parameters", | |
on_click=reset_curve_parameters, | |
args=(metrics_selected, panel_selected, selected_channel_name), | |
use_container_width=True, | |
) | |
st.divider() | |
save_col, down_col = st.columns([1, 1]) | |
with save_col: | |
file_name = st.text_input( | |
"rcs download file name", | |
key="file_name_input", | |
placeholder="File name", | |
label_visibility="collapsed", | |
) | |
down_col.download_button( | |
label="Download response curves", | |
data=json.dumps(rcs), | |
file_name=f"{file_name}.json", | |
mime="application/json", | |
disabled=len(file_name) == 0, | |
use_container_width=True, | |
) | |
def s_curve_derivative(x, K, b, a, x0): | |
# Derivative of the S-curve function | |
return ( | |
a | |
* b | |
* K | |
* np.exp(-a * (x - x0)) | |
/ ((1 + b * np.exp(-a * (x - x0))) ** 2) | |
) | |
# Parameters of the S-curve | |
K = st.session_state["K"] | |
b = st.session_state["b"] | |
a = st.session_state["a"] | |
x0 = st.session_state["x0"] | |
# # Optimized spend value obtained from the tool | |
# optimized_spend = st.number_input( | |
# "value of x" | |
# ) # Replace this with your optimized spend value | |
# # Calculate the slope at the optimized spend value | |
# slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0) | |
# st.write("Slope ", slope_at_optimized_spend) | |
# Initialize a list to hold our rows | |
rows = [] | |
# Iterate over the dictionary | |
for key, value in st.session_state["update_rcs"].items(): | |
# Split the key into its components | |
metrics, panel, channel_name = key.split("#@") | |
# Create a new row with the components and the values | |
row = { | |
"Metrics": name_formating(metrics), | |
"Panel": name_formating(panel), | |
"Channel Name": channel_name, | |
"K": value["K"], | |
"b": value["b"], | |
"a": value["a"], | |
"x0": value["x0"], | |
} | |
# Append the row to our list | |
rows.append(row) | |
# Convert the list of rows into a DataFrame | |
updated_parms_df = pd.DataFrame(rows) | |
if len(list(st.session_state["update_rcs"].keys())) > 0: | |
st.markdown("#### Updated Parameters", unsafe_allow_html=True) | |
st.dataframe(updated_parms_df, hide_index=True) | |
else: | |
st.info("No parameters are updated") | |
update_db("8_Build_Response_Curves.py") | |