|
|
|
import streamlit as st |
|
|
|
st.set_page_config( |
|
page_title="Transformations", |
|
page_icon=":shark:", |
|
layout="wide", |
|
initial_sidebar_state="collapsed", |
|
) |
|
|
|
import os |
|
import pickle |
|
import sqlite3 |
|
import numpy as np |
|
import pandas as pd |
|
from utilities import update_db |
|
import plotly.graph_objects as go |
|
from utilities import set_header, load_local_css, update_db, project_selection |
|
|
|
|
|
load_local_css("styles.css") |
|
set_header() |
|
|
|
if "username" not in st.session_state: |
|
st.session_state["username"] = None |
|
|
|
if "project_name" not in st.session_state: |
|
st.session_state["project_name"] = None |
|
|
|
if "project_dct" not in st.session_state: |
|
project_selection() |
|
st.stop() |
|
|
|
if "username" in st.session_state and st.session_state["username"] is not None: |
|
|
|
conn = sqlite3.connect( |
|
r"DB/User.db", check_same_thread=False |
|
) |
|
c = conn.cursor() |
|
|
|
if not os.path.exists( |
|
os.path.join(st.session_state["project_path"], "data_import.pkl") |
|
): |
|
st.error("Please move to Data Import page") |
|
st.stop() |
|
|
|
|
|
with open( |
|
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb" |
|
) as f: |
|
data = pickle.load(f) |
|
|
|
|
|
final_df_loaded = data["final_df"] |
|
bin_dict_loaded = data["bin_dict"] |
|
|
|
|
|
if "transformed_columns_dict" not in st.session_state: |
|
st.session_state["transformed_columns_dict"] = {} |
|
|
|
if "final_df" not in st.session_state: |
|
st.session_state["final_df"] = final_df_loaded |
|
|
|
if "summary_string" not in st.session_state: |
|
st.session_state["summary_string"] = None |
|
|
|
|
|
original_columns = { |
|
category: bin_dict_loaded[category] |
|
for category in ["Media", "Internal", "Exogenous"] |
|
if category in bin_dict_loaded |
|
} |
|
|
|
|
|
panel_1 = bin_dict_loaded.get("Panel Level 1") |
|
panel_2 = bin_dict_loaded.get("Panel Level 2") |
|
|
|
|
|
if panel_1: |
|
panel = panel_1 + panel_2 if panel_2 else panel_1 |
|
else: |
|
panel = [] |
|
|
|
|
|
def transformation_widgets(category, transform_params, date_granularity): |
|
|
|
if ( |
|
st.session_state["project_dct"]["transformations"] is None |
|
or st.session_state["project_dct"]["transformations"] == {} |
|
): |
|
st.session_state["project_dct"]["transformations"] = {} |
|
if category not in st.session_state["project_dct"]["transformations"].keys(): |
|
st.session_state["project_dct"]["transformations"][category] = {} |
|
|
|
|
|
predefined_defualts = { |
|
"Lag": (1, 2), |
|
"Lead": (1, 2), |
|
"Moving Average": (1, 2), |
|
"Saturation": (10, 20), |
|
"Power": (2, 4), |
|
"Adstock": (0.5, 0.7), |
|
} |
|
|
|
|
|
transformation_options = { |
|
"Media": [ |
|
"Lag", |
|
"Moving Average", |
|
"Saturation", |
|
"Power", |
|
"Adstock", |
|
], |
|
"Internal": ["Lead", "Lag", "Moving Average"], |
|
"Exogenous": ["Lead", "Lag", "Moving Average"], |
|
} |
|
|
|
expanded = st.session_state["project_dct"]["transformations"][category].get( |
|
"expanded", False |
|
) |
|
|
|
|
|
def create_transformation_widgets(column, transformations): |
|
with column: |
|
for transformation in transformations: |
|
|
|
if transformation == "Lead": |
|
lead_default = st.session_state["project_dct"][ |
|
"transformations" |
|
][category].get("Lead", predefined_defualts["Lead"]) |
|
st.markdown(f"**Lead ({date_granularity})**") |
|
lead = st.slider( |
|
"Lead periods", |
|
1, |
|
10, |
|
lead_default, |
|
1, |
|
key=f"lead_{category}", |
|
label_visibility="collapsed", |
|
) |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"Lead" |
|
] = lead |
|
start = lead[0] |
|
end = lead[1] |
|
step = 1 |
|
transform_params[category]["Lead"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Lag": |
|
lag_default = st.session_state["project_dct"][ |
|
"transformations" |
|
][category].get("Lag", predefined_defualts["Lag"]) |
|
st.markdown(f"**Lag ({date_granularity})**") |
|
lag = st.slider( |
|
"Lag periods", |
|
1, |
|
10, |
|
lag_default, |
|
1, |
|
key=f"lag_{category}", |
|
label_visibility="collapsed", |
|
) |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"Lag" |
|
] = lag |
|
start = lag[0] |
|
end = lag[1] |
|
step = 1 |
|
transform_params[category]["Lag"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Moving Average": |
|
ma_default = st.session_state["project_dct"]["transformations"][ |
|
category |
|
].get("MA", predefined_defualts["Moving Average"]) |
|
st.markdown(f"**Moving Average ({date_granularity})**") |
|
window = st.slider( |
|
"Window size for Moving Average", |
|
1, |
|
10, |
|
ma_default, |
|
1, |
|
key=f"ma_{category}", |
|
label_visibility="collapsed", |
|
) |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"MA" |
|
] = window |
|
start = window[0] |
|
end = window[1] |
|
step = 1 |
|
transform_params[category]["Moving Average"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Saturation": |
|
st.markdown("**Saturation (%)**") |
|
saturation_default = st.session_state["project_dct"][ |
|
"transformations" |
|
][category].get("Saturation", predefined_defualts["Saturation"]) |
|
saturation_point = st.slider( |
|
f"Saturation Percentage", |
|
0, |
|
100, |
|
saturation_default, |
|
10, |
|
key=f"sat_{category}", |
|
label_visibility="collapsed", |
|
) |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"Saturation" |
|
] = saturation_point |
|
start = saturation_point[0] |
|
end = saturation_point[1] |
|
step = 10 |
|
transform_params[category]["Saturation"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Power": |
|
st.markdown("**Power**") |
|
power_default = st.session_state["project_dct"][ |
|
"transformations" |
|
][category].get("Power", predefined_defualts["Power"]) |
|
power = st.slider( |
|
f"Power", |
|
0, |
|
10, |
|
power_default, |
|
1, |
|
key=f"power_{category}", |
|
label_visibility="collapsed", |
|
) |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"Power" |
|
] = power |
|
start = power[0] |
|
end = power[1] |
|
step = 1 |
|
transform_params[category]["Power"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Adstock": |
|
ads_default = st.session_state["project_dct"][ |
|
"transformations" |
|
][category].get("Adstock", predefined_defualts["Adstock"]) |
|
st.markdown("**Adstock**") |
|
rate = st.slider( |
|
f"Factor ({category})", |
|
0.0, |
|
1.0, |
|
ads_default, |
|
0.05, |
|
key=f"adstock_{category}", |
|
label_visibility="collapsed", |
|
) |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"Adstock" |
|
] = rate |
|
start = rate[0] |
|
end = rate[1] |
|
step = 0.05 |
|
adstock_range = [ |
|
round(a, 3) for a in np.arange(start, end + step, step) |
|
] |
|
transform_params[category]["Adstock"] = np.array(adstock_range) |
|
|
|
with st.expander(f"{category} Transformations", expanded=expanded): |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"expanded" |
|
] = True |
|
|
|
|
|
sel_transformations = st.session_state["project_dct"]["transformations"][ |
|
category |
|
].get(f"transformation_{category}", []) |
|
|
|
transformations_to_apply = st.multiselect( |
|
"Select transformations to apply", |
|
options=transformation_options[category], |
|
default=sel_transformations, |
|
key=f"transformation_{category}", |
|
|
|
) |
|
st.session_state["project_dct"]["transformations"][category][ |
|
"transformation_" + category |
|
] = transformations_to_apply |
|
|
|
transformations_per_column = ( |
|
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 |
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
transformations_col1 = transformations_to_apply[:transformations_per_column] |
|
transformations_col2 = transformations_to_apply[transformations_per_column:] |
|
|
|
|
|
create_transformation_widgets(col1, transformations_col1) |
|
create_transformation_widgets(col2, transformations_col2) |
|
|
|
|
|
def create_specific_transformation_widgets( |
|
column, |
|
transformations, |
|
channel_name, |
|
date_granularity, |
|
specific_transform_params, |
|
): |
|
|
|
|
|
predefined_defualts = { |
|
"Lag": (1, 2), |
|
"Lead": (1, 2), |
|
"Moving Average": (1, 2), |
|
"Saturation": (10, 20), |
|
"Power": (2, 4), |
|
"Adstock": (0.5, 0.7), |
|
} |
|
|
|
with column: |
|
for transformation in transformations: |
|
|
|
if transformation == "Lead": |
|
st.markdown(f"**Lead ({date_granularity})**") |
|
lead = st.slider( |
|
"Lead periods", |
|
1, |
|
10, |
|
predefined_defualts["Lead"], |
|
1, |
|
key=f"lead_{channel_name}_specific", |
|
label_visibility="collapsed", |
|
) |
|
start = lead[0] |
|
end = lead[1] |
|
step = 1 |
|
specific_transform_params[channel_name]["Lead"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Lag": |
|
st.markdown(f"**Lag ({date_granularity})**") |
|
lag = st.slider( |
|
"Lag periods", |
|
1, |
|
10, |
|
predefined_defualts["Lag"], |
|
1, |
|
key=f"lag_{channel_name}_specific", |
|
label_visibility="collapsed", |
|
) |
|
start = lag[0] |
|
end = lag[1] |
|
step = 1 |
|
specific_transform_params[channel_name]["Lag"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Moving Average": |
|
st.markdown(f"**Moving Average ({date_granularity})**") |
|
window = st.slider( |
|
"Window size for Moving Average", |
|
1, |
|
10, |
|
predefined_defualts["Moving Average"], |
|
1, |
|
key=f"ma_{channel_name}_specific", |
|
label_visibility="collapsed", |
|
) |
|
start = window[0] |
|
end = window[1] |
|
step = 1 |
|
specific_transform_params[channel_name]["Moving Average"] = ( |
|
np.arange(start, end + step, step) |
|
) |
|
|
|
if transformation == "Saturation": |
|
st.markdown("**Saturation (%)**") |
|
saturation_point = st.slider( |
|
f"Saturation Percentage", |
|
0, |
|
100, |
|
predefined_defualts["Saturation"], |
|
10, |
|
key=f"sat_{channel_name}_specific", |
|
label_visibility="collapsed", |
|
) |
|
start = saturation_point[0] |
|
end = saturation_point[1] |
|
step = 10 |
|
specific_transform_params[channel_name]["Saturation"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Power": |
|
st.markdown("**Power**") |
|
power = st.slider( |
|
f"Power", |
|
0, |
|
10, |
|
predefined_defualts["Power"], |
|
1, |
|
key=f"power_{channel_name}_specific", |
|
label_visibility="collapsed", |
|
) |
|
start = power[0] |
|
end = power[1] |
|
step = 1 |
|
specific_transform_params[channel_name]["Power"] = np.arange( |
|
start, end + step, step |
|
) |
|
|
|
if transformation == "Adstock": |
|
st.markdown("**Adstock**") |
|
rate = st.slider( |
|
f"Factor", |
|
0.0, |
|
1.0, |
|
predefined_defualts["Adstock"], |
|
0.05, |
|
key=f"adstock_{channel_name}_specific", |
|
label_visibility="collapsed", |
|
) |
|
start = rate[0] |
|
end = rate[1] |
|
step = 0.05 |
|
adstock_range = [ |
|
round(a, 3) for a in np.arange(start, end + step, step) |
|
] |
|
specific_transform_params[channel_name]["Adstock"] = np.array( |
|
adstock_range |
|
) |
|
|
|
|
|
def apply_lag(df, lag): |
|
return df.shift(lag) |
|
|
|
|
|
def apply_lead(df, lead): |
|
return df.shift(-lead) |
|
|
|
|
|
def apply_moving_average(df, window_size): |
|
return df.rolling(window=window_size).mean() |
|
|
|
|
|
def apply_saturation(df, saturation_percent_100): |
|
|
|
saturation_percent = saturation_percent_100 / 100.0 |
|
|
|
|
|
column_max = df.max() |
|
column_min = df.min() |
|
saturation_point = (column_min + column_max) / 2 |
|
|
|
numerator = np.log( |
|
(1 / (saturation_percent if saturation_percent != 1 else 1 - 1e-9)) - 1 |
|
) |
|
denominator = np.log(saturation_point / max(column_max, 1e-9)) |
|
|
|
steepness = numerator / max( |
|
denominator, 1e-9 |
|
) |
|
|
|
|
|
transformed_series = df.apply( |
|
lambda x: ( |
|
1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness) |
|
) |
|
* x |
|
) |
|
|
|
return transformed_series |
|
|
|
|
|
def apply_power(df, power): |
|
return df**power |
|
|
|
|
|
def apply_adstock(df, factor): |
|
x = 0 |
|
|
|
adstock_var = [x := x * factor + v for v in df] |
|
ans = pd.Series(adstock_var, index=df.index) |
|
return ans |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def generate_transformed_columns( |
|
original_columns, transform_params, specific_transform_params |
|
): |
|
transformed_columns, summary = {}, {} |
|
|
|
for category, columns in original_columns.items(): |
|
for column in columns: |
|
transformed_columns[column] = [] |
|
summary_details = ( |
|
[] |
|
) |
|
|
|
if column in specific_transform_params.keys(): |
|
for transformation, values in specific_transform_params[ |
|
column |
|
].items(): |
|
|
|
for value in values: |
|
transformed_name = f"{column}@{transformation}_{value}" |
|
transformed_columns[column].append(transformed_name) |
|
|
|
|
|
if len(values) > 1: |
|
formatted_values = ( |
|
", ".join(map(str, values[:-1])) |
|
+ " and " |
|
+ str(values[-1]) |
|
) |
|
else: |
|
formatted_values = str(values[0]) |
|
|
|
|
|
summary_details.append(f"{transformation} ({formatted_values})") |
|
|
|
else: |
|
if category in transform_params: |
|
for transformation, values in transform_params[ |
|
category |
|
].items(): |
|
|
|
for value in values: |
|
transformed_name = f"{column}@{transformation}_{value}" |
|
transformed_columns[column].append(transformed_name) |
|
|
|
|
|
if len(values) > 1: |
|
formatted_values = ( |
|
", ".join(map(str, values[:-1])) |
|
+ " and " |
|
+ str(values[-1]) |
|
) |
|
else: |
|
formatted_values = str(values[0]) |
|
|
|
|
|
summary_details.append( |
|
f"{transformation} ({formatted_values})" |
|
) |
|
|
|
|
|
if summary_details: |
|
formatted_summary = "⮕ ".join(summary_details) |
|
|
|
summary[column] = f"<strong>{column}</strong>: {formatted_summary}" |
|
|
|
|
|
summary_items = [ |
|
f"{idx + 1}. {details}" for idx, details in enumerate(summary.values()) |
|
] |
|
|
|
summary_string = "\n".join(summary_items) |
|
|
|
return transformed_columns, summary_string |
|
|
|
|
|
def transform_slice( |
|
transform_params, |
|
transformation_functions, |
|
panel, |
|
df, |
|
df_slice, |
|
category, |
|
category_df, |
|
): |
|
|
|
for transformation, parameters in transform_params[category].items(): |
|
transformation_function = transformation_functions[transformation] |
|
|
|
|
|
if len(panel) > 0: |
|
|
|
category_df = pd.concat( |
|
[ |
|
df_slice.groupby(panel) |
|
.transform(transformation_function, p) |
|
.add_suffix(f"@{transformation}_{p}") |
|
for p in parameters |
|
], |
|
axis=1, |
|
) |
|
|
|
|
|
category_df.fillna(0, inplace=True) |
|
|
|
|
|
df_slice = pd.concat( |
|
[df[panel], category_df], |
|
axis=1, |
|
) |
|
|
|
else: |
|
for p in parameters: |
|
|
|
temp_df = df_slice.apply( |
|
lambda x: transformation_function(x, p), axis=0 |
|
).rename( |
|
lambda x: f"{x}@{transformation}_{p}", |
|
axis="columns", |
|
) |
|
|
|
category_df = pd.concat([category_df, temp_df], axis=1) |
|
|
|
|
|
category_df.fillna(0, inplace=True) |
|
|
|
|
|
df_slice = pd.concat( |
|
[df[panel], category_df], |
|
axis=1, |
|
) |
|
|
|
return category_df, df, df_slice |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def apply_category_transformations( |
|
df_main, bin_dict, transform_params, panel, specific_transform_params |
|
): |
|
|
|
transformation_functions = { |
|
"Lead": apply_lead, |
|
"Lag": apply_lag, |
|
"Moving Average": apply_moving_average, |
|
"Saturation": apply_saturation, |
|
"Power": apply_power, |
|
"Adstock": apply_adstock, |
|
} |
|
|
|
|
|
transformed_dfs = [] |
|
|
|
|
|
for category in ["Media", "Exogenous", "Internal"]: |
|
if ( |
|
category not in transform_params |
|
or category not in bin_dict |
|
or not transform_params[category] |
|
): |
|
continue |
|
|
|
|
|
category_df = pd.DataFrame() |
|
|
|
|
|
df_slice = df_main[bin_dict[category] + panel].copy() |
|
|
|
|
|
df_slice = df_slice.drop( |
|
columns=list(specific_transform_params.keys()), errors="ignore" |
|
).copy() |
|
|
|
category_df, df, df_slice_updated = transform_slice( |
|
transform_params.copy(), |
|
transformation_functions.copy(), |
|
panel, |
|
df_main.copy(), |
|
df_slice.copy(), |
|
category, |
|
category_df.copy(), |
|
) |
|
|
|
|
|
if not category_df.empty: |
|
transformed_dfs.append(category_df) |
|
|
|
|
|
for channel_specific in specific_transform_params: |
|
|
|
category_df = pd.DataFrame() |
|
|
|
df_slice_specific = df_main[[channel_specific] + panel].copy() |
|
transform_params_specific = { |
|
"Media": specific_transform_params[channel_specific] |
|
} |
|
|
|
category_df, df, df_slice_specific_updated = transform_slice( |
|
transform_params_specific.copy(), |
|
transformation_functions.copy(), |
|
panel, |
|
df_main.copy(), |
|
df_slice_specific.copy(), |
|
"Media", |
|
category_df.copy(), |
|
) |
|
|
|
|
|
if not category_df.empty: |
|
transformed_dfs.append(category_df) |
|
|
|
|
|
if len(transformed_dfs) > 0: |
|
final_df = pd.concat([df_main] + transformed_dfs, axis=1) |
|
else: |
|
|
|
final_df = df_main |
|
|
|
|
|
columns_with_at = [col for col in final_df.columns if "@" in col] |
|
|
|
|
|
columns_to_drop = set() |
|
|
|
|
|
for col in columns_with_at: |
|
base_name = col.split("@")[0] |
|
for other_col in columns_with_at: |
|
if other_col.startswith(base_name) and len(other_col.split("@")) > len( |
|
col.split("@") |
|
): |
|
columns_to_drop.add(col) |
|
break |
|
|
|
|
|
final_df.drop(columns=list(columns_to_drop), inplace=True) |
|
|
|
return final_df |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def infer_date_granularity(df): |
|
|
|
common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0] |
|
|
|
|
|
if common_freq == 1: |
|
return "daily" |
|
elif common_freq == 7: |
|
return "weekly" |
|
elif 28 <= common_freq <= 31: |
|
return "monthly" |
|
else: |
|
return "irregular" |
|
|
|
|
|
|
|
|
|
|
|
|
|
date_granularity = infer_date_granularity(final_df_loaded) |
|
|
|
|
|
transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}} |
|
|
|
|
|
cols1 = st.columns([2, 1]) |
|
|
|
with cols1[0]: |
|
st.markdown(f"**Welcome {st.session_state['username']}**") |
|
with cols1[1]: |
|
st.markdown(f"**Current Project: {st.session_state['project_name']}**") |
|
|
|
st.markdown("### Select Transformations to Apply") |
|
|
|
with st.expander("Specific Transformations"): |
|
select_specific_channels = st.multiselect( |
|
"Select channels", options=bin_dict_loaded["Media"] |
|
) |
|
|
|
specific_transform_params = {} |
|
for select_specific_channel in select_specific_channels: |
|
specific_transform_params[select_specific_channel] = {} |
|
|
|
st.divider() |
|
channel_name = str(select_specific_channel).replace("_", " ").title() |
|
st.markdown(f"###### {channel_name}") |
|
transformations_to_apply = st.multiselect( |
|
"Select transformations to apply", |
|
options=[ |
|
"Lag", |
|
"Moving Average", |
|
"Saturation", |
|
"Power", |
|
"Adstock", |
|
], |
|
default="Adstock", |
|
key=f"specific_transformation_{select_specific_channel}_Media", |
|
) |
|
|
|
transformations_per_column = ( |
|
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2 |
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
transformations_col1 = transformations_to_apply[:transformations_per_column] |
|
transformations_col2 = transformations_to_apply[transformations_per_column:] |
|
|
|
|
|
create_specific_transformation_widgets( |
|
col1, |
|
transformations_col1, |
|
select_specific_channel, |
|
date_granularity, |
|
specific_transform_params, |
|
) |
|
create_specific_transformation_widgets( |
|
col2, |
|
transformations_col2, |
|
select_specific_channel, |
|
date_granularity, |
|
specific_transform_params, |
|
) |
|
|
|
for category in ["Media", "Internal", "Exogenous"]: |
|
|
|
if category == "Internal": |
|
continue |
|
|
|
transformation_widgets(category, transform_params, date_granularity) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button("Accept and Proceed", use_container_width=True): |
|
with st.spinner("Applying transformations..."): |
|
final_df = apply_category_transformations( |
|
final_df_loaded.copy(), |
|
bin_dict_loaded.copy(), |
|
transform_params.copy(), |
|
panel.copy(), |
|
specific_transform_params.copy(), |
|
) |
|
|
|
|
|
transformed_columns_dict, summary_string = generate_transformed_columns( |
|
original_columns, transform_params, specific_transform_params |
|
) |
|
|
|
|
|
st.session_state["final_df"] = final_df |
|
st.session_state["summary_string"] = summary_string |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("### Transformed DataFrame") |
|
final_df = st.session_state["final_df"].copy() |
|
|
|
sort_col = [] |
|
for col in final_df.columns: |
|
if col in ["Panel_1", "Panel_2", "date"]: |
|
sort_col.append(col) |
|
|
|
sorted_final_df = final_df.sort_values( |
|
by=sort_col, ascending=True, na_position="first" |
|
) |
|
|
|
|
|
sorted_final_df = sorted_final_df.loc[:, ~sorted_final_df.columns.duplicated()] |
|
|
|
|
|
if sorted_final_df.shape[1] > 500: |
|
|
|
st.warning( |
|
"The transformed DataFrame has more than 500 columns. Displaying only the first 500 columns.", |
|
icon="⚠️", |
|
) |
|
st.dataframe(sorted_final_df.iloc[:, :500], hide_index=True) |
|
else: |
|
st.dataframe(sorted_final_df, hide_index=True) |
|
|
|
|
|
total_rows, total_columns = st.session_state["final_df"].shape |
|
st.markdown( |
|
f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
if "summary_string" in st.session_state and st.session_state["summary_string"]: |
|
with st.expander("Summary of Transformations"): |
|
st.markdown("### Summary of Transformations") |
|
st.markdown(st.session_state["summary_string"], unsafe_allow_html=True) |
|
|
|
@st.cache_resource(show_spinner=False) |
|
def save_to_pickle(file_path, final_df): |
|
|
|
with open(file_path, "wb") as f: |
|
pickle.dump({"final_df_transformed": final_df}, f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
variables = [col for col in final_df.columns if col.lower() != "date"] |
|
|
|
|
|
with st.expander("Transformed Variable Correlation Plot"): |
|
selected_vars = st.multiselect( |
|
"Choose variables for correlation plot:", variables |
|
) |
|
|
|
|
|
if selected_vars: |
|
corr_df = final_df[selected_vars].corr() |
|
|
|
|
|
annotations = [] |
|
for i in range(len(corr_df)): |
|
for j in range(len(corr_df.columns)): |
|
annotations.append( |
|
go.layout.Annotation( |
|
text=f"{corr_df.iloc[i, j]:.2f}", |
|
x=corr_df.columns[j], |
|
y=corr_df.index[i], |
|
showarrow=False, |
|
font=dict(color="black"), |
|
) |
|
) |
|
|
|
|
|
heatmap = go.Heatmap( |
|
z=corr_df.values, |
|
x=corr_df.columns, |
|
y=corr_df.index, |
|
colorscale="RdBu", |
|
zmin=-1, |
|
zmax=1, |
|
) |
|
|
|
layout = go.Layout( |
|
title="Transformed Variable Correlation Plot", |
|
xaxis=dict(title="Variables"), |
|
yaxis=dict(title="Variables"), |
|
width=1000, |
|
height=1000, |
|
annotations=annotations, |
|
) |
|
|
|
fig = go.Figure(data=[heatmap], layout=layout) |
|
|
|
st.plotly_chart(fig) |
|
else: |
|
st.write("Please select at least one variable to plot.") |
|
|
|
|
|
|
|
|
|
|
|
if st.button("Accept and Save", use_container_width=True): |
|
save_to_pickle( |
|
os.path.join(st.session_state["project_path"], "final_df_transformed.pkl"), |
|
st.session_state["final_df"], |
|
) |
|
project_dct_path = os.path.join( |
|
st.session_state["project_path"], "project_dct.pkl" |
|
) |
|
|
|
with open(project_dct_path, "wb") as f: |
|
pickle.dump(st.session_state["project_dct"], f) |
|
|
|
update_db("3_Transformations.py") |
|
|
|
st.toast("💾 Saved Successfully!") |
|
|