v6Mastercardapp / pages /8_Response_Curves.py
BlendMMM's picture
Upload 11 files
ff89010 verified
import streamlit as st
import plotly.express as px
import numpy as np
import plotly.graph_objects as go
from utilities import (
channel_name_formating,
load_authenticator,
initialize_data,
fetch_actual_data,
)
from sklearn.metrics import r2_score
from collections import OrderedDict
from classes import class_from_dict, class_to_dict
import pickle
import json
import sqlite3
from utilities import update_db
for k, v in st.session_state.items():
if k not in ["logout", "login", "config"] and not k.startswith(
"FormSubmitter"
):
st.session_state[k] = v
def s_curve(x, K, b, a, x0):
return K / (1 + b * np.exp(-a * (x - x0)))
def save_scenario(scenario_name):
"""
Save the current scenario with the mentioned name in the session state
Parameters
----------
scenario_name
Name of the scenario to be saved
"""
if "saved_scenarios" not in st.session_state:
st.session_state = OrderedDict()
# st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
st.session_state["scenario"]
)
st.session_state["scenario_input"] = ""
print(type(st.session_state["saved_scenarios"]))
with open("../saved_scenarios.pkl", "wb") as f:
pickle.dump(st.session_state["saved_scenarios"], f)
def reset_curve_parameters(
metrics=None, panel=None, selected_channel_name=None
):
del st.session_state["K"]
del st.session_state["b"]
del st.session_state["a"]
del st.session_state["x0"]
if (
metrics is not None
and panel is not None
and selected_channel_name is not None
):
if f"{metrics}#@{panel}#@{selected_channel_name}" in list(
st.session_state["update_rcs"].keys()
):
del st.session_state["update_rcs"][
f"{metrics}#@{panel}#@{selected_channel_name}"
]
def update_response_curve(
K_updated,
b_updated,
a_updated,
x0_updated,
metrics=None,
panel=None,
selected_channel_name=None,
):
print(
"[DEBUG] update_response_curves: ",
st.session_state["project_dct"]["scenario_planner"].keys(),
)
st.session_state["project_dct"]["scenario_planner"][unique_key].channels[
selected_channel_name
].response_curve_params = {
"K": st.session_state["K"],
"b": st.session_state["b"],
"a": st.session_state["a"],
"x0": st.session_state["x0"],
}
# if (
# metrics is not None
# and panel is not None
# and selected_channel_name is not None
# ):
# st.session_state["update_rcs"][
# f"{metrics}#@{panel}#@{selected_channel_name}"
# ] = {
# "K": K_updated,
# "b": b_updated,
# "a": a_updated,
# "x0": x0_updated,
# }
# st.session_state["scenario"].channels[
# selected_channel_name
# ].response_curve_params = {
# "K": K_updated,
# "b": b_updated,
# "a": a_updated,
# "x0": x0_updated,
# }
# authenticator = st.session_state.get('authenticator')
# if authenticator is None:
# authenticator = load_authenticator()
# name, authentication_status, username = authenticator.login('Login', 'main')
# auth_status = st.session_state.get('authentication_status')
# if auth_status == True:
# is_state_initiaized = st.session_state.get('initialized',False)
# if not is_state_initiaized:
# print("Scenario page state reloaded")
import pandas as pd
@st.cache_resource(show_spinner=False)
def panel_fetch(file_selected):
raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
if "Panel" in raw_data_mmm_df.columns:
panel = list(set(raw_data_mmm_df["Panel"]))
else:
raw_data_mmm_df = None
panel = None
return panel
import glob
import os
def get_excel_names(directory):
# Create a list to hold the final parts of the filenames
last_portions = []
# Patterns to match Excel files (.xlsx and .xls) that contain @#
patterns = [
os.path.join(directory, "*@#*.xlsx"),
os.path.join(directory, "*@#*.xls"),
]
# Process each pattern
for pattern in patterns:
files = glob.glob(pattern)
# Extracting the last portion after @# for each file
for file in files:
base_name = os.path.basename(file)
last_portion = base_name.split("@#")[-1]
last_portion = last_portion.replace(".xlsx", "").replace(
".xls", ""
) # Removing extensions
last_portions.append(last_portion)
return last_portions
def name_formating(channel_name):
# Replace underscores with spaces
name_mod = channel_name.replace("_", " ")
# Capitalize the first letter of each word
name_mod = name_mod.title()
return name_mod
def fetch_panel_data():
print("DEBUG etch_panel_data: running... ")
file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx"
panel_selected = st.session_state["panel_selected_selectbox"]
print(panel_selected)
if panel_selected == "Aggregated":
(
st.session_state["actual_input_df"],
st.session_state["actual_contribution_df"],
) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
else:
(
st.session_state["actual_input_df"],
st.session_state["actual_contribution_df"],
) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
print("unique_key")
if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
if panel_selected == "Aggregated":
initialize_data(
panel=panel_selected,
target_file=file_selected,
updated_rcs={},
metrics=metrics_selected,
)
panel = None
else:
initialize_data(
panel=panel_selected,
target_file=file_selected,
updated_rcs={},
metrics=metrics_selected,
)
st.session_state["project_dct"]["scenario_planner"][unique_key] = (
st.session_state["scenario"]
)
# print(
# "DEBUG etch_panel_data: ",
# st.session_state["project_dct"]["scenario_planner"][
# unique_key
# ].keys(),
# )
else:
st.session_state["scenario"] = st.session_state["project_dct"][
"scenario_planner"
][unique_key]
st.session_state["rcs"] = {}
st.session_state["powers"] = {}
for channel_name, _channel in st.session_state["project_dct"][
"scenario_planner"
][unique_key].channels.items():
st.session_state["rcs"][
channel_name
] = _channel.response_curve_params
st.session_state["powers"][channel_name] = _channel.power
if "K" in st.session_state:
del st.session_state["K"]
if "b" in st.session_state:
del st.session_state["b"]
if "a" in st.session_state:
del st.session_state["a"]
if "x0" in st.session_state:
del st.session_state["x0"]
if "project_dct" not in st.session_state:
st.error("Please load a project from home")
st.stop()
database_file = r"DB\User.db"
conn = sqlite3.connect(
database_file, check_same_thread=False
) # connection with sql db
c = conn.cursor()
st.subheader("Build Response Curves")
if "update_rcs" not in st.session_state:
st.session_state["update_rcs"] = {}
st.session_state["first_time"] = True
col1, col2, col3 = st.columns([1, 1, 1])
directory = "metrics_level_data"
metrics_list = get_excel_names(directory)
metrics_selected = col1.selectbox(
"Response Metrics",
metrics_list,
on_change=fetch_panel_data,
format_func=name_formating,
key="response_metrics_selectbox",
)
file_selected = (
f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
)
panel_list = panel_fetch(file_selected)
final_panel_list = ["Aggregated"] + panel_list
panel_selected = col3.selectbox(
"Panel",
final_panel_list,
on_change=fetch_panel_data,
key="panel_selected_selectbox",
)
is_state_initiaized = st.session_state.get("initialized_rcs", False)
print(is_state_initiaized)
if not is_state_initiaized:
print("DEBUG.....", "Here")
fetch_panel_data()
# if panel_selected == "Aggregated":
# initialize_data(panel=panel_selected, target_file=file_selected)
# panel = None
# else:
# initialize_data(panel=panel_selected, target_file=file_selected)
st.session_state["initialized_rcs"] = True
# channels_list = st.session_state["channels_list"]
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
chanel_list_final = list(
st.session_state["project_dct"]["scenario_planner"][
unique_key
].channels.keys()
) + ["Others"]
selected_channel_name = col2.selectbox(
"Channel",
chanel_list_final,
format_func=channel_name_formating,
on_change=reset_curve_parameters,
key="selected_channel_name_selectbox",
)
rcs = st.session_state["rcs"]
if "K" not in st.session_state:
st.session_state["K"] = rcs[selected_channel_name]["K"]
if "b" not in st.session_state:
st.session_state["b"] = rcs[selected_channel_name]["b"]
if "a" not in st.session_state:
st.session_state["a"] = rcs[selected_channel_name]["a"]
if "x0" not in st.session_state:
st.session_state["x0"] = rcs[selected_channel_name]["x0"]
x = st.session_state["actual_input_df"][selected_channel_name].values
y = st.session_state["actual_contribution_df"][selected_channel_name].values
power = np.ceil(np.log(x.max()) / np.log(10)) - 3
print(f"DEBUG BUILD RCS: {selected_channel_name}")
print(f"DEBUG BUILD RCS: K : {st.session_state['K']}")
print(f"DEBUG BUILD RCS: b : {st.session_state['b']}")
print(f"DEBUG BUILD RCS: a : {st.session_state['a']}")
print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}")
# fig = px.scatter(x, s_curve(x/10**power,
# st.session_state['K'],
# st.session_state['b'],
# st.session_state['a'],
# st.session_state['x0']))
x_plot = np.linspace(0, 5 * max(x), 50)
fig = px.scatter(x=x, y=y)
fig.add_trace(
go.Scatter(
x=x_plot,
y=s_curve(
x_plot / 10**power,
st.session_state["K"],
st.session_state["b"],
st.session_state["a"],
st.session_state["x0"],
),
line=dict(color="red"),
name="Modified",
),
)
fig.add_trace(
go.Scatter(
x=x_plot,
y=s_curve(
x_plot / 10**power,
rcs[selected_channel_name]["K"],
rcs[selected_channel_name]["b"],
rcs[selected_channel_name]["a"],
rcs[selected_channel_name]["x0"],
),
line=dict(color="rgba(0, 255, 0, 0.4)"),
name="Actual",
),
)
fig.update_layout(title_text="Response Curve", showlegend=True)
fig.update_annotations(font_size=10)
fig.update_xaxes(title="Spends")
fig.update_yaxes(title="Revenue")
st.plotly_chart(fig, use_container_width=True)
r2 = r2_score(
y,
s_curve(
x / 10**power,
st.session_state["K"],
st.session_state["b"],
st.session_state["a"],
st.session_state["x0"],
),
)
r2_actual = r2_score(
y,
s_curve(
x / 10**power,
rcs[selected_channel_name]["K"],
rcs[selected_channel_name]["b"],
rcs[selected_channel_name]["a"],
rcs[selected_channel_name]["x0"],
),
)
columns = st.columns((1, 1, 2))
with columns[0]:
st.metric("R2 Modified", round(r2, 2))
with columns[1]:
st.metric("R2 Actual", round(r2_actual, 2))
st.markdown("#### Set Parameters", unsafe_allow_html=True)
columns = st.columns(4)
if "updated_parms" not in st.session_state:
st.session_state["updated_parms"] = {
"K_updated": 0,
"b_updated": 0,
"a_updated": 0,
"x0_updated": 0,
}
with columns[0]:
st.session_state["updated_parms"]["K_updated"] = st.number_input(
"K", key="K", format="%0.5f"
)
with columns[1]:
st.session_state["updated_parms"]["b_updated"] = st.number_input(
"b", key="b", format="%0.5f"
)
with columns[2]:
st.session_state["updated_parms"]["a_updated"] = st.number_input(
"a", key="a", step=0.0001, format="%0.5f"
)
with columns[3]:
st.session_state["updated_parms"]["x0_updated"] = st.number_input(
"x0", key="x0", format="%0.5f"
)
# st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = (
# st.session_state["updated_parms"]["K_updated"]
# )
# st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = (
# st.session_state["updated_parms"]["b_updated"]
# )
# st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = (
# st.session_state["updated_parms"]["a_updated"]
# )
# st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = (
# st.session_state["updated_parms"]["x0_updated"]
# )
update_col, reset_col = st.columns([1, 1])
if update_col.button(
"Update Parameters",
on_click=update_response_curve,
args=(
st.session_state["updated_parms"]["K_updated"],
st.session_state["updated_parms"]["b_updated"],
st.session_state["updated_parms"]["a_updated"],
st.session_state["updated_parms"]["x0_updated"],
metrics_selected,
panel_selected,
selected_channel_name,
),
use_container_width=True,
):
st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[
"updated_parms"
]["K_updated"]
st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[
"updated_parms"
]["b_updated"]
st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[
"updated_parms"
]["a_updated"]
st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[
"updated_parms"
]["x0_updated"]
reset_col.button(
"Reset Parameters",
on_click=reset_curve_parameters,
args=(metrics_selected, panel_selected, selected_channel_name),
use_container_width=True,
)
st.divider()
save_col, down_col = st.columns([1, 1])
with save_col:
file_name = st.text_input(
"rcs download file name",
key="file_name_input",
placeholder="File name",
label_visibility="collapsed",
)
down_col.download_button(
label="Download response curves",
data=json.dumps(rcs),
file_name=f"{file_name}.json",
mime="application/json",
disabled=len(file_name) == 0,
use_container_width=True,
)
def s_curve_derivative(x, K, b, a, x0):
# Derivative of the S-curve function
return (
a
* b
* K
* np.exp(-a * (x - x0))
/ ((1 + b * np.exp(-a * (x - x0))) ** 2)
)
# Parameters of the S-curve
K = st.session_state["K"]
b = st.session_state["b"]
a = st.session_state["a"]
x0 = st.session_state["x0"]
# # Optimized spend value obtained from the tool
# optimized_spend = st.number_input(
# "value of x"
# ) # Replace this with your optimized spend value
# # Calculate the slope at the optimized spend value
# slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0)
# st.write("Slope ", slope_at_optimized_spend)
# Initialize a list to hold our rows
rows = []
# Iterate over the dictionary
for key, value in st.session_state["update_rcs"].items():
# Split the key into its components
metrics, panel, channel_name = key.split("#@")
# Create a new row with the components and the values
row = {
"Metrics": name_formating(metrics),
"Panel": name_formating(panel),
"Channel Name": channel_name,
"K": value["K"],
"b": value["b"],
"a": value["a"],
"x0": value["x0"],
}
# Append the row to our list
rows.append(row)
# Convert the list of rows into a DataFrame
updated_parms_df = pd.DataFrame(rows)
if len(list(st.session_state["update_rcs"].keys())) > 0:
st.markdown("#### Updated Parameters", unsafe_allow_html=True)
st.dataframe(updated_parms_df, hide_index=True)
else:
st.info("No parameters are updated")
update_db("8_Build_Response_Curves.py")