""" MMO Build Sprint 3 date : changes : capability to tune MixedLM as well as simple LR in the same page """ import os import streamlit as st import pandas as pd from Eda_functions import format_numbers import pickle from utilities import set_header, load_local_css import statsmodels.api as sm import re from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt from statsmodels.stats.outliers_influence import variance_inflation_factor import yaml from yaml import SafeLoader import streamlit_authenticator as stauth st.set_option("deprecation.showPyplotGlobalUse", False) import statsmodels.formula.api as smf from Data_prep_functions import * import sqlite3 from utilities import update_db # for i in ["model_tuned", "X_train_tuned", "X_test_tuned", "tuned_model_features", "tuned_model", "tuned_model_dict"] : st.set_page_config( page_title="Model Tuning", page_icon=":shark:", layout="wide", initial_sidebar_state="collapsed", ) load_local_css("styles.css") set_header() # Check for authentication status for k, v in st.session_state.items(): # print(k, v) if k not in [ "logout", "login", "config", "build_tuned_model", ] and not k.startswith("FormSubmitter"): st.session_state[k] = v with open("config.yaml") as file: config = yaml.load(file, Loader=SafeLoader) st.session_state["config"] = config authenticator = stauth.Authenticate( config["credentials"], config["cookie"]["name"], config["cookie"]["key"], config["cookie"]["expiry_days"], config["preauthorized"], ) st.session_state["authenticator"] = authenticator name, authentication_status, username = authenticator.login("Login", "main") auth_status = st.session_state.get("authentication_status") if auth_status == True: authenticator.logout("Logout", "main") is_state_initiaized = st.session_state.get("initialized", False) if "project_dct" not in st.session_state: st.error("Please load a project from Home page") st.stop() if not os.path.exists( os.path.join(st.session_state["project_path"], "best_models.pkl") ): st.error("Please save a model before tuning") st.stop() conn = sqlite3.connect( r"DB/User.db", check_same_thread=False ) # connection with sql db c = conn.cursor() if not is_state_initiaized: if "session_name" not in st.session_state: st.session_state["session_name"] = None if ( "session_state_saved" in st.session_state["project_dct"]["model_build"].keys() ): for key in [ "Model", "date", "saved_model_names", "media_data", "X_test_spends", ]: if key not in st.session_state: st.session_state[key] = st.session_state["project_dct"][ "model_build" ]["session_state_saved"][key] st.session_state["bin_dict"] = st.session_state["project_dct"][ "model_build" ]["session_state_saved"]["bin_dict"] if ( "used_response_metrics" not in st.session_state or st.session_state["used_response_metrics"] == [] ): st.session_state["used_response_metrics"] = st.session_state[ "project_dct" ]["model_build"]["session_state_saved"][ "used_response_metrics" ] else: st.error("Please load a session with a built model") st.stop() # if 'sel_model' not in st.session_state["project_dct"]["model_tuning"].keys(): # st.session_state["project_dct"]["model_tuning"]['sel_model']= {} for key in ["select_all_flags_check", "selected_flags", "sel_model"]: if key not in st.session_state["project_dct"]["model_tuning"].keys(): st.session_state["project_dct"]["model_tuning"][key] = {} # Sprint3 # is_panel = st.session_state['is_panel'] # panel_col = 'markets' # set the panel column date_col = "date" panel_col = [ col.lower() .replace(".", "_") .replace("@", "_") .replace(" ", "_") .replace("-", "") .replace(":", "") .replace("__", "_") for col in st.session_state["bin_dict"]["Panel Level 1"] ][ 0 ] # set the panel column is_panel = True if len(panel_col) > 0 else False # flag indicating there is not tuned model till now # Sprint4 - model tuned dict if "Model_Tuned" not in st.session_state: st.session_state["Model_Tuned"] = {} st.title("1. Model Tuning") if "is_tuned_model" not in st.session_state: st.session_state["is_tuned_model"] = {} # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default if ( "used_response_metrics" in st.session_state and st.session_state["used_response_metrics"] != [] ): default_target_idx = ( st.session_state["project_dct"]["model_tuning"].get( "sel_target_col", None ) if st.session_state["project_dct"]["model_tuning"].get( "sel_target_col", None ) is not None else st.session_state["used_response_metrics"][0] ) def format_display(inp): return inp.title().replace("_", " ").strip() sel_target_col = st.selectbox( "Select the response metric", st.session_state["used_response_metrics"], index=st.session_state["used_response_metrics"].index( default_target_idx ), format_func=format_display, ) target_col = ( sel_target_col.lower() .replace(" ", "_") .replace("-", "") .replace(":", "") .replace("__", "_") ) st.session_state["project_dct"]["model_tuning"][ "sel_target_col" ] = sel_target_col else: sel_target_col = "Total Approved Accounts - Revenue" target_col = "total_approved_accounts_revenue" # Sprint4 - Look through all saved models, only show saved models of the sel resp metric (target_col) # saved_models = st.session_state['saved_model_names'] with open( os.path.join(st.session_state["project_path"], "best_models.pkl"), "rb" ) as file: model_dict = pickle.load(file) saved_models = model_dict.keys() required_saved_models = [ m.split("__")[0] for m in saved_models if m.split("__")[1] == target_col ] if len(required_saved_models) > 0: default_model_idx = st.session_state["project_dct"]["model_tuning"][ "sel_model" ].get(sel_target_col, required_saved_models[0]) sel_model = st.selectbox( "Select the model to tune", required_saved_models, index=required_saved_models.index(default_model_idx), ) else: default_model_idx = st.session_state["project_dct"]["model_tuning"][ "sel_model" ].get(sel_target_col, 0) sel_model = st.selectbox( "Select the model to tune", required_saved_models ) st.session_state["project_dct"]["model_tuning"]["sel_model"][ sel_target_col ] = default_model_idx sel_model_dict = model_dict[ sel_model + "__" + target_col ] # Sprint4 - get the model obj of the selected model X_train = sel_model_dict["X_train"] X_test = sel_model_dict["X_test"] y_train = sel_model_dict["y_train"] y_test = sel_model_dict["y_test"] df = st.session_state["media_data"] if "selected_model" not in st.session_state: st.session_state["selected_model"] = 0 st.markdown("### 1.1 Event Flags") st.markdown( "Helps in quantifying the impact of specific occurrences of events" ) flag_expander_default = ( st.session_state["project_dct"]["model_tuning"].get( "flag_expander", None ) if st.session_state["project_dct"]["model_tuning"].get( "flag_expander", None ) is not None else False ) with st.expander("Apply Event Flags", flag_expander_default): st.session_state["project_dct"]["model_tuning"]["flag_expander"] = True model = sel_model_dict["Model_object"] date = st.session_state["date"] date = pd.to_datetime(date) X_train = sel_model_dict["X_train"] # features_set= model_dict[st.session_state["selected_model"]]['feature_set'] features_set = sel_model_dict["feature_set"] col = st.columns(3) min_date = min(date) max_date = max(date) start_date_default = ( st.session_state["project_dct"]["model_tuning"].get( "start_date_default" ) if st.session_state["project_dct"]["model_tuning"].get( "start_date_default" ) is not None else min_date ) end_date_default = ( st.session_state["project_dct"]["model_tuning"].get( "end_date_default" ) if st.session_state["project_dct"]["model_tuning"].get( "end_date_default" ) is not None else max_date ) with col[0]: start_date = st.date_input( "Select Start Date", start_date_default, min_value=min_date, max_value=max_date, ) with col[1]: end_date_default = ( end_date_default if end_date_default >= start_date else start_date ) end_date = st.date_input( "Select End Date", end_date_default, min_value=max(min_date, start_date), max_value=max_date, ) with col[2]: repeat_default = ( st.session_state["project_dct"]["model_tuning"].get( "repeat_default" ) if st.session_state["project_dct"]["model_tuning"].get( "repeat_default" ) is not None else "No" ) repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1 repeat = st.selectbox( "Repeat Annually", ["Yes", "No"], index=repeat_default_idx ) st.session_state["project_dct"]["model_tuning"][ "start_date_default" ] = start_date st.session_state["project_dct"]["model_tuning"][ "end_date_default" ] = end_date st.session_state["project_dct"]["model_tuning"][ "repeat_default" ] = repeat if repeat == "Yes": repeat = True else: repeat = False if "Flags" not in st.session_state: st.session_state["Flags"] = {} if "flags" in st.session_state["project_dct"]["model_tuning"].keys(): st.session_state["Flags"] = st.session_state["project_dct"][ "model_tuning" ]["flags"] # print("**"*50) # print(y_train) # print("**"*50) # print(model.fittedvalues) if is_panel: # Sprint3 met, line_values, fig_flag = plot_actual_vs_predicted( X_train[date_col], y_train, model.fittedvalues, model, target_column=sel_target_col, flag=(start_date, end_date), repeat_all_years=repeat, is_panel=True, ) st.plotly_chart(fig_flag, use_container_width=True) # create flag on test met, test_line_values, fig_flag = plot_actual_vs_predicted( X_test[date_col], y_test, sel_model_dict["pred_test"], model, target_column=sel_target_col, flag=(start_date, end_date), repeat_all_years=repeat, is_panel=True, ) else: pred_train = model.predict(X_train[features_set]) met, line_values, fig_flag = plot_actual_vs_predicted( X_train[date_col], y_train, pred_train, model, flag=(start_date, end_date), repeat_all_years=repeat, is_panel=False, ) st.plotly_chart(fig_flag, use_container_width=True) pred_test = model.predict(X_test[features_set]) met, test_line_values, fig_flag = plot_actual_vs_predicted( X_test[date_col], y_test, pred_test, model, flag=(start_date, end_date), repeat_all_years=repeat, is_panel=False, ) flag_name = "f1_flag" flag_name = st.text_input("Enter Flag Name") # Sprint4 - add selected target col to flag name if st.button("Update flag"): st.session_state["Flags"][flag_name + "__" + target_col] = {} st.session_state["Flags"][flag_name + "__" + target_col][ "train" ] = line_values st.session_state["Flags"][flag_name + "__" + target_col][ "test" ] = test_line_values st.success(f'{flag_name + "__" + target_col} stored') st.session_state["project_dct"]["model_tuning"]["flags"] = ( st.session_state["Flags"] ) # Sprint4 - only show flag created for the particular target col if st.session_state["Flags"] is None: st.session_state["Flags"] = {} target_model_flags = [ f.split("__")[0] for f in st.session_state["Flags"].keys() if f.split("__")[1] == target_col ] options = list(target_model_flags) selected_options = [] num_columns = 4 num_rows = -(-len(options) // num_columns) tick = False if st.checkbox( "Select all", value=st.session_state["project_dct"]["model_tuning"][ "select_all_flags_check" ].get(sel_target_col, False), ): tick = True st.session_state["project_dct"]["model_tuning"][ "select_all_flags_check" ][sel_target_col] = True else: st.session_state["project_dct"]["model_tuning"][ "select_all_flags_check" ][sel_target_col] = False selection_defualts = st.session_state["project_dct"]["model_tuning"][ "selected_flags" ].get(sel_target_col, []) selected_options = selection_defualts for row in range(num_rows): cols = st.columns(num_columns) for col in cols: if options: option = options.pop(0) option_default = ( True if option in selection_defualts else False ) selected = col.checkbox(option, value=(tick or option_default)) if selected: selected_options.append(option) st.session_state["project_dct"]["model_tuning"]["selected_flags"][ sel_target_col ] = selected_options st.markdown("### 1.2 Select Parameters to Apply") parameters = st.columns(3) with parameters[0]: Trend = st.checkbox( "**Trend**", value=st.session_state["project_dct"]["model_tuning"].get( "trend_check", False ), ) st.markdown( "Helps account for long-term trends or seasonality that could influence advertising effectiveness" ) with parameters[1]: week_number = st.checkbox( "**Week_number**", value=st.session_state["project_dct"]["model_tuning"].get( "week_num_check", False ), ) st.markdown( "Assists in detecting and incorporating weekly patterns or seasonality" ) with parameters[2]: sine_cosine = st.checkbox( "**Sine and Cosine Waves**", value=st.session_state["project_dct"]["model_tuning"].get( "sine_cosine_check", False ), ) st.markdown( "Helps in capturing cyclical patterns or seasonality in the data" ) # # def get_tuned_model(): # st.session_state['build_tuned_model']=True if st.button( "Build model with Selected Parameters and Flags", key="build_tuned_model",use_container_width=True ): new_features = features_set st.header("2.1 Results Summary") # date=list(df.index) # df = df.reset_index(drop=True) # X_train=df[features_set] ss = MinMaxScaler() if is_panel == True: X_train_tuned = X_train[features_set] # X_train_tuned = pd.DataFrame(ss.fit_transform(X), columns=X.columns) X_train_tuned[target_col] = X_train[target_col] X_train_tuned[date_col] = X_train[date_col] X_train_tuned[panel_col] = X_train[panel_col] X_test_tuned = X_test[features_set] # X_test_tuned = pd.DataFrame(ss.transform(X), columns=X.columns) X_test_tuned[target_col] = X_test[target_col] X_test_tuned[date_col] = X_test[date_col] X_test_tuned[panel_col] = X_test[panel_col] else: X_train_tuned = X_train[features_set] # X_train_tuned = pd.DataFrame(ss.fit_transform(X_train_tuned), columns=X_train_tuned.columns) X_test_tuned = X_test[features_set] # X_test_tuned = pd.DataFrame(ss.transform(X_test_tuned), columns=X_test_tuned.columns) for flag in selected_options: # Spirnt4 - added target_col in flag name X_train_tuned[flag] = st.session_state["Flags"][ flag + "__" + target_col ]["train"] X_test_tuned[flag] = st.session_state["Flags"][ flag + "__" + target_col ]["test"] # test # X_train_tuned.to_csv("Test/X_train_tuned_flag.csv",index=False) # X_test_tuned.to_csv("Test/X_test_tuned_flag.csv",index=False) # print("()()"*20,flag, len(st.session_state['Flags'][flag])) if Trend: st.session_state["project_dct"]["model_tuning"][ "trend_check" ] = True # Sprint3 - group by panel, calculate trend of each panel spearately. Add trend to new feature set if is_panel: newdata = pd.DataFrame() panel_wise_end_point_train = {} for panel, groupdf in X_train_tuned.groupby(panel_col): groupdf.sort_values(date_col, inplace=True) groupdf["Trend"] = np.arange(1, len(groupdf) + 1, 1) newdata = pd.concat([newdata, groupdf]) panel_wise_end_point_train[panel] = len(groupdf) X_train_tuned = newdata.copy() test_newdata = pd.DataFrame() for panel, test_groupdf in X_test_tuned.groupby(panel_col): test_groupdf.sort_values(date_col, inplace=True) start = panel_wise_end_point_train[panel] + 1 end = start + len(test_groupdf) # should be + 1? - Sprint4 # print("??"*20, panel, len(test_groupdf), len(np.arange(start, end, 1)), start) test_groupdf["Trend"] = np.arange(start, end, 1) test_newdata = pd.concat([test_newdata, test_groupdf]) X_test_tuned = test_newdata.copy() new_features = new_features + ["Trend"] else: X_train_tuned["Trend"] = np.arange( 1, len(X_train_tuned) + 1, 1 ) X_test_tuned["Trend"] = np.arange( len(X_train_tuned) + 1, len(X_train_tuned) + len(X_test_tuned) + 1, 1, ) new_features = new_features + ["Trend"] else: st.session_state["project_dct"]["model_tuning"][ "trend_check" ] = False if week_number: st.session_state["project_dct"]["model_tuning"][ "week_num_check" ] = True # Sprint3 - create weeknumber from date column in xtrain tuned. add week num to new feature set if is_panel: X_train_tuned[date_col] = pd.to_datetime( X_train_tuned[date_col] ) X_train_tuned["Week_number"] = X_train_tuned[ date_col ].dt.day_of_week if X_train_tuned["Week_number"].nunique() == 1: st.write( "All dates in the data are of the same week day. Hence Week number can't be used." ) else: X_test_tuned[date_col] = pd.to_datetime( X_test_tuned[date_col] ) X_test_tuned["Week_number"] = X_test_tuned[ date_col ].dt.day_of_week new_features = new_features + ["Week_number"] else: date = pd.to_datetime(date.values) X_train_tuned["Week_number"] = pd.to_datetime( X_train[date_col] ).dt.day_of_week X_test_tuned["Week_number"] = pd.to_datetime( X_test[date_col] ).dt.day_of_week new_features = new_features + ["Week_number"] else: st.session_state["project_dct"]["model_tuning"][ "week_num_check" ] = False if sine_cosine: st.session_state["project_dct"]["model_tuning"][ "sine_cosine_check" ] = True # Sprint3 - create panel wise sine cosine waves in xtrain tuned. add to new feature set if is_panel: new_features = new_features + ["sine_wave", "cosine_wave"] newdata = pd.DataFrame() newdata_test = pd.DataFrame() groups = X_train_tuned.groupby(panel_col) frequency = 2 * np.pi / 365 # Adjust the frequency as needed train_panel_wise_end_point = {} for panel, groupdf in groups: num_samples = len(groupdf) train_panel_wise_end_point[panel] = num_samples days_since_start = np.arange(num_samples) sine_wave = np.sin(frequency * days_since_start) cosine_wave = np.cos(frequency * days_since_start) sine_cosine_df = pd.DataFrame( {"sine_wave": sine_wave, "cosine_wave": cosine_wave} ) assert len(sine_cosine_df) == len(groupdf) # groupdf = pd.concat([groupdf, sine_cosine_df], axis=1) groupdf["sine_wave"] = sine_wave groupdf["cosine_wave"] = cosine_wave newdata = pd.concat([newdata, groupdf]) X_train_tuned = newdata.copy() test_groups = X_test_tuned.groupby(panel_col) for panel, test_groupdf in test_groups: num_samples = len(test_groupdf) start = train_panel_wise_end_point[panel] days_since_start = np.arange(start, start + num_samples, 1) # print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1))) sine_wave = np.sin(frequency * days_since_start) cosine_wave = np.cos(frequency * days_since_start) sine_cosine_df = pd.DataFrame( {"sine_wave": sine_wave, "cosine_wave": cosine_wave} ) assert len(sine_cosine_df) == len(test_groupdf) # groupdf = pd.concat([groupdf, sine_cosine_df], axis=1) test_groupdf["sine_wave"] = sine_wave test_groupdf["cosine_wave"] = cosine_wave newdata_test = pd.concat([newdata_test, test_groupdf]) X_test_tuned = newdata_test.copy() else: new_features = new_features + ["sine_wave", "cosine_wave"] num_samples = len(X_train_tuned) frequency = 2 * np.pi / 365 # Adjust the frequency as needed days_since_start = np.arange(num_samples) sine_wave = np.sin(frequency * days_since_start) cosine_wave = np.cos(frequency * days_since_start) sine_cosine_df = pd.DataFrame( {"sine_wave": sine_wave, "cosine_wave": cosine_wave} ) # Concatenate the sine and cosine waves with the scaled X DataFrame X_train_tuned = pd.concat( [X_train_tuned, sine_cosine_df], axis=1 ) test_num_samples = len(X_test_tuned) start = num_samples days_since_start = np.arange( start, start + test_num_samples, 1 ) sine_wave = np.sin(frequency * days_since_start) cosine_wave = np.cos(frequency * days_since_start) sine_cosine_df = pd.DataFrame( {"sine_wave": sine_wave, "cosine_wave": cosine_wave} ) # Concatenate the sine and cosine waves with the scaled X DataFrame X_test_tuned = pd.concat( [X_test_tuned, sine_cosine_df], axis=1 ) else: st.session_state["project_dct"]["model_tuning"][ "sine_cosine_check" ] = False # model if selected_options: new_features = new_features + selected_options if is_panel: inp_vars_str = " + ".join(new_features) new_features = list(set(new_features)) md_str = target_col + " ~ " + inp_vars_str md_tuned = smf.mixedlm( md_str, data=X_train_tuned[[target_col] + new_features], groups=X_train_tuned[panel_col], ) model_tuned = md_tuned.fit() # plot act v pred for original model and tuned model metrics_table, line, actual_vs_predicted_plot = ( plot_actual_vs_predicted( X_train[date_col], y_train, model.fittedvalues, model, target_column=sel_target_col, is_panel=True, ) ) metrics_table_tuned, line, actual_vs_predicted_plot_tuned = ( plot_actual_vs_predicted( X_train_tuned[date_col], X_train_tuned[target_col], model_tuned.fittedvalues, model_tuned, target_column=sel_target_col, is_panel=True, ) ) else: new_features = list(set(new_features)) model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit() metrics_table, line, actual_vs_predicted_plot = ( plot_actual_vs_predicted( date[:130], y_train, model.predict(X_train[features_set]), model, target_column=sel_target_col, ) ) metrics_table_tuned, line, actual_vs_predicted_plot_tuned = ( plot_actual_vs_predicted( date[:130], y_train, model_tuned.predict(X_train_tuned), model_tuned, target_column=sel_target_col, ) ) mape = np.round(metrics_table.iloc[0, 1], 2) r2 = np.round(metrics_table.iloc[1, 1], 2) adjr2 = np.round(metrics_table.iloc[2, 1], 2) mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2) r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2) adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2) parameters_ = st.columns(3) with parameters_[0]: st.metric("R2", r2_tuned, np.round(r2_tuned - r2, 2)) with parameters_[1]: st.metric( "Adjusted R2", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2) ) with parameters_[2]: st.metric( "MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse" ) st.write(model_tuned.summary()) X_train_tuned[date_col] = X_train[date_col] X_test_tuned[date_col] = X_test[date_col] X_train_tuned[target_col] = y_train X_test_tuned[target_col] = y_test st.header("2.2 Actual vs. Predicted Plot") # if is_panel: # metrics_table, line, actual_vs_predicted_plot = plot_actual_vs_predicted(date, y_train, model.predict(X_train), # model, target_column='Revenue',is_panel=True) # else: # metrics_table,line,actual_vs_predicted_plot=plot_actual_vs_predicted(date, y_train, model.predict(X_train), model,target_column='Revenue') if is_panel: metrics_table, line, actual_vs_predicted_plot = ( plot_actual_vs_predicted( X_train_tuned[date_col], X_train_tuned[target_col], model_tuned.fittedvalues, model_tuned, target_column=sel_target_col, is_panel=True, ) ) else: metrics_table, line, actual_vs_predicted_plot = ( plot_actual_vs_predicted( X_train_tuned[date_col], X_train_tuned[target_col], model_tuned.predict(X_train_tuned[new_features]), model_tuned, target_column=sel_target_col, is_panel=False, ) ) # plot_actual_vs_predicted(X_train[date_col], y_train, # model.fittedvalues, model, # target_column='Revenue', # is_panel=is_panel) st.plotly_chart(actual_vs_predicted_plot, use_container_width=True) st.markdown("## 2.3 Residual Analysis") if is_panel: columns = st.columns(2) with columns[0]: fig = plot_residual_predicted( y_train, model_tuned.fittedvalues, X_train_tuned ) st.plotly_chart(fig) with columns[1]: st.empty() fig = qqplot(y_train, model_tuned.fittedvalues) st.plotly_chart(fig) with columns[0]: fig = residual_distribution(y_train, model_tuned.fittedvalues) st.pyplot(fig) else: columns = st.columns(2) with columns[0]: fig = plot_residual_predicted( y_train, model_tuned.predict(X_train_tuned[new_features]), X_train, ) st.plotly_chart(fig) with columns[1]: st.empty() fig = qqplot( y_train, model_tuned.predict(X_train_tuned[new_features]) ) st.plotly_chart(fig) with columns[0]: fig = residual_distribution( y_train, model_tuned.predict(X_train_tuned[new_features]) ) st.pyplot(fig) # st.session_state['is_tuned_model'][target_col] = True # Sprint4 - saved tuned model in a dict st.session_state["Model_Tuned"][sel_model + "__" + target_col] = { "Model_object": model_tuned, "feature_set": new_features, "X_train_tuned": X_train_tuned, "X_test_tuned": X_test_tuned, } # Pending # if st.session_state['build_tuned_model']==True: if st.session_state["Model_Tuned"] is not None: if st.button( "Use This model for Media Planning",use_container_width=True ): # save_model = st.button('Use this model to build response curves', key='saved_tuned_model') # if save_model: st.session_state["is_tuned_model"][target_col] = True with open( os.path.join( st.session_state["project_path"], "tuned_model.pkl" ), "wb", ) as f: # pickle.dump(st.session_state['tuned_model'], f) pickle.dump(st.session_state["Model_Tuned"], f) # Sprint4 st.session_state["project_dct"]["model_tuning"][ "session_state_saved" ] = {} for key in [ "bin_dict", "used_response_metrics", "is_tuned_model", "media_data", "X_test_spends", ]: st.session_state["project_dct"]["model_tuning"][ "session_state_saved" ][key] = st.session_state[key] project_dct_path = os.path.join( st.session_state["project_path"], "project_dct.pkl" ) with open(project_dct_path, "wb") as f: pickle.dump(st.session_state["project_dct"], f) update_db("5_Model_Tuning.py") st.success(sel_model + "__" + target_col + " Tuned saved!")