Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from Eda_functions import * | |
import numpy as np | |
import pickle | |
import streamlit as st | |
import streamlit.components.v1 as components | |
import sweetviz as sv | |
from utilities import set_header, load_local_css | |
from st_aggrid import GridOptionsBuilder, GridUpdateMode | |
from st_aggrid import GridOptionsBuilder | |
from st_aggrid import AgGrid | |
import base64 | |
import os | |
import tempfile | |
# from ydata_profiling import ProfileReport | |
import re | |
# from pygwalker.api.streamlit import StreamlitRenderer | |
# from Home_redirecting import home | |
import sqlite3 | |
from utilities import update_db | |
st.set_page_config( | |
page_title="Data Validation", | |
page_icon=":shark:", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
load_local_css("styles.css") | |
set_header() | |
if "project_dct" not in st.session_state: | |
# home() | |
st.warning("Please select a project from home page") | |
st.stop() | |
data_path = os.path.join(st.session_state["project_path"], "data_import.pkl") | |
try: | |
with open(data_path, "rb") as f: | |
data = pickle.load(f) | |
except Exception as e: | |
st.error(f"Please import data from the Data Import Page") | |
st.stop() | |
conn = sqlite3.connect(r"DB\User.db", check_same_thread=False) # connection with sql db | |
c = conn.cursor() | |
st.session_state["cleaned_data"] = data["final_df"] | |
st.session_state["category_dict"] = data["bin_dict"] | |
# st.write(st.session_state['category_dict']) | |
st.title("Data Validation and Insights") | |
target_variables = [ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Response Metrics" | |
] | |
def format_display(inp): | |
return inp.title().replace("_", " ").strip() | |
target_variables = list(*target_variables) | |
target_column = st.selectbox( | |
"Select the Target Feature/Dependent Variable (will be used in all charts as reference)", | |
target_variables, | |
index=st.session_state["project_dct"]["data_validation"]["target_column"], | |
format_func=format_display, | |
) | |
st.session_state["project_dct"]["data_validation"]["target_column"] = ( | |
target_variables.index(target_column) | |
) | |
st.session_state["target_column"] = target_column | |
panels = st.session_state["category_dict"]["Panel Level 1"][0] | |
selected_panels = st.multiselect( | |
"Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.", | |
st.session_state["cleaned_data"][panels].unique(), | |
default=st.session_state["project_dct"]["data_validation"]["selected_panels"], | |
) | |
st.session_state["project_dct"]["data_validation"]["selected_panels"] = selected_panels | |
aggregation_dict = { | |
item: "sum" if key == "Media" else "mean" | |
for key, value in st.session_state["category_dict"].items() | |
for item in value | |
if item not in ["date", "Panel_1"] | |
} | |
with st.expander("**Reponse Metric Analysis**"): | |
if len(selected_panels) > 0: | |
st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][ | |
st.session_state["cleaned_data"]["Panel_1"].isin(selected_panels) | |
] | |
st.session_state["Cleaned_data_panel"] = ( | |
st.session_state["Cleaned_data_panel"] | |
.groupby(by="date") | |
.agg(aggregation_dict) | |
) | |
st.session_state["Cleaned_data_panel"] = st.session_state[ | |
"Cleaned_data_panel" | |
].reset_index() | |
else: | |
# st.write(st.session_state['cleaned_data']) | |
st.session_state["Cleaned_data_panel"] = ( | |
st.session_state["cleaned_data"].groupby(by="date").agg(aggregation_dict) | |
) | |
st.session_state["Cleaned_data_panel"] = st.session_state[ | |
"Cleaned_data_panel" | |
].reset_index() | |
fig = line_plot_target( | |
st.session_state["Cleaned_data_panel"], | |
target=target_column, | |
title=f"{target_column} Over Time", | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
media_channel = list( | |
*[ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Media" | |
] | |
) | |
# st.write(media_channel) | |
exo_var = list( | |
*[ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Exogenous" | |
] | |
) | |
internal_var = list( | |
*[ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Internal" | |
] | |
) | |
Non_media_variables = exo_var + internal_var | |
st.markdown("### Annual Data Summary") | |
summary_df = summary( | |
st.session_state["Cleaned_data_panel"], | |
media_channel + [target_column], | |
spends=None, | |
Target=True, | |
) | |
st.dataframe( | |
summary_df, | |
use_container_width=True, | |
) | |
if st.checkbox("Show raw data"): | |
st.cache_resource(show_spinner=False) | |
def raw_df_gen(): | |
# Convert 'date' to datetime but do not convert to string yet for sorting | |
dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"]) | |
# Concatenate the dates with other numeric columns formatted | |
raw_df = pd.concat( | |
[ | |
dates, | |
st.session_state["Cleaned_data_panel"] | |
.select_dtypes(np.number) | |
.applymap(format_numbers), | |
], | |
axis=1, | |
) | |
# Now sort raw_df by the 'date' column, which is still in datetime format | |
sorted_raw_df = raw_df.sort_values(by="date", ascending=True) | |
# After sorting, convert 'date' to string format for display | |
sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y") | |
return sorted_raw_df | |
# Display the sorted DataFrame in Streamlit | |
st.dataframe(raw_df_gen()) | |
col1 = st.columns(1) | |
if "selected_feature" not in st.session_state: | |
st.session_state["selected_feature"] = None | |
def generate_report_with_target(channel_data, target_feature): | |
report = sv.analyze([channel_data, "Dataset"], target_feat=target_feature) | |
temp_dir = tempfile.mkdtemp() | |
report_path = os.path.join(temp_dir, "report.html") | |
report.show_html( | |
filepath=report_path, open_browser=False | |
) # Generate the report as an HTML file | |
return report_path | |
def generate_profile_report(df): | |
pr = df.profile_report() | |
temp_dir = tempfile.mkdtemp() | |
report_path = os.path.join(temp_dir, "report.html") | |
pr.to_file(report_path) | |
return report_path | |
# st.header() | |
with st.expander("Univariate and Bivariate Report"): | |
eda_columns = st.columns(2) | |
with eda_columns[0]: | |
if st.button( | |
"Generate Profile Report", | |
help="Univariate report which inlcudes all statistical analysis", | |
): | |
with st.spinner("Generating Report"): | |
report_file = generate_profile_report( | |
st.session_state["Cleaned_data_panel"] | |
) | |
if os.path.exists(report_file): | |
with open(report_file, "rb") as f: | |
st.success("Report Generated") | |
st.download_button( | |
label="Download EDA Report", | |
data=f.read(), | |
file_name="pandas_profiling_report.html", | |
mime="text/html", | |
) | |
else: | |
st.warning( | |
"Report generation failed. Unable to find the report file." | |
) | |
with eda_columns[1]: | |
if st.button( | |
"Generate Sweetviz Report", | |
help="Bivariate report for selected response metric", | |
): | |
with st.spinner("Generating Report"): | |
report_file = generate_report_with_target( | |
st.session_state["Cleaned_data_panel"], target_column | |
) | |
if os.path.exists(report_file): | |
with open(report_file, "rb") as f: | |
st.success("Report Generated") | |
st.download_button( | |
label="Download EDA Report", | |
data=f.read(), | |
file_name="report.html", | |
mime="text/html", | |
) | |
else: | |
st.warning("Report generation failed. Unable to find the report file.") | |
# st.warning('Work in Progress') | |
with st.expander("Media Variables Analysis"): | |
# Get the selected feature | |
media_variables = [ | |
col | |
for col in media_channel | |
if "cost" not in col.lower() and "spend" not in col.lower() | |
] | |
st.session_state["selected_feature"] = st.selectbox( | |
"Select media", media_variables, format_func=format_display | |
) | |
st.session_state["project_dct"]["data_validation"]["selected_feature"] = ( | |
media_variables.index(st.session_state["selected_feature"]) | |
) | |
# Filter spends features based on the selected feature | |
spends_features = [ | |
col | |
for col in st.session_state["Cleaned_data_panel"].columns | |
if any(keyword in col.lower() for keyword in ["cost", "spend"]) | |
] | |
spends_feature = [ | |
col | |
for col in spends_features | |
if re.split(r"_cost|_spend", col.lower())[0] | |
in st.session_state["selected_feature"] | |
] | |
if "validation" not in st.session_state: | |
st.session_state["validation"] = st.session_state["project_dct"][ | |
"data_validation" | |
]["validated_variables"] | |
val_variables = [col for col in media_channel if col != "date"] | |
if not set( | |
st.session_state["project_dct"]["data_validation"]["validated_variables"] | |
).issubset(set(val_variables)): | |
st.session_state["validation"] = [] | |
if len(spends_feature) == 0: | |
st.warning("No spends varaible available for the selected metric in data") | |
else: | |
fig_row1 = line_plot( | |
st.session_state["Cleaned_data_panel"], | |
x_col="date", | |
y1_cols=[st.session_state["selected_feature"]], | |
y2_cols=[target_column], | |
title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time', | |
) | |
st.plotly_chart(fig_row1, use_container_width=True) | |
st.markdown("### Summary") | |
st.dataframe( | |
summary( | |
st.session_state["cleaned_data"], | |
[st.session_state["selected_feature"]], | |
spends=spends_feature[0], | |
), | |
use_container_width=True, | |
) | |
cols2 = st.columns(2) | |
if len(set(st.session_state["validation"]).intersection(val_variables)) == len( | |
val_variables | |
): | |
disable = True | |
help = "All media variables are validated" | |
else: | |
disable = False | |
help = "" | |
with cols2[0]: | |
if st.button("Validate", disabled=disable, help=help): | |
st.session_state["validation"].append( | |
st.session_state["selected_feature"] | |
) | |
with cols2[1]: | |
if st.checkbox("Validate all", disabled=disable, help=help): | |
st.session_state["validation"].extend(val_variables) | |
st.success("All media variables are validated ✅") | |
if len(set(st.session_state["validation"]).intersection(val_variables)) != len( | |
val_variables | |
): | |
validation_data = pd.DataFrame( | |
{ | |
"Validate": [ | |
(True if col in st.session_state["validation"] else False) | |
for col in val_variables | |
], | |
"Variables": val_variables, | |
} | |
) | |
sorted_validation_df = validation_data.sort_values( | |
by="Variables", ascending=True, na_position="first" | |
) | |
cols3 = st.columns([1, 30]) | |
with cols3[1]: | |
validation_df = st.data_editor( | |
sorted_validation_df, | |
# column_config={ | |
# 'Validate':st.column_config.CheckboxColumn(wi) | |
# }, | |
column_config={ | |
"Validate": st.column_config.CheckboxColumn( | |
default=False, | |
width=100, | |
), | |
"Variables": st.column_config.TextColumn(width=1000), | |
}, | |
hide_index=True, | |
) | |
selected_rows = validation_df[validation_df["Validate"] == True][ | |
"Variables" | |
] | |
# st.write(selected_rows) | |
st.session_state["validation"].extend(selected_rows) | |
st.session_state["project_dct"]["data_validation"][ | |
"validated_variables" | |
] = st.session_state["validation"] | |
not_validated_variables = [ | |
col | |
for col in val_variables | |
if col not in st.session_state["validation"] | |
] | |
if not_validated_variables: | |
not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}' | |
st.warning(not_validated_message) | |
with st.expander("Non Media Variables Analysis"): | |
selected_columns_row4 = st.selectbox( | |
"Select Channel", | |
Non_media_variables, | |
format_func=format_display, | |
index=st.session_state["project_dct"]["data_validation"]["Non_media_variables"], | |
) | |
st.session_state["project_dct"]["data_validation"]["Non_media_variables"] = ( | |
Non_media_variables.index(selected_columns_row4) | |
) | |
# # Create the dual-axis line plot | |
fig_row4 = line_plot( | |
st.session_state["Cleaned_data_panel"], | |
x_col="date", | |
y1_cols=[selected_columns_row4], | |
y2_cols=[target_column], | |
title=f"Analysis of {selected_columns_row4} and {target_column} Over Time", | |
) | |
st.plotly_chart(fig_row4, use_container_width=True) | |
selected_non_media = selected_columns_row4 | |
sum_df = st.session_state["Cleaned_data_panel"][ | |
["date", selected_non_media, target_column] | |
] | |
sum_df["Year"] = pd.to_datetime( | |
st.session_state["Cleaned_data_panel"]["date"] | |
).dt.year | |
# st.dataframe(df) | |
# st.dataframe(sum_df.head(2)) | |
print(sum_df) | |
sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum") | |
sum_df.loc["Grand Total"] = sum_df.sum() | |
sum_df = sum_df.applymap(format_numbers) | |
sum_df.fillna("-", inplace=True) | |
sum_df = sum_df.replace({"0.0": "-", "nan": "-"}) | |
st.markdown("### Summary") | |
st.dataframe(sum_df, use_container_width=True) | |
# with st.expander('Interactive Dashboard'): | |
# pygg_app=StreamlitRenderer(st.session_state['cleaned_data']) | |
# pygg_app.explorer() | |
with st.expander("Correlation Analysis"): | |
options = list( | |
st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns | |
) | |
# selected_options = [] | |
# num_columns = 4 | |
# num_rows = -(-len(options) // num_columns) # Ceiling division to calculate rows | |
# # Create a grid of checkboxes | |
# st.header('Select Features for Correlation Plot') | |
# tick=False | |
# if st.checkbox('Select all'): | |
# tick=True | |
# selected_options = [] | |
# for row in range(num_rows): | |
# cols = st.columns(num_columns) | |
# for col in cols: | |
# if options: | |
# option = options.pop(0) | |
# selected = col.checkbox(option,value=tick) | |
# if selected: | |
# selected_options.append(option) | |
# # Display selected options | |
selected_options = st.multiselect( | |
"Select Variables For correlation plot", | |
[var for var in options if var != target_column], | |
default=options[3], | |
) | |
st.pyplot( | |
correlation_plot( | |
st.session_state["Cleaned_data_panel"], | |
selected_options, | |
target_column, | |
) | |
) | |
if st.button("Save Changes", use_container_width=True): | |
update_db("2_Data_Validation.py") | |
project_dct_path = os.path.join(st.session_state["project_path"], "project_dct.pkl") | |
with open(project_dct_path, "wb") as f: | |
pickle.dump(st.session_state["project_dct"], f) | |
st.success("Changes saved") | |