RFI / pages /8_Response_Curves.py
Manoj
first commit
9938325
import streamlit as st
st.set_page_config(
page_title="Response Curves",
page_icon="⚖️",
layout="wide",
initial_sidebar_state="collapsed",
)
import os
import glob
import json
import sqlite3
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import r2_score
from utilities import project_selection, initialize_data, set_header, load_local_css
from utilities import (
get_panels_names,
get_metrics_names,
name_formating,
load_json_files,
generate_rcs_data,
)
# Styling
load_local_css("styles.css")
set_header()
# Create project_dct
if "project_dct" not in st.session_state:
project_selection()
st.stop()
database_file = r"DB\User.db"
conn = sqlite3.connect(
database_file, check_same_thread=False
) # connection with sql db
c = conn.cursor()
# Display project info
col_project_data = st.columns([2, 1])
with col_project_data[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with col_project_data[1]:
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
# Page Title
st.title("Response Curves")
# Function to build s curve
def s_curve(x, K, b, a, x0):
return K / (1 + b * np.exp(-a * (x - x0)))
# Function to update the RCS parameters in the modified JSON data
def modify_rcs_parameters(metrics_selected, panel_selected, channel_selected):
# Define unique keys for each parameter based on the selection
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
# Retrieve the updated parameters from session state
K_updated, b_updated, a_updated, x0_updated = (
st.session_state[K_key],
st.session_state[b_key],
st.session_state[a_key],
st.session_state[x0_key],
)
# Load the existing modified RCS data
modified_json_file_path = os.path.join(
st.session_state["project_path"], "rcs_data_modified.json"
)
try:
with open(modified_json_file_path, "r") as json_file:
rcs_data_modified = json.load(json_file)
except:
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️")
return
# Update the RCS parameters for the selected metric and panel
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
"K": K_updated,
"b": b_updated,
"a": a_updated,
"x0": x0_updated,
}
# Save the updated data back to the JSON file
try:
with open(modified_json_file_path, "w") as json_file:
json.dump(rcs_data_modified, json_file, indent=4)
except:
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️")
return
# Function to reset the parameters to their default values
def reset_parameters(metrics_selected, panel_selected, channel_selected):
# Define the path to the JSON files
original_json_file_path = os.path.join(
st.session_state["project_path"], "rcs_data_original.json"
)
try:
# Open and load original RCS data
with open(original_json_file_path, "rb") as original_json_file:
rcs_data_original = json.load(original_json_file)
original_channel_data = rcs_data_original[metrics_selected][panel_selected][
channel_selected
]
except:
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️")
return
# Define unique keys for each parameter based on the selection
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
# Reset session state values to original data
del st.session_state[K_key]
del st.session_state[b_key]
del st.session_state[a_key]
del st.session_state[x0_key]
# Reset the modified JSON file with original parameters
modified_json_file_path = os.path.join(
st.session_state["project_path"], "rcs_data_modified.json"
)
try:
with open(modified_json_file_path, "r") as json_file:
rcs_data_modified = json.load(json_file)
except:
rcs_data_modified = {}
# Update the parameters in the modified data to the original values
rcs_data_modified[metrics_selected][panel_selected][channel_selected] = {
"K": original_channel_data["K"],
"b": original_channel_data["b"],
"a": original_channel_data["a"],
"x0": original_channel_data["x0"],
}
# Save the reset data back to the JSON file
try:
with open(modified_json_file_path, "w") as json_file:
json.dump(rcs_data_modified, json_file, indent=4)
except:
st.toast("Failed to Load/Update. Tool reset to default settings.", icon="⚠️")
return
@st.cache_resource(show_spinner=False)
def updated_parm_gen(original_data, modified_data, metrics_selected, panel_selected):
# Retrieve the data for the selected metric and panel
original_data_selection = original_data[metrics_selected][panel_selected]
modified_data_selection = modified_data[metrics_selected][panel_selected]
# Initialize an empty list to hold the data for the DataFrame
data = []
# Iterate through each channel in the selected metric and panel
for channel in original_data_selection:
# Extract original parameters
K_o, b_o, a_o, x0_o = (
original_data_selection[channel]["K"],
original_data_selection[channel]["b"],
original_data_selection[channel]["a"],
original_data_selection[channel]["x0"],
)
# Extract modified parameters
K_m, b_m, a_m, x0_m = (
modified_data_selection[channel]["K"],
modified_data_selection[channel]["b"],
modified_data_selection[channel]["a"],
modified_data_selection[channel]["x0"],
)
# Check if any parameters differ
if (K_o != K_m) or (b_o != b_m) or (a_o != a_m) or (x0_o != x0_m):
# Append the data to the list only if there is a difference
data.append(
{
"Metric": name_formating(metrics_selected),
"Panel": name_formating(panel_selected),
"Channel": name_formating(channel),
"K (Original)": K_o,
"b (Original)": b_o,
"a (Original)": a_o,
"x0 (Original)": x0_o,
"K (Modified)": K_m,
"b (Modified)": b_m,
"a (Modified)": a_m,
"x0 (Modified)": x0_m,
}
)
# Create a DataFrame from the collected data
df = pd.DataFrame(data)
return df
# Define the directory where the metrics data is located
directory = os.path.join(st.session_state["project_path"], "metrics_level_data")
# Retrieve the list of all metric names from the specified directory
metrics_list = get_metrics_names(directory)
# Check if there are any metrics available in the metrics list
if len(metrics_list) == 0:
# Display a warning message to the user if no metrics are found
st.warning(
"Please tune at least one model to generate response curves data.",
icon="⚠️",
)
# Stop further execution as there is no data to process
st.stop()
# Widget columns
metric_col, channel_col, panel_col = st.columns(3)
# Metrics Selection
metrics_selected = metric_col.selectbox(
"Response Metrics",
sorted(metrics_list),
format_func=name_formating,
key="response_metrics_selectbox",
index=0,
)
# Retrieve the list of all panel names for specified Metrics
file_selected = f"metrics_level_data/data_test_overview_panel@#{metrics_selected}.xlsx"
file_selected_path = os.path.join(st.session_state["project_path"], file_selected)
panel_list = get_panels_names(file_selected_path)
# Panel Selection
panel_selected = panel_col.selectbox(
"Panel",
sorted(panel_list),
key="panel_selected_selectbox",
index=0,
)
# Define the path to the JSON files
original_json_file_path = os.path.join(
st.session_state["project_path"], "rcs_data_original.json"
)
modified_json_file_path = os.path.join(
st.session_state["project_path"], "rcs_data_modified.json"
)
# Check if the RCS JSON file does not exist
if not os.path.exists(original_json_file_path) or not os.path.exists(
modified_json_file_path
):
print(
f"RCS JSON file does not exist at {original_json_file_path}. Generating new RCS data..."
)
generate_rcs_data(original_json_file_path, modified_json_file_path)
else:
print(
f"RCS JSON file already exists at {original_json_file_path}. No need to generate new RCS data."
)
# Load JSON files if they exist
original_data, modified_data = load_json_files(
original_json_file_path, modified_json_file_path
)
# Retrieve the list of all channels names for specified Metrics and Panel
chanel_list_final = list(original_data[metrics_selected][panel_selected].keys())
# Channel Selection
channel_selected = channel_col.selectbox(
"Channel",
sorted(chanel_list_final),
format_func=name_formating,
key="selected_channel_name_selectbox",
)
# Extract original channel data for the selected metric, panel, and channel
original_channel_data = original_data[metrics_selected][panel_selected][
channel_selected
]
# Extract modified channel data for the same metric, panel, and channel
modified_channel_data = modified_data[metrics_selected][panel_selected][
channel_selected
]
# X and Y values for plotting
x = original_channel_data["x"]
y = original_channel_data["y"]
# Scaling factor for X values and range for S-curve plotting
power = original_channel_data["power"]
x_plot = original_channel_data["x_plot"]
# Original S-curve parameters
K_orig = original_channel_data["K"]
b_orig = original_channel_data["b"]
a_orig = original_channel_data["a"]
x0_orig = original_channel_data["x0"]
# Modified S-curve parameters (user-adjusted)
K_mod = modified_channel_data["K"]
b_mod = modified_channel_data["b"]
a_mod = modified_channel_data["a"]
x0_mod = modified_channel_data["x0"]
# Create a scatter plot for the original data points
fig = px.scatter(
x=x,
y=y,
title="Original and Modified S-Curve Plot",
labels={"x": "Spends", "y": name_formating(metrics_selected)},
)
# Add the modified S-curve trace
fig.add_trace(
go.Scatter(
x=x_plot,
y=s_curve(
np.array(x_plot) / 10**power,
K_mod,
b_mod,
a_mod,
x0_mod,
),
line=dict(color="red"),
name="Modified",
),
)
# Add the original S-curve trace
fig.add_trace(
go.Scatter(
x=x_plot,
y=s_curve(
np.array(x_plot) / 10**power,
K_orig,
b_orig,
a_orig,
x0_orig,
),
line=dict(color="rgba(0, 255, 0, 0.6)"), # Semi-transparent green
name="Original",
),
)
# Customize the layout of the plot
fig.update_layout(
title="Comparison of Original and Modified S-Curves",
xaxis_title="Input (Clicks, Impressions, etc..)",
yaxis_title=name_formating(metrics_selected),
legend_title="Curve Type",
)
# Display s-curve
st.plotly_chart(fig, use_container_width=True)
# Calculate R² for the original curve
y_orig_pred = s_curve(np.array(x) / 10**power, K_orig, b_orig, a_orig, x0_orig)
r2_orig = r2_score(y, y_orig_pred)
# Calculate R² for the modified curve
y_mod_pred = s_curve(np.array(x) / 10**power, K_mod, b_mod, a_mod, x0_mod)
r2_mod = r2_score(y, y_mod_pred)
# Calculate the difference in R²
r2_diff = r2_mod - r2_orig
# Display R² metrics
st.write("## R² Comparison")
r2_col = st.columns(3)
r2_col[0].metric("R² (Original)", f"{r2_orig:.2f}")
r2_col[1].metric("R² (Modified)", f"{r2_mod:.2f}")
r2_col[2].metric("Difference in R²", f"{r2_diff:.2f}")
# Define unique keys for each parameter based on the selection
K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}"
# Initialize session state keys if they do not exist
if K_key not in st.session_state:
st.session_state[K_key] = K_mod
if b_key not in st.session_state:
st.session_state[b_key] = b_mod
if a_key not in st.session_state:
st.session_state[a_key] = a_mod
if x0_key not in st.session_state:
st.session_state[x0_key] = x0_mod
# RCS parameters input
rsc_ip_col = st.columns(4)
with rsc_ip_col[0]:
K_updated = st.number_input(
"K",
step=0.001,
min_value=0.0000,
format="%.4f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=K_key,
)
with rsc_ip_col[1]:
b_updated = st.number_input(
"b",
step=0.001,
min_value=0.0000,
format="%.4f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=b_key,
)
with rsc_ip_col[2]:
a_updated = st.number_input(
"a",
step=0.001,
min_value=0.0000,
format="%.4f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=a_key,
)
with rsc_ip_col[3]:
x0_updated = st.number_input(
"x0",
step=0.001,
min_value=0.0000,
format="%.4f",
on_change=modify_rcs_parameters,
args=(metrics_selected, panel_selected, channel_selected),
key=x0_key,
)
# Create columns for Reset and Download buttons
reset_download_col = st.columns(2)
with reset_download_col[0]:
if st.button(
"Reset",
use_container_width=True,
):
reset_parameters(metrics_selected, panel_selected, channel_selected)
st.rerun()
with reset_download_col[1]:
# Provide a download button for the modified RCS data
try:
with open(modified_json_file_path, "r") as file:
st.download_button(
label="Download",
data=file,
file_name=f"{name_formating(metrics_selected)}_{name_formating(panel_selected)}_rcs_data.json",
mime="application/json",
use_container_width=True,
)
except:
pass
# Generate the DataFrame showing only non-matching parameters
updated_parm_df = updated_parm_gen(
original_data, modified_data, metrics_selected, panel_selected
)
# Display the DataFrame or show an informational message if no updates
if not updated_parm_df.empty:
st.write("## Parameter Comparison for Selected Metric and Panel")
st.dataframe(updated_parm_df, hide_index=True)
else:
st.info("No parameters are updated for the selected Metric and Panel")