|
import streamlit as st |
|
|
|
st.set_page_config( |
|
page_title="Response Curves", |
|
page_icon="⚖️", |
|
layout="wide", |
|
initial_sidebar_state="collapsed", |
|
) |
|
|
|
import os |
|
import glob |
|
import json |
|
import sqlite3 |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from sklearn.metrics import r2_score |
|
from utilities import project_selection, initialize_data, set_header, load_local_css |
|
from utilities import ( |
|
get_panels_names, |
|
get_metrics_names, |
|
name_formating, |
|
load_json_files, |
|
generate_rcs_data, |
|
) |
|
|
|
|
|
load_local_css("styles.css") |
|
set_header() |
|
|
|
|
|
if "project_dct" not in st.session_state: |
|
|
|
project_selection() |
|
st.stop() |
|
|
|
database_file = r"DB\User.db" |
|
|
|
conn = sqlite3.connect( |
|
database_file, check_same_thread=False |
|
) |
|
c = conn.cursor() |
|
|
|
|
|
col_project_data = st.columns([2, 1]) |
|
with col_project_data[0]: |
|
st.markdown(f"**Welcome {st.session_state['username']}**") |
|
with col_project_data[1]: |
|
st.markdown(f"**Current Project: {st.session_state['project_name']}**") |
|
|
|
|
|
st.title("Response Curves") |
|
|
|
|
|
|
|
def s_curve(x, K, b, a, x0): |
|
return K / (1 + b * np.exp(-a * (x - x0))) |
|
|
|
|
|
|
|
def modify_rcs_parameters(metrics_selected, panel_selected, channel_selected): |
|
|
|
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
|
|
|
|
K_updated, b_updated, a_updated, x0_updated = ( |
|
st.session_state[K_key], |
|
st.session_state[b_key], |
|
st.session_state[a_key], |
|
st.session_state[x0_key], |
|
) |
|
|
|
|
|
modified_json_file_path = os.path.join( |
|
st.session_state["project_path"], "rcs_data_modified.json" |
|
) |
|
try: |
|
with open(modified_json_file_path, "r") as json_file: |
|
rcs_data_modified = json.load(json_file) |
|
except: |
|
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️") |
|
return |
|
|
|
|
|
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = { |
|
"K": K_updated, |
|
"b": b_updated, |
|
"a": a_updated, |
|
"x0": x0_updated, |
|
} |
|
|
|
|
|
try: |
|
with open(modified_json_file_path, "w") as json_file: |
|
json.dump(rcs_data_modified, json_file, indent=4) |
|
except: |
|
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️") |
|
return |
|
|
|
|
|
|
|
def reset_parameters(metrics_selected, panel_selected, channel_selected): |
|
|
|
original_json_file_path = os.path.join( |
|
st.session_state["project_path"], "rcs_data_original.json" |
|
) |
|
try: |
|
|
|
with open(original_json_file_path, "rb") as original_json_file: |
|
rcs_data_original = json.load(original_json_file) |
|
original_channel_data = rcs_data_original[metrics_selected][panel_selected][ |
|
channel_selected |
|
] |
|
except: |
|
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️") |
|
return |
|
|
|
|
|
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
|
|
|
|
del st.session_state[K_key] |
|
del st.session_state[b_key] |
|
del st.session_state[a_key] |
|
del st.session_state[x0_key] |
|
|
|
|
|
modified_json_file_path = os.path.join( |
|
st.session_state["project_path"], "rcs_data_modified.json" |
|
) |
|
try: |
|
with open(modified_json_file_path, "r") as json_file: |
|
rcs_data_modified = json.load(json_file) |
|
except: |
|
rcs_data_modified = {} |
|
|
|
|
|
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = { |
|
"K": original_channel_data["K"], |
|
"b": original_channel_data["b"], |
|
"a": original_channel_data["a"], |
|
"x0": original_channel_data["x0"], |
|
} |
|
|
|
|
|
try: |
|
with open(modified_json_file_path, "w") as json_file: |
|
json.dump(rcs_data_modified, json_file, indent=4) |
|
except: |
|
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️") |
|
return |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def updated_parm_gen(original_data, modified_data, metrics_selected, panel_selected): |
|
|
|
original_data_selection = original_data[metrics_selected][panel_selected] |
|
modified_data_selection = modified_data[metrics_selected][panel_selected] |
|
|
|
|
|
data = [] |
|
|
|
|
|
for channel in original_data_selection: |
|
|
|
K_o, b_o, a_o, x0_o = ( |
|
original_data_selection[channel]["K"], |
|
original_data_selection[channel]["b"], |
|
original_data_selection[channel]["a"], |
|
original_data_selection[channel]["x0"], |
|
) |
|
|
|
K_m, b_m, a_m, x0_m = ( |
|
modified_data_selection[channel]["K"], |
|
modified_data_selection[channel]["b"], |
|
modified_data_selection[channel]["a"], |
|
modified_data_selection[channel]["x0"], |
|
) |
|
|
|
|
|
if (K_o != K_m) or (b_o != b_m) or (a_o != a_m) or (x0_o != x0_m): |
|
|
|
data.append( |
|
{ |
|
"Metric": name_formating(metrics_selected), |
|
"Panel": name_formating(panel_selected), |
|
"Channel": name_formating(channel), |
|
"K (Original)": K_o, |
|
"b (Original)": b_o, |
|
"a (Original)": a_o, |
|
"x0 (Original)": x0_o, |
|
"K (Modified)": K_m, |
|
"b (Modified)": b_m, |
|
"a (Modified)": a_m, |
|
"x0 (Modified)": x0_m, |
|
} |
|
) |
|
|
|
|
|
df = pd.DataFrame(data) |
|
|
|
return df |
|
|
|
|
|
|
|
directory = os.path.join(st.session_state["project_path"], "metrics_level_data") |
|
|
|
|
|
metrics_list = get_metrics_names(directory) |
|
|
|
|
|
if len(metrics_list) == 0: |
|
|
|
st.warning( |
|
"Please tune at least one model to generate response curves data.", |
|
icon="⚠️", |
|
) |
|
|
|
st.stop() |
|
|
|
|
|
metric_col, channel_col, panel_col = st.columns(3) |
|
|
|
|
|
metrics_selected = metric_col.selectbox( |
|
"Response Metrics", |
|
sorted(metrics_list), |
|
format_func=name_formating, |
|
key="response_metrics_selectbox", |
|
index=0, |
|
) |
|
|
|
|
|
|
|
file_selected = f"metrics_level_data/data_test_overview_panel@#{metrics_selected}.xlsx" |
|
file_selected_path = os.path.join(st.session_state["project_path"], file_selected) |
|
panel_list = get_panels_names(file_selected_path) |
|
|
|
|
|
panel_selected = panel_col.selectbox( |
|
"Panel", |
|
sorted(panel_list), |
|
key="panel_selected_selectbox", |
|
index=0, |
|
) |
|
|
|
|
|
original_json_file_path = os.path.join( |
|
st.session_state["project_path"], "rcs_data_original.json" |
|
) |
|
modified_json_file_path = os.path.join( |
|
st.session_state["project_path"], "rcs_data_modified.json" |
|
) |
|
|
|
|
|
if not os.path.exists(original_json_file_path) or not os.path.exists( |
|
modified_json_file_path |
|
): |
|
print( |
|
f"RCS JSON file does not exist at {original_json_file_path}. Generating new RCS data..." |
|
) |
|
generate_rcs_data(original_json_file_path, modified_json_file_path) |
|
else: |
|
print( |
|
f"RCS JSON file already exists at {original_json_file_path}. No need to generate new RCS data." |
|
) |
|
|
|
|
|
original_data, modified_data = load_json_files( |
|
original_json_file_path, modified_json_file_path |
|
) |
|
|
|
|
|
chanel_list_final = list(original_data[metrics_selected][panel_selected].keys()) |
|
|
|
|
|
channel_selected = channel_col.selectbox( |
|
"Channel", |
|
sorted(chanel_list_final), |
|
format_func=name_formating, |
|
key="selected_channel_name_selectbox", |
|
) |
|
|
|
|
|
original_channel_data = original_data[metrics_selected][panel_selected][ |
|
channel_selected |
|
] |
|
|
|
|
|
modified_channel_data = modified_data[metrics_selected][panel_selected][ |
|
channel_selected |
|
] |
|
|
|
|
|
x = original_channel_data["x"] |
|
y = original_channel_data["y"] |
|
|
|
|
|
power = original_channel_data["power"] |
|
x_plot = original_channel_data["x_plot"] |
|
|
|
|
|
K_orig = original_channel_data["K"] |
|
b_orig = original_channel_data["b"] |
|
a_orig = original_channel_data["a"] |
|
x0_orig = original_channel_data["x0"] |
|
|
|
|
|
K_mod = modified_channel_data["K"] |
|
b_mod = modified_channel_data["b"] |
|
a_mod = modified_channel_data["a"] |
|
x0_mod = modified_channel_data["x0"] |
|
|
|
|
|
fig = px.scatter( |
|
x=x, |
|
y=y, |
|
title="Original and Modified S-Curve Plot", |
|
labels={"x": "Spends", "y": name_formating(metrics_selected)}, |
|
) |
|
|
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=x_plot, |
|
y=s_curve( |
|
np.array(x_plot) / 10**power, |
|
K_mod, |
|
b_mod, |
|
a_mod, |
|
x0_mod, |
|
), |
|
line=dict(color="red"), |
|
name="Modified", |
|
), |
|
) |
|
|
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=x_plot, |
|
y=s_curve( |
|
np.array(x_plot) / 10**power, |
|
K_orig, |
|
b_orig, |
|
a_orig, |
|
x0_orig, |
|
), |
|
line=dict(color="rgba(0, 255, 0, 0.6)"), |
|
name="Original", |
|
), |
|
) |
|
|
|
|
|
fig.update_layout( |
|
title="Comparison of Original and Modified S-Curves", |
|
xaxis_title="Input (Clicks, Impressions, etc..)", |
|
yaxis_title=name_formating(metrics_selected), |
|
legend_title="Curve Type", |
|
) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
y_orig_pred = s_curve(np.array(x) / 10**power, K_orig, b_orig, a_orig, x0_orig) |
|
r2_orig = r2_score(y, y_orig_pred) |
|
|
|
|
|
y_mod_pred = s_curve(np.array(x) / 10**power, K_mod, b_mod, a_mod, x0_mod) |
|
r2_mod = r2_score(y, y_mod_pred) |
|
|
|
|
|
r2_diff = r2_mod - r2_orig |
|
|
|
|
|
st.write("## R² Comparison") |
|
r2_col = st.columns(3) |
|
|
|
r2_col[0].metric("R² (Original)", f"{r2_orig:.2f}") |
|
r2_col[1].metric("R² (Modified)", f"{r2_mod:.2f}") |
|
r2_col[2].metric("Difference in R²", f"{r2_diff:.2f}") |
|
|
|
|
|
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" |
|
|
|
|
|
if K_key not in st.session_state: |
|
st.session_state[K_key] = K_mod |
|
if b_key not in st.session_state: |
|
st.session_state[b_key] = b_mod |
|
if a_key not in st.session_state: |
|
st.session_state[a_key] = a_mod |
|
if x0_key not in st.session_state: |
|
st.session_state[x0_key] = x0_mod |
|
|
|
|
|
rsc_ip_col = st.columns(4) |
|
with rsc_ip_col[0]: |
|
K_updated = st.number_input( |
|
"K", |
|
step=0.001, |
|
min_value=0.0000, |
|
format="%.4f", |
|
on_change=modify_rcs_parameters, |
|
args=(metrics_selected, panel_selected, channel_selected), |
|
key=K_key, |
|
) |
|
with rsc_ip_col[1]: |
|
b_updated = st.number_input( |
|
"b", |
|
step=0.001, |
|
min_value=0.0000, |
|
format="%.4f", |
|
on_change=modify_rcs_parameters, |
|
args=(metrics_selected, panel_selected, channel_selected), |
|
key=b_key, |
|
) |
|
with rsc_ip_col[2]: |
|
a_updated = st.number_input( |
|
"a", |
|
step=0.001, |
|
min_value=0.0000, |
|
format="%.4f", |
|
on_change=modify_rcs_parameters, |
|
args=(metrics_selected, panel_selected, channel_selected), |
|
key=a_key, |
|
) |
|
with rsc_ip_col[3]: |
|
x0_updated = st.number_input( |
|
"x0", |
|
step=0.001, |
|
min_value=0.0000, |
|
format="%.4f", |
|
on_change=modify_rcs_parameters, |
|
args=(metrics_selected, panel_selected, channel_selected), |
|
key=x0_key, |
|
) |
|
|
|
|
|
|
|
reset_download_col = st.columns(2) |
|
with reset_download_col[0]: |
|
if st.button( |
|
"Reset", |
|
use_container_width=True, |
|
): |
|
reset_parameters(metrics_selected, panel_selected, channel_selected) |
|
st.rerun() |
|
|
|
with reset_download_col[1]: |
|
|
|
try: |
|
with open(modified_json_file_path, "r") as file: |
|
st.download_button( |
|
label="Download", |
|
data=file, |
|
file_name=f"{name_formating(metrics_selected)}_{name_formating(panel_selected)}_rcs_data.json", |
|
mime="application/json", |
|
use_container_width=True, |
|
) |
|
except: |
|
pass |
|
|
|
|
|
updated_parm_df = updated_parm_gen( |
|
original_data, modified_data, metrics_selected, panel_selected |
|
) |
|
|
|
|
|
if not updated_parm_df.empty: |
|
st.write("## Parameter Comparison for Selected Metric and Panel") |
|
st.dataframe(updated_parm_df, hide_index=True) |
|
else: |
|
st.info("No parameters are updated for the selected Metric and Panel") |
|
|