RFI / pages /2_Data_Validation_and_Insights.py
Manoj
latest
fde220d
import streamlit as st
import pandas as pd
from Eda_functions import *
import numpy as np
import pickle
import streamlit as st
from utilities import set_header, load_local_css,update_db,project_selection
import os
import tempfile
import sqlite3
from utilities import update_db
import re
st.set_page_config(
page_title="Data Validation",
page_icon=":shark:",
layout="wide",
initial_sidebar_state="collapsed",
)
load_local_css("styles.css")
set_header()
if 'username' not in st.session_state:
st.session_state['username']=None
if "project_name" not in st.session_state:
st.session_state["project_name"] = None
if "project_dct" not in st.session_state:
project_selection()
if "project_path" not in st.session_state:
st.stop()
if 'username' in st.session_state and st.session_state['username'] is not None:
data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")
try:
with open(data_path, "rb") as f:
data = pickle.load(f)
except Exception as e:
st.error(f"Please import data from the Data Import Page")
st.stop()
conn = sqlite3.connect(r"DB\User.db", check_same_thread=False) # connection with sql db
c = conn.cursor()
st.session_state["cleaned_data"] = data["final_df"]
st.session_state["category_dict"] = data["bin_dict"]
# st.write(st.session_state['category_dict'])
cols1 = st.columns([2, 1])
with cols1[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with cols1[1]:
st.markdown(
f"**Current Project: {st.session_state['project_name']}**"
)
st.title("Data Validation and Insights")
target_variables = [
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Response Metrics"
]
def format_display(inp):
return inp.title().replace("_", " ").strip()
target_variables = list(*target_variables)
target_column = st.selectbox(
"Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
target_variables,
index=st.session_state["project_dct"]["data_validation"]["target_column"],
format_func=format_display,
)
st.session_state["project_dct"]["data_validation"]["target_column"] = (
target_variables.index(target_column)
)
st.session_state["target_column"] = target_column
#panels = st.session_state["category_dict"]["Panel Level 1"][0]
if 'Panel_1' not in st.session_state["cleaned_data"].columns:
st.session_state["cleaned_data"]['Panel_1']=['1']*len(st.session_state["cleaned_data"])
panels= st.session_state["cleaned_data"]['Panel_1']
disable=True
else:
disable=False
selected_panels = st.multiselect(
"Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
st.session_state["cleaned_data"]["Panel_1"].unique(),
default=st.session_state["project_dct"]["data_validation"]["selected_panels"],
disabled=disable
)
st.session_state["project_dct"]["data_validation"]["selected_panels"] = selected_panels
aggregation_dict = {
item: "sum" if key == "Media" else "mean"
for key, value in st.session_state["category_dict"].items()
for item in value
if item not in ["date", "Panel_1"]
}
aggregation_dict = {key: value for key, value in aggregation_dict.items() if key in st.session_state["cleaned_data"].columns}
with st.expander("**Reponse Metric Analysis**"):
if len(selected_panels) > 0:
st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][
st.session_state["cleaned_data"]["Panel_1"].isin(selected_panels)
]
st.session_state["Cleaned_data_panel"] = (
st.session_state["Cleaned_data_panel"]
.groupby(by="date")
.agg(aggregation_dict)
)
st.session_state["Cleaned_data_panel"] = st.session_state[
"Cleaned_data_panel"
].reset_index()
else:
# st.write(st.session_state['cleaned_data'])
st.session_state["Cleaned_data_panel"] = (
st.session_state["cleaned_data"].groupby(by="date").agg(aggregation_dict)
)
st.session_state["Cleaned_data_panel"] = st.session_state[
"Cleaned_data_panel"
].reset_index()
fig = line_plot_target(
st.session_state["Cleaned_data_panel"],
target=target_column,
title=f"{target_column} Over Time",
)
st.plotly_chart(fig, use_container_width=True)
media_channel = list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Media"
]
)
spends_features= list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Spends"
]
)
# st.write(media_channel)
exo_var = list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Exogenous"
]
)
internal_var = list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Internal"
]
)
Non_media_variables = exo_var + internal_var
st.markdown("### Annual Data Summary")
summary_df = summary(
st.session_state["Cleaned_data_panel"],
media_channel + [target_column]+spends_features,
spends=None,
Target=True,
)
st.dataframe(
summary_df.sort_index(axis=1),
use_container_width=True,
)
if st.checkbox("Show raw data"):
st.cache_resource(show_spinner=False)
def raw_df_gen():
# Convert 'date' to datetime but do not convert to string yet for sorting
dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"])
# Concatenate the dates with other numeric columns formatted
raw_df = pd.concat(
[
dates,
st.session_state["Cleaned_data_panel"]
.select_dtypes(np.number)
.applymap(format_numbers),
],
axis=1,
)
# Now sort raw_df by the 'date' column, which is still in datetime format
sorted_raw_df = raw_df.sort_values(by="date", ascending=True)
# After sorting, convert 'date' to string format for display
sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y")
return sorted_raw_df
# Display the sorted DataFrame in Streamlit
st.dataframe(raw_df_gen())
col1 = st.columns(1)
if "selected_feature" not in st.session_state:
st.session_state["selected_feature"] = None
# st.warning('Work in Progress')
with st.expander("Media Variables Analysis"):
# Get the selected feature
st.session_state["selected_feature"] = st.selectbox(
"Select media", media_channel+spends_features, format_func=format_display
)
# st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
# )
# Filter spends features based on the selected feature
spends_col= st.columns(2)
spends_feature = [
col
for col in spends_features
if re.split(r"_cost|_spend", col.lower())[0]
in st.session_state["selected_feature"]
]
with spends_col[0]:
if len(spends_feature) == 0:
st.warning("No 'spends' variable available for the selected metric in the data. Please ensure the columns are properly named.or select them in the provided selecttion box")
else:
st.write(f'Selected "{spends_feature[0]}" as the corresponding spends variable automatically. If this is incorrect, please click the checkbox to change the variable.')
with spends_col[1]:
if len(spends_feature)==0 or st.checkbox('Select "Spends" variable for CPM and CPC calculation'):
spends_feature=[st.selectbox('Spends variable',spends_features)]
if "validation" not in st.session_state:
st.session_state["validation"] = st.session_state["project_dct"][
"data_validation"
]["validated_variables"]
val_variables = [col for col in media_channel if col != "date"]
if not set(
st.session_state["project_dct"]["data_validation"]["validated_variables"]
).issubset(set(val_variables)):
st.session_state["validation"] = []
else:
fig_row1 = line_plot(
st.session_state["Cleaned_data_panel"],
x_col="date",
y1_cols=[st.session_state["selected_feature"]],
y2_cols=[target_column],
title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
)
st.plotly_chart(fig_row1, use_container_width=True)
st.markdown("### Summary")
st.dataframe(
summary(
st.session_state["cleaned_data"],
[st.session_state["selected_feature"]],
spends=spends_feature[0],
),
use_container_width=True,
)
cols2 = st.columns(2)
if len(set(st.session_state["validation"]).intersection(val_variables)) == len(
val_variables
):
disable = True
help = "All media variables are validated"
else:
disable = False
help = ""
with cols2[0]:
if st.button("Validate", disabled=disable, help=help):
st.session_state["validation"].append(
st.session_state["selected_feature"]
)
with cols2[1]:
if st.checkbox("Validate all", disabled=disable, help=help):
st.session_state["validation"].extend(val_variables)
st.success("All media variables are validated ✅")
if len(set(st.session_state["validation"]).intersection(val_variables)) != len(
val_variables
):
validation_data = pd.DataFrame(
{
"Validate": [
(True if col in st.session_state["validation"] else False)
for col in val_variables
],
"Variables": val_variables,
}
)
sorted_validation_df = validation_data.sort_values(
by="Variables", ascending=True, na_position="first"
)
cols3 = st.columns([1, 30])
with cols3[1]:
validation_df = st.data_editor(
sorted_validation_df,
# column_config={
# 'Validate':st.column_config.CheckboxColumn(wi)
# },
column_config={
"Validate": st.column_config.CheckboxColumn(
default=False,
width=100,
),
"Variables": st.column_config.TextColumn(width=1000),
},
hide_index=True,
)
selected_rows = validation_df[validation_df["Validate"] == True][
"Variables"
]
# st.write(selected_rows)
st.session_state["validation"].extend(selected_rows)
st.session_state["project_dct"]["data_validation"][
"validated_variables"
] = st.session_state["validation"]
not_validated_variables = [
col
for col in val_variables
if col not in st.session_state["validation"]
]
if not_validated_variables:
not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
st.warning(not_validated_message)
with st.expander("Non Media Variables Analysis"):
if len(Non_media_variables)==0:
st.warning('Non media variables not present')
else:
selected_columns_row4 = st.selectbox(
"Select Channel",
Non_media_variables,
format_func=format_display,
index=st.session_state["project_dct"]["data_validation"]["Non_media_variables"],
)
st.session_state["project_dct"]["data_validation"]["Non_media_variables"] = (
Non_media_variables.index(selected_columns_row4)
)
# # Create the dual-axis line plot
fig_row4 = line_plot(
st.session_state["Cleaned_data_panel"],
x_col="date",
y1_cols=[selected_columns_row4],
y2_cols=[target_column],
title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
)
st.plotly_chart(fig_row4, use_container_width=True)
selected_non_media = selected_columns_row4
sum_df = st.session_state["Cleaned_data_panel"][
["date", selected_non_media, target_column]
]
sum_df["Year"] = pd.to_datetime(
st.session_state["Cleaned_data_panel"]["date"]
).dt.year
# st.dataframe(df)
# st.dataframe(sum_df.head(2))
print(sum_df)
sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
sum_df.loc["Grand Total"] = sum_df.sum()
sum_df = sum_df.applymap(format_numbers)
sum_df.fillna("-", inplace=True)
sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
st.markdown("### Summary")
st.dataframe(sum_df, use_container_width=True)
# with st.expander('Interactive Dashboard'):
# pygg_app=StreamlitRenderer(st.session_state['cleaned_data'])
# pygg_app.explorer()
with st.expander("Correlation Analysis"):
options = list(
st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
)
selected_options = st.multiselect(
"Select Variables For correlation plot",
[var for var in options if var != target_column],
default=options[3],
)
st.pyplot(
correlation_plot(
st.session_state["Cleaned_data_panel"],
selected_options,
target_column,
)
)
if st.button("Save Changes", use_container_width=True):
update_db("2_Data_Validation.py")
project_dct_path = os.path.join(st.session_state["project_path"], "project_dct.pkl")
with open(project_dct_path, "wb") as f:
pickle.dump(st.session_state["project_dct"], f)
st.success("Changes saved")