RFI / pages /3_Transformations.py
Manoj
latest
fde220d
# Importing necessary libraries
import streamlit as st
st.set_page_config(
page_title="Transformations",
page_icon=":shark:",
layout="wide",
initial_sidebar_state="collapsed",
)
import os
import pickle
import sqlite3
import numpy as np
import pandas as pd
from utilities import update_db
import plotly.graph_objects as go
from utilities import set_header, load_local_css, update_db, project_selection
load_local_css("styles.css")
set_header()
if "username" not in st.session_state:
st.session_state["username"] = None
if "project_name" not in st.session_state:
st.session_state["project_name"] = None
if "project_dct" not in st.session_state:
project_selection()
st.stop()
if "username" in st.session_state and st.session_state["username"] is not None:
conn = sqlite3.connect(
r"DB/User.db", check_same_thread=False
) # connection with sql db
c = conn.cursor()
if not os.path.exists(
os.path.join(st.session_state["project_path"], "data_import.pkl")
):
st.error("Please move to Data Import page")
st.stop()
# Deserialize and load the objects from the pickle file
with open(
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
) as f:
data = pickle.load(f)
# Accessing the loaded objects
final_df_loaded = data["final_df"]
bin_dict_loaded = data["bin_dict"]
# Initialize session state
if "transformed_columns_dict" not in st.session_state:
st.session_state["transformed_columns_dict"] = {} # Default empty dictionary
if "final_df" not in st.session_state:
st.session_state["final_df"] = final_df_loaded # Default as original dataframe
if "summary_string" not in st.session_state:
st.session_state["summary_string"] = None # Default as None
# Extract original columns for specified categories
original_columns = {
category: bin_dict_loaded[category]
for category in ["Media", "Internal", "Exogenous"]
if category in bin_dict_loaded
}
# Retrive Panel columns
panel_1 = bin_dict_loaded.get("Panel Level 1")
panel_2 = bin_dict_loaded.get("Panel Level 2")
# Apply transformations on panel level
if panel_1:
panel = panel_1 + panel_2 if panel_2 else panel_1
else:
panel = []
# Function to build transformation widgets
def transformation_widgets(category, transform_params, date_granularity):
if (
st.session_state["project_dct"]["transformations"] is None
or st.session_state["project_dct"]["transformations"] == {}
):
st.session_state["project_dct"]["transformations"] = {}
if category not in st.session_state["project_dct"]["transformations"].keys():
st.session_state["project_dct"]["transformations"][category] = {}
# Define a dict of pre-defined default values of every transformation
predefined_defualts = {
"Lag": (1, 2),
"Lead": (1, 2),
"Moving Average": (1, 2),
"Saturation": (10, 20),
"Power": (2, 4),
"Adstock": (0.5, 0.7),
}
# Transformation Options
transformation_options = {
"Media": [
"Lag",
"Moving Average",
"Saturation",
"Power",
"Adstock",
],
"Internal": ["Lead", "Lag", "Moving Average"],
"Exogenous": ["Lead", "Lag", "Moving Average"],
}
expanded = st.session_state["project_dct"]["transformations"][category].get(
"expanded", False
)
# Define a helper function to create widgets for each transformation
def create_transformation_widgets(column, transformations):
with column:
for transformation in transformations:
# Conditionally create widgets for selected transformations
if transformation == "Lead":
lead_default = st.session_state["project_dct"][
"transformations"
][category].get("Lead", predefined_defualts["Lead"])
st.markdown(f"**Lead ({date_granularity})**")
lead = st.slider(
"Lead periods",
1,
10,
lead_default,
1,
key=f"lead_{category}",
label_visibility="collapsed",
)
st.session_state["project_dct"]["transformations"][category][
"Lead"
] = lead
start = lead[0]
end = lead[1]
step = 1
transform_params[category]["Lead"] = np.arange(
start, end + step, step
)
if transformation == "Lag":
lag_default = st.session_state["project_dct"][
"transformations"
][category].get("Lag", predefined_defualts["Lag"])
st.markdown(f"**Lag ({date_granularity})**")
lag = st.slider(
"Lag periods",
1,
10,
lag_default,
1,
key=f"lag_{category}",
label_visibility="collapsed",
)
st.session_state["project_dct"]["transformations"][category][
"Lag"
] = lag
start = lag[0]
end = lag[1]
step = 1
transform_params[category]["Lag"] = np.arange(
start, end + step, step
)
if transformation == "Moving Average":
ma_default = st.session_state["project_dct"]["transformations"][
category
].get("MA", predefined_defualts["Moving Average"])
st.markdown(f"**Moving Average ({date_granularity})**")
window = st.slider(
"Window size for Moving Average",
1,
10,
ma_default,
1,
key=f"ma_{category}",
label_visibility="collapsed",
)
st.session_state["project_dct"]["transformations"][category][
"MA"
] = window
start = window[0]
end = window[1]
step = 1
transform_params[category]["Moving Average"] = np.arange(
start, end + step, step
)
if transformation == "Saturation":
st.markdown("**Saturation (%)**")
saturation_default = st.session_state["project_dct"][
"transformations"
][category].get("Saturation", predefined_defualts["Saturation"])
saturation_point = st.slider(
f"Saturation Percentage",
0,
100,
saturation_default,
10,
key=f"sat_{category}",
label_visibility="collapsed",
)
st.session_state["project_dct"]["transformations"][category][
"Saturation"
] = saturation_point
start = saturation_point[0]
end = saturation_point[1]
step = 10
transform_params[category]["Saturation"] = np.arange(
start, end + step, step
)
if transformation == "Power":
st.markdown("**Power**")
power_default = st.session_state["project_dct"][
"transformations"
][category].get("Power", predefined_defualts["Power"])
power = st.slider(
f"Power",
0,
10,
power_default,
1,
key=f"power_{category}",
label_visibility="collapsed",
)
st.session_state["project_dct"]["transformations"][category][
"Power"
] = power
start = power[0]
end = power[1]
step = 1
transform_params[category]["Power"] = np.arange(
start, end + step, step
)
if transformation == "Adstock":
ads_default = st.session_state["project_dct"][
"transformations"
][category].get("Adstock", predefined_defualts["Adstock"])
st.markdown("**Adstock**")
rate = st.slider(
f"Factor ({category})",
0.0,
1.0,
ads_default,
0.05,
key=f"adstock_{category}",
label_visibility="collapsed",
)
st.session_state["project_dct"]["transformations"][category][
"Adstock"
] = rate
start = rate[0]
end = rate[1]
step = 0.05
adstock_range = [
round(a, 3) for a in np.arange(start, end + step, step)
]
transform_params[category]["Adstock"] = np.array(adstock_range)
with st.expander(f"{category} Transformations", expanded=expanded):
st.session_state["project_dct"]["transformations"][category][
"expanded"
] = True
# Let users select which transformations to apply
sel_transformations = st.session_state["project_dct"]["transformations"][
category
].get(f"transformation_{category}", [])
transformations_to_apply = st.multiselect(
"Select transformations to apply",
options=transformation_options[category],
default=sel_transformations,
key=f"transformation_{category}",
# on_change=selection_change(),
)
st.session_state["project_dct"]["transformations"][category][
"transformation_" + category
] = transformations_to_apply
# Determine the number of transformations to put in each column
transformations_per_column = (
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
)
# Create two columns
col1, col2 = st.columns(2)
# Assign transformations to each column
transformations_col1 = transformations_to_apply[:transformations_per_column]
transformations_col2 = transformations_to_apply[transformations_per_column:]
# Create widgets in each column
create_transformation_widgets(col1, transformations_col1)
create_transformation_widgets(col2, transformations_col2)
# Define a helper function to create widgets for each specific transformation
def create_specific_transformation_widgets(
column,
transformations,
channel_name,
date_granularity,
specific_transform_params,
):
# Define a dict of pre-defined default values of every transformation
predefined_defualts = {
"Lag": (1, 2),
"Lead": (1, 2),
"Moving Average": (1, 2),
"Saturation": (10, 20),
"Power": (2, 4),
"Adstock": (0.5, 0.7),
}
with column:
for transformation in transformations:
# Conditionally create widgets for selected transformations
if transformation == "Lead":
st.markdown(f"**Lead ({date_granularity})**")
lead = st.slider(
"Lead periods",
1,
10,
predefined_defualts["Lead"],
1,
key=f"lead_{channel_name}_specific",
label_visibility="collapsed",
)
start = lead[0]
end = lead[1]
step = 1
specific_transform_params[channel_name]["Lead"] = np.arange(
start, end + step, step
)
if transformation == "Lag":
st.markdown(f"**Lag ({date_granularity})**")
lag = st.slider(
"Lag periods",
1,
10,
predefined_defualts["Lag"],
1,
key=f"lag_{channel_name}_specific",
label_visibility="collapsed",
)
start = lag[0]
end = lag[1]
step = 1
specific_transform_params[channel_name]["Lag"] = np.arange(
start, end + step, step
)
if transformation == "Moving Average":
st.markdown(f"**Moving Average ({date_granularity})**")
window = st.slider(
"Window size for Moving Average",
1,
10,
predefined_defualts["Moving Average"],
1,
key=f"ma_{channel_name}_specific",
label_visibility="collapsed",
)
start = window[0]
end = window[1]
step = 1
specific_transform_params[channel_name]["Moving Average"] = (
np.arange(start, end + step, step)
)
if transformation == "Saturation":
st.markdown("**Saturation (%)**")
saturation_point = st.slider(
f"Saturation Percentage",
0,
100,
predefined_defualts["Saturation"],
10,
key=f"sat_{channel_name}_specific",
label_visibility="collapsed",
)
start = saturation_point[0]
end = saturation_point[1]
step = 10
specific_transform_params[channel_name]["Saturation"] = np.arange(
start, end + step, step
)
if transformation == "Power":
st.markdown("**Power**")
power = st.slider(
f"Power",
0,
10,
predefined_defualts["Power"],
1,
key=f"power_{channel_name}_specific",
label_visibility="collapsed",
)
start = power[0]
end = power[1]
step = 1
specific_transform_params[channel_name]["Power"] = np.arange(
start, end + step, step
)
if transformation == "Adstock":
st.markdown("**Adstock**")
rate = st.slider(
f"Factor",
0.0,
1.0,
predefined_defualts["Adstock"],
0.05,
key=f"adstock_{channel_name}_specific",
label_visibility="collapsed",
)
start = rate[0]
end = rate[1]
step = 0.05
adstock_range = [
round(a, 3) for a in np.arange(start, end + step, step)
]
specific_transform_params[channel_name]["Adstock"] = np.array(
adstock_range
)
# Function to apply Lag transformation
def apply_lag(df, lag):
return df.shift(lag)
# Function to apply Lead transformation
def apply_lead(df, lead):
return df.shift(-lead)
# Function to apply Moving Average transformation
def apply_moving_average(df, window_size):
return df.rolling(window=window_size).mean()
# Function to apply Moving Average transformation
def apply_saturation(df, saturation_percent_100):
# Convert saturation percentage from 100-based to fraction
saturation_percent = saturation_percent_100 / 100.0
# Calculate saturation point and steepness
column_max = df.max()
column_min = df.min()
saturation_point = (column_min + column_max) / 2
numerator = np.log(
(1 / (saturation_percent if saturation_percent != 1 else 1 - 1e-9)) - 1
)
denominator = np.log(saturation_point / max(column_max, 1e-9))
steepness = numerator / max(
denominator, 1e-9
) # Avoid division by zero with a small constant
# Apply the saturation transformation with safeguard for division by zero
transformed_series = df.apply(
lambda x: (
1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)
)
* x
)
return transformed_series
# Function to apply Power transformation
def apply_power(df, power):
return df**power
# Function to apply Adstock transformation
def apply_adstock(df, factor):
x = 0
# Use the walrus operator to update x iteratively with the Adstock formula
adstock_var = [x := x * factor + v for v in df]
ans = pd.Series(adstock_var, index=df.index)
return ans
# Function to generate transformed columns names
@st.cache_resource(show_spinner=False)
def generate_transformed_columns(
original_columns, transform_params, specific_transform_params
):
transformed_columns, summary = {}, {}
for category, columns in original_columns.items():
for column in columns:
transformed_columns[column] = []
summary_details = (
[]
) # List to hold transformation details for the current column
if column in specific_transform_params.keys():
for transformation, values in specific_transform_params[
column
].items():
# Generate transformed column names for each value
for value in values:
transformed_name = f"{column}@{transformation}_{value}"
transformed_columns[column].append(transformed_name)
# Format the values list as a string with commas and "and" before the last item
if len(values) > 1:
formatted_values = (
", ".join(map(str, values[:-1]))
+ " and "
+ str(values[-1])
)
else:
formatted_values = str(values[0])
# Add transformation details
summary_details.append(f"{transformation} ({formatted_values})")
else:
if category in transform_params:
for transformation, values in transform_params[
category
].items():
# Generate transformed column names for each value
for value in values:
transformed_name = f"{column}@{transformation}_{value}"
transformed_columns[column].append(transformed_name)
# Format the values list as a string with commas and "and" before the last item
if len(values) > 1:
formatted_values = (
", ".join(map(str, values[:-1]))
+ " and "
+ str(values[-1])
)
else:
formatted_values = str(values[0])
# Add transformation details
summary_details.append(
f"{transformation} ({formatted_values})"
)
# Only add to summary if there are transformation details for the column
if summary_details:
formatted_summary = "⮕ ".join(summary_details)
# Use <strong> tags to make the column name bold
summary[column] = f"<strong>{column}</strong>: {formatted_summary}"
# Generate a comprehensive summary string for all columns
summary_items = [
f"{idx + 1}. {details}" for idx, details in enumerate(summary.values())
]
summary_string = "\n".join(summary_items)
return transformed_columns, summary_string
# Function to transform Dataframe slice
def transform_slice(
transform_params,
transformation_functions,
panel,
df,
df_slice,
category,
category_df,
):
# Iterate through each transformation and its parameters for the current category
for transformation, parameters in transform_params[category].items():
transformation_function = transformation_functions[transformation]
# Check if there is panel data to group by
if len(panel) > 0:
# Apply the transformation to each group
category_df = pd.concat(
[
df_slice.groupby(panel)
.transform(transformation_function, p)
.add_suffix(f"@{transformation}_{p}")
for p in parameters
],
axis=1,
)
# Replace all NaN or null values in category_df with 0
category_df.fillna(0, inplace=True)
# Update df_slice
df_slice = pd.concat(
[df[panel], category_df],
axis=1,
)
else:
for p in parameters:
# Apply the transformation function to each column
temp_df = df_slice.apply(
lambda x: transformation_function(x, p), axis=0
).rename(
lambda x: f"{x}@{transformation}_{p}",
axis="columns",
)
# Concatenate the transformed DataFrame slice to the category DataFrame
category_df = pd.concat([category_df, temp_df], axis=1)
# Replace all NaN or null values in category_df with 0
category_df.fillna(0, inplace=True)
# Update df_slice
df_slice = pd.concat(
[df[panel], category_df],
axis=1,
)
return category_df, df, df_slice
# Function to apply transformations to DataFrame slices based on specified categories and parameters
@st.cache_resource(show_spinner=False)
def apply_category_transformations(
df_main, bin_dict, transform_params, panel, specific_transform_params
):
# Dictionary for function mapping
transformation_functions = {
"Lead": apply_lead,
"Lag": apply_lag,
"Moving Average": apply_moving_average,
"Saturation": apply_saturation,
"Power": apply_power,
"Adstock": apply_adstock,
}
# List to collect all transformed DataFrames
transformed_dfs = []
# Iterate through each category specified in transform_params
for category in ["Media", "Exogenous", "Internal"]:
if (
category not in transform_params
or category not in bin_dict
or not transform_params[category]
):
continue # Skip categories without transformations
# Initialize category_df as an empty DataFrame
category_df = pd.DataFrame()
# Slice the DataFrame based on the columns specified in bin_dict for the current category
df_slice = df_main[bin_dict[category] + panel].copy()
# Drop the column from df_slice to skip specific transformations
df_slice = df_slice.drop(
columns=list(specific_transform_params.keys()), errors="ignore"
).copy()
category_df, df, df_slice_updated = transform_slice(
transform_params.copy(),
transformation_functions.copy(),
panel,
df_main.copy(),
df_slice.copy(),
category,
category_df.copy(),
)
# Append the transformed category DataFrame to the list if it's not empty
if not category_df.empty:
transformed_dfs.append(category_df)
# Apply channel specific transforms
for channel_specific in specific_transform_params:
# Initialize category_df as an empty DataFrame
category_df = pd.DataFrame()
df_slice_specific = df_main[[channel_specific] + panel].copy()
transform_params_specific = {
"Media": specific_transform_params[channel_specific]
}
category_df, df, df_slice_specific_updated = transform_slice(
transform_params_specific.copy(),
transformation_functions.copy(),
panel,
df_main.copy(),
df_slice_specific.copy(),
"Media",
category_df.copy(),
)
# Append the transformed category DataFrame to the list if it's not empty
if not category_df.empty:
transformed_dfs.append(category_df)
# If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
if len(transformed_dfs) > 0:
final_df = pd.concat([df_main] + transformed_dfs, axis=1)
else:
# If no transformations were applied, use the original DataFrame
final_df = df_main
# Find columns with '@' in their names
columns_with_at = [col for col in final_df.columns if "@" in col]
# Create a set of columns to drop
columns_to_drop = set()
# Iterate through columns with '@' to find shorter names to drop
for col in columns_with_at:
base_name = col.split("@")[0]
for other_col in columns_with_at:
if other_col.startswith(base_name) and len(other_col.split("@")) > len(
col.split("@")
):
columns_to_drop.add(col)
break
# Drop the identified columns from the DataFrame
final_df.drop(columns=list(columns_to_drop), inplace=True)
return final_df
# Function to infers the granularity of the date column in a DataFrame
@st.cache_resource(show_spinner=False)
def infer_date_granularity(df):
# Find the most common difference
common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
# Map the most common difference to a granularity
if common_freq == 1:
return "daily"
elif common_freq == 7:
return "weekly"
elif 28 <= common_freq <= 31:
return "monthly"
else:
return "irregular"
#########################################################################################################################################################
# User input for transformations
#########################################################################################################################################################
# Infer date granularity
date_granularity = infer_date_granularity(final_df_loaded)
# Initialize the main dictionary to store the transformation parameters for each category
transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
# User input for transformations
cols1 = st.columns([2, 1])
with cols1[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with cols1[1]:
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
st.markdown("### Select Transformations to Apply")
with st.expander("Specific Transformations"):
select_specific_channels = st.multiselect(
"Select channels", options=bin_dict_loaded["Media"]
)
specific_transform_params = {}
for select_specific_channel in select_specific_channels:
specific_transform_params[select_specific_channel] = {}
st.divider()
channel_name = str(select_specific_channel).replace("_", " ").title()
st.markdown(f"###### {channel_name}")
transformations_to_apply = st.multiselect(
"Select transformations to apply",
options=[
"Lag",
"Moving Average",
"Saturation",
"Power",
"Adstock",
],
default="Adstock",
key=f"specific_transformation_{select_specific_channel}_Media",
)
# Determine the number of transformations to put in each column
transformations_per_column = (
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
)
# Create two columns
col1, col2 = st.columns(2)
# Assign transformations to each column
transformations_col1 = transformations_to_apply[:transformations_per_column]
transformations_col2 = transformations_to_apply[transformations_per_column:]
# Create widgets in each column
create_specific_transformation_widgets(
col1,
transformations_col1,
select_specific_channel,
date_granularity,
specific_transform_params,
)
create_specific_transformation_widgets(
col2,
transformations_col2,
select_specific_channel,
date_granularity,
specific_transform_params,
)
for category in ["Media", "Internal", "Exogenous"]:
# Skip Internal
if category == "Internal":
continue
transformation_widgets(category, transform_params, date_granularity)
#########################################################################################################################################################
# Apply transformations
#########################################################################################################################################################
# Apply category-based transformations to the DataFrame
if st.button("Accept and Proceed", use_container_width=True):
with st.spinner("Applying transformations..."):
final_df = apply_category_transformations(
final_df_loaded.copy(),
bin_dict_loaded.copy(),
transform_params.copy(),
panel.copy(),
specific_transform_params.copy(),
)
# Generate a dictionary mapping original column names to lists of transformed column names
transformed_columns_dict, summary_string = generate_transformed_columns(
original_columns, transform_params, specific_transform_params
)
# Store into transformed dataframe and summary session state
st.session_state["final_df"] = final_df
st.session_state["summary_string"] = summary_string
#########################################################################################################################################################
# Display the transformed DataFrame and summary
#########################################################################################################################################################
# Display the transformed DataFrame in the Streamlit app
st.markdown("### Transformed DataFrame")
final_df = st.session_state["final_df"].copy()
sort_col = []
for col in final_df.columns:
if col in ["Panel_1", "Panel_2", "date"]:
sort_col.append(col)
sorted_final_df = final_df.sort_values(
by=sort_col, ascending=True, na_position="first"
)
# Dropping duplicate columns
sorted_final_df = sorted_final_df.loc[:, ~sorted_final_df.columns.duplicated()]
# Check the number of columns and show only the first 500 if there are more
if sorted_final_df.shape[1] > 500:
# Display a warning if the DataFrame has more than 500 columns
st.warning(
"The transformed DataFrame has more than 500 columns. Displaying only the first 500 columns.",
icon="⚠️",
)
st.dataframe(sorted_final_df.iloc[:, :500], hide_index=True)
else:
st.dataframe(sorted_final_df, hide_index=True)
# Total rows and columns
total_rows, total_columns = st.session_state["final_df"].shape
st.markdown(
f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
unsafe_allow_html=True,
)
# Display the summary of transformations as markdown
if "summary_string" in st.session_state and st.session_state["summary_string"]:
with st.expander("Summary of Transformations"):
st.markdown("### Summary of Transformations")
st.markdown(st.session_state["summary_string"], unsafe_allow_html=True)
@st.cache_resource(show_spinner=False)
def save_to_pickle(file_path, final_df):
# Open the file in write-binary mode and dump the objects
with open(file_path, "wb") as f:
pickle.dump({"final_df_transformed": final_df}, f)
# Data is now saved to file
#########################################################################################################################################################
# Correlation Plot
#########################################################################################################################################################
# Filter out the 'date' column
variables = [col for col in final_df.columns if col.lower() != "date"]
# Expander with multiselect
with st.expander("Transformed Variable Correlation Plot"):
selected_vars = st.multiselect(
"Choose variables for correlation plot:", variables
)
# Calculate correlation
if selected_vars:
corr_df = final_df[selected_vars].corr()
# Prepare text annotations with 2 decimal places
annotations = []
for i in range(len(corr_df)):
for j in range(len(corr_df.columns)):
annotations.append(
go.layout.Annotation(
text=f"{corr_df.iloc[i, j]:.2f}",
x=corr_df.columns[j],
y=corr_df.index[i],
showarrow=False,
font=dict(color="black"),
)
)
# Plotly correlation plot using go
heatmap = go.Heatmap(
z=corr_df.values,
x=corr_df.columns,
y=corr_df.index,
colorscale="RdBu",
zmin=-1,
zmax=1,
)
layout = go.Layout(
title="Transformed Variable Correlation Plot",
xaxis=dict(title="Variables"),
yaxis=dict(title="Variables"),
width=1000,
height=1000,
annotations=annotations,
)
fig = go.Figure(data=[heatmap], layout=layout)
st.plotly_chart(fig)
else:
st.write("Please select at least one variable to plot.")
#########################################################################################################################################################
# Accept and Save
#########################################################################################################################################################
if st.button("Accept and Save", use_container_width=True):
save_to_pickle(
os.path.join(st.session_state["project_path"], "final_df_transformed.pkl"),
st.session_state["final_df"],
)
project_dct_path = os.path.join(
st.session_state["project_path"], "project_dct.pkl"
)
with open(project_dct_path, "wb") as f:
pickle.dump(st.session_state["project_dct"], f)
update_db("3_Transformations.py")
st.toast("💾 Saved Successfully!")