BlendMMM commited on
Commit
bd8809c
·
1 Parent(s): 46305dc

Update pages/4_Model_Build.py

Browse files
Files changed (1) hide show
  1. pages/4_Model_Build.py +1062 -1062
pages/4_Model_Build.py CHANGED
@@ -1,1062 +1,1062 @@
1
- """
2
- MMO Build Sprint 3
3
- additions : adding more variables to session state for saved model : random effect, predicted train & test
4
-
5
- MMO Build Sprint 4
6
- additions : ability to run models for different response metrics
7
- """
8
-
9
- import streamlit as st
10
- import pandas as pd
11
- import plotly.express as px
12
- import plotly.graph_objects as go
13
- from Eda_functions import format_numbers
14
- import numpy as np
15
- import pickle
16
- from st_aggrid import AgGrid
17
- from st_aggrid import GridOptionsBuilder, GridUpdateMode
18
- from utilities import set_header, load_local_css
19
- from st_aggrid import GridOptionsBuilder
20
- import time
21
- import itertools
22
- import statsmodels.api as sm
23
- import numpy as npc
24
- import re
25
- import itertools
26
- from sklearn.metrics import (
27
- mean_absolute_error,
28
- r2_score,
29
- mean_absolute_percentage_error,
30
- )
31
- from sklearn.preprocessing import MinMaxScaler
32
- import os
33
- import matplotlib.pyplot as plt
34
- from statsmodels.stats.outliers_influence import variance_inflation_factor
35
- import yaml
36
- from yaml import SafeLoader
37
- import streamlit_authenticator as stauth
38
-
39
- st.set_option("deprecation.showPyplotGlobalUse", False)
40
- import statsmodels.api as sm
41
- import statsmodels.formula.api as smf
42
-
43
- from datetime import datetime
44
- import seaborn as sns
45
- from Data_prep_functions import *
46
- import sqlite3
47
- from utilities import update_db
48
-
49
-
50
- @st.cache_resource(show_spinner=False)
51
- # def save_to_pickle(file_path, final_df):
52
- # # Open the file in write-binary mode and dump the objects
53
- # with open(file_path, "wb") as f:
54
- # pickle.dump({file_path: final_df}, f)
55
-
56
-
57
- def get_random_effects(media_data, panel_col, _mdf):
58
- random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
59
-
60
- for i, market in enumerate(media_data[panel_col].unique()):
61
- print(i, end="\r")
62
- intercept = _mdf.random_effects[market].values[0]
63
- random_eff_df.loc[i, "random_effect"] = intercept
64
- random_eff_df.loc[i, panel_col] = market
65
-
66
- return random_eff_df
67
-
68
-
69
- def mdf_predict(X_df, mdf, random_eff_df):
70
- X = X_df.copy()
71
- X["fixed_effect"] = mdf.predict(X)
72
- X = pd.merge(X, random_eff_df, on=panel_col, how="left")
73
- X["pred"] = X["fixed_effect"] + X["random_effect"]
74
- # X.to_csv('Test/megred_df.csv',index=False)
75
- X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
76
- return X["pred"]
77
-
78
-
79
- st.set_page_config(
80
- page_title="Model Build",
81
- page_icon=":shark:",
82
- layout="wide",
83
- initial_sidebar_state="collapsed",
84
- )
85
-
86
- load_local_css("styles.css")
87
- set_header()
88
-
89
- # Check for authentication status
90
- for k, v in st.session_state.items():
91
- if k not in [
92
- "logout",
93
- "login",
94
- "config",
95
- "model_build_button",
96
- ] and not k.startswith("FormSubmitter"):
97
- st.session_state[k] = v
98
- with open("config.yaml") as file:
99
- config = yaml.load(file, Loader=SafeLoader)
100
- st.session_state["config"] = config
101
- authenticator = stauth.Authenticate(
102
- config["credentials"],
103
- config["cookie"]["name"],
104
- config["cookie"]["key"],
105
- config["cookie"]["expiry_days"],
106
- config["preauthorized"],
107
- )
108
- st.session_state["authenticator"] = authenticator
109
- name, authentication_status, username = authenticator.login("Login", "main")
110
- auth_status = st.session_state.get("authentication_status")
111
-
112
- if auth_status == True:
113
- authenticator.logout("Logout", "main")
114
- is_state_initiaized = st.session_state.get("initialized", False)
115
-
116
- conn = sqlite3.connect(
117
- r"DB/User.db", check_same_thread=False
118
- ) # connection with sql db
119
- c = conn.cursor()
120
-
121
- if not is_state_initiaized:
122
-
123
- if "session_name" not in st.session_state:
124
- st.session_state["session_name"] = None
125
-
126
- if "project_dct" not in st.session_state:
127
- st.error("Please load a project from Home page")
128
- st.stop()
129
-
130
- st.title("1. Build Your Model")
131
-
132
- if not os.path.exists(
133
- os.path.join(st.session_state["project_path"], "data_import.pkl")
134
- ):
135
- st.error("Please move to Data Import Page and save.")
136
- st.stop()
137
- with open(
138
- os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
139
- ) as f:
140
- data = pickle.load(f)
141
- st.session_state["bin_dict"] = data["bin_dict"]
142
-
143
- if not os.path.exists(
144
- os.path.join(
145
- st.session_state["project_path"], "final_df_transformed.pkl"
146
- )
147
- ):
148
- st.error(
149
- "Please move to Transformation Page and save transformations."
150
- )
151
- st.stop()
152
- with open(
153
- os.path.join(
154
- st.session_state["project_path"], "final_df_transformed.pkl"
155
- ),
156
- "rb",
157
- ) as f:
158
- data = pickle.load(f)
159
- media_data = data["final_df_transformed"]
160
- media_data.to_csv("Test/media_data.csv", index=False)
161
- train_idx = int(len(media_data) / 5) * 4
162
- # Sprint4 - available response metrics is a list of all reponse metrics in the data
163
- ## these will be put in a drop down
164
-
165
- st.session_state["media_data"] = media_data
166
-
167
- if "available_response_metrics" not in st.session_state:
168
- # st.session_state['available_response_metrics'] = ['Total Approved Accounts - Revenue',
169
- # 'Total Approved Accounts - Appsflyer',
170
- # 'Account Requests - Appsflyer',
171
- # 'App Installs - Appsflyer']
172
-
173
- st.session_state["available_response_metrics"] = st.session_state[
174
- "bin_dict"
175
- ]["Response Metrics"]
176
- # Sprint4
177
- if "is_tuned_model" not in st.session_state:
178
- st.session_state["is_tuned_model"] = {}
179
- for resp_metric in st.session_state["available_response_metrics"]:
180
- resp_metric = (
181
- resp_metric.lower()
182
- .replace(" ", "_")
183
- .replace("-", "")
184
- .replace(":", "")
185
- .replace("__", "_")
186
- )
187
- st.session_state["is_tuned_model"][resp_metric] = False
188
-
189
- # Sprint4 - used_response_metrics is a list of resp metrics for which user has created & saved a model
190
- if "used_response_metrics" not in st.session_state:
191
- st.session_state["used_response_metrics"] = []
192
-
193
- # Sprint4 - saved_model_names
194
- if "saved_model_names" not in st.session_state:
195
- st.session_state["saved_model_names"] = []
196
-
197
- if "Model" not in st.session_state:
198
- if (
199
- "session_state_saved"
200
- in st.session_state["project_dct"]["model_build"].keys()
201
- and st.session_state["project_dct"]["model_build"][
202
- "session_state_saved"
203
- ]
204
- is not None
205
- and "Model"
206
- in st.session_state["project_dct"]["model_build"][
207
- "session_state_saved"
208
- ].keys()
209
- ):
210
- st.session_state["Model"] = st.session_state["project_dct"][
211
- "model_build"
212
- ]["session_state_saved"]["Model"]
213
- else:
214
- st.session_state["Model"] = {}
215
-
216
- # Sprint4 - select a response metric
217
- default_target_idx = (
218
- st.session_state["project_dct"]["model_build"].get(
219
- "sel_target_col", None
220
- )
221
- if st.session_state["project_dct"]["model_build"].get(
222
- "sel_target_col", None
223
- )
224
- is not None
225
- else st.session_state["available_response_metrics"][0]
226
- )
227
-
228
- sel_target_col = st.selectbox(
229
- "Select the response metric",
230
- st.session_state["available_response_metrics"],
231
- index=st.session_state["available_response_metrics"].index(
232
- default_target_idx
233
- ),
234
- )
235
- # , on_change=reset_save())
236
- st.session_state["project_dct"]["model_build"][
237
- "sel_target_col"
238
- ] = sel_target_col
239
-
240
- target_col = (
241
- sel_target_col.lower()
242
- .replace(" ", "_")
243
- .replace("-", "")
244
- .replace(":", "")
245
- .replace("__", "_")
246
- )
247
- new_name_dct = {
248
- col: col.lower()
249
- .replace(".", "_")
250
- .lower()
251
- .replace("@", "_")
252
- .replace(" ", "_")
253
- .replace("-", "")
254
- .replace(":", "")
255
- .replace("__", "_")
256
- for col in media_data.columns
257
- }
258
- media_data.columns = [
259
- col.lower()
260
- .replace(".", "_")
261
- .replace("@", "_")
262
- .replace(" ", "_")
263
- .replace("-", "")
264
- .replace(":", "")
265
- .replace("__", "_")
266
- for col in media_data.columns
267
- ]
268
- panel_col = [
269
- col.lower()
270
- .replace(".", "_")
271
- .replace("@", "_")
272
- .replace(" ", "_")
273
- .replace("-", "")
274
- .replace(":", "")
275
- .replace("__", "_")
276
- for col in st.session_state["bin_dict"]["Panel Level 1"]
277
- ][
278
- 0
279
- ] # set the panel column
280
- date_col = "date"
281
-
282
- is_panel = True if len(panel_col) > 0 else False
283
-
284
- if "is_panel" not in st.session_state:
285
- st.session_state["is_panel"] = is_panel
286
-
287
- if is_panel:
288
- media_data.sort_values([date_col, panel_col], inplace=True)
289
- else:
290
- media_data.sort_values(date_col, inplace=True)
291
-
292
- media_data.reset_index(drop=True, inplace=True)
293
-
294
- date = media_data[date_col]
295
- st.session_state["date"] = date
296
- y = media_data[target_col]
297
-
298
- if is_panel:
299
- spends_data = media_data[
300
- [
301
- c
302
- for c in media_data.columns
303
- if "_cost" in c.lower() or "_spend" in c.lower()
304
- ]
305
- + [date_col, panel_col]
306
- ]
307
- # Sprint3 - spends for resp curves
308
- else:
309
- spends_data = media_data[
310
- [
311
- c
312
- for c in media_data.columns
313
- if "_cost" in c.lower() or "_spend" in c.lower()
314
- ]
315
- + [date_col]
316
- ]
317
-
318
- y = media_data[target_col]
319
- media_data.drop([date_col], axis=1, inplace=True)
320
- media_data.reset_index(drop=True, inplace=True)
321
-
322
- columns = st.columns(2)
323
-
324
- old_shape = media_data.shape
325
-
326
- if "old_shape" not in st.session_state:
327
- st.session_state["old_shape"] = old_shape
328
-
329
- if "media_data" not in st.session_state:
330
- st.session_state["media_data"] = pd.DataFrame()
331
-
332
- # Sprint3
333
- if "orig_media_data" not in st.session_state:
334
- st.session_state["orig_media_data"] = pd.DataFrame()
335
-
336
- # Sprint3 additions
337
- if "random_effects" not in st.session_state:
338
- st.session_state["random_effects"] = pd.DataFrame()
339
- if "pred_train" not in st.session_state:
340
- st.session_state["pred_train"] = []
341
- if "pred_test" not in st.session_state:
342
- st.session_state["pred_test"] = []
343
- # end of Sprint3 additions
344
-
345
- # Section 3 - Create combinations
346
-
347
- # bucket=['paid_search', 'kwai','indicacao','infleux', 'influencer','FB: Level Achieved - Tier 1 Impressions',
348
- # ' FB: Level Achieved - Tier 2 Impressions','paid_social_others',
349
- # ' GA App: Will And Cid Pequena Baixo Risco Clicks',
350
- # 'digital_tactic_others',"programmatic"
351
- # ]
352
-
353
- # srishti - bucket names changed
354
- bucket = [
355
- "paid_search",
356
- "kwai",
357
- "indicacao",
358
- "infleux",
359
- "influencer",
360
- "fb_level_achieved_tier_2",
361
- "fb_level_achieved_tier_1",
362
- "paid_social_others",
363
- "ga_app",
364
- "digital_tactic_others",
365
- "programmatic",
366
- ]
367
-
368
- # with columns[0]:
369
- # if st.button('Create Combinations of Variables'):
370
-
371
- top_3_correlated_features = []
372
- # # for col in st.session_state['media_data'].columns[:19]:
373
- # original_cols = [c for c in st.session_state['media_data'].columns if
374
- # "_clicks" in c.lower() or "_impressions" in c.lower()]
375
- # original_cols = [c for c in original_cols if "_lag" not in c.lower() and "_adstock" not in c.lower()]
376
-
377
- original_cols = (
378
- st.session_state["bin_dict"]["Media"]
379
- + st.session_state["bin_dict"]["Internal"]
380
- )
381
-
382
- original_cols = [
383
- col.lower()
384
- .replace(".", "_")
385
- .replace("@", "_")
386
- .replace(" ", "_")
387
- .replace("-", "")
388
- .replace(":", "")
389
- .replace("__", "_")
390
- for col in original_cols
391
- ]
392
- original_cols = [col for col in original_cols if "_cost" not in col]
393
- # for col in st.session_state['media_data'].columns[:19]:
394
- for col in original_cols: # srishti - new
395
- corr_df = (
396
- pd.concat(
397
- [st.session_state["media_data"].filter(regex=col), y], axis=1
398
- )
399
- .corr()[target_col]
400
- .iloc[:-1]
401
- )
402
- top_3_correlated_features.append(
403
- list(corr_df.sort_values(ascending=False).head(2).index)
404
- )
405
- flattened_list = [
406
- item for sublist in top_3_correlated_features for item in sublist
407
- ]
408
- # all_features_set={var:[col for col in flattened_list if var in col] for var in bucket}
409
- all_features_set = {
410
- var: [col for col in flattened_list if var in col]
411
- for var in bucket
412
- if len([col for col in flattened_list if var in col]) > 0
413
- } # srishti
414
- channels_all = [values for values in all_features_set.values()]
415
- st.session_state["combinations"] = list(itertools.product(*channels_all))
416
- # if 'combinations' not in st.session_state:
417
- # st.session_state['combinations']=combinations_all
418
-
419
- st.session_state["final_selection"] = st.session_state["combinations"]
420
- # st.success('Created combinations')
421
-
422
- # revenue.reset_index(drop=True,inplace=True)
423
- y.reset_index(drop=True, inplace=True)
424
- if "Model_results" not in st.session_state:
425
- st.session_state["Model_results"] = {
426
- "Model_object": [],
427
- "Model_iteration": [],
428
- "Feature_set": [],
429
- "MAPE": [],
430
- "R2": [],
431
- "ADJR2": [],
432
- "pos_count": [],
433
- }
434
-
435
- def reset_model_result_dct():
436
- st.session_state["Model_results"] = {
437
- "Model_object": [],
438
- "Model_iteration": [],
439
- "Feature_set": [],
440
- "MAPE": [],
441
- "R2": [],
442
- "ADJR2": [],
443
- "pos_count": [],
444
- }
445
-
446
- # if st.button('Build Model'):
447
-
448
- if "iterations" not in st.session_state:
449
- st.session_state["iterations"] = 0
450
-
451
- if "final_selection" not in st.session_state:
452
- st.session_state["final_selection"] = False
453
-
454
- save_path = r"Model/"
455
- if st.session_state["final_selection"]:
456
- st.write(
457
- f'Total combinations created {format_numbers(len(st.session_state["final_selection"]))}'
458
- )
459
-
460
- # st.session_state["project_dct"]["model_build"]["all_iters_check"] = False
461
-
462
- checkbox_default = (
463
- st.session_state["project_dct"]["model_build"]["all_iters_check"]
464
- if st.session_state["project_dct"]["model_build"]["all_iters_check"]
465
- is not None
466
- else False
467
- )
468
-
469
- if st.checkbox("Build all iterations", value=checkbox_default):
470
- # st.session_state["project_dct"]["model_build"]["all_iters_check"]
471
- iterations = len(st.session_state["final_selection"])
472
- st.session_state["project_dct"]["model_build"][
473
- "all_iters_check"
474
- ] = True
475
-
476
- else:
477
- iterations = st.number_input(
478
- "Select the number of iterations to perform",
479
- min_value=0,
480
- step=100,
481
- value=st.session_state["iterations"],
482
- on_change=reset_model_result_dct,
483
- )
484
- st.session_state["project_dct"]["model_build"][
485
- "all_iters_check"
486
- ] = False
487
- st.session_state["project_dct"]["model_build"][
488
- "iterations"
489
- ] = iterations
490
-
491
- # st.stop()
492
-
493
- # build_button = st.session_state["project_dct"]["model_build"]["build_button"] if \
494
- # "build_button" in st.session_state["project_dct"]["model_build"].keys() else False
495
- # model_button =st.button('Build Model', on_click=reset_model_result_dct, key='model_build_button')
496
- # if
497
- # if model_button:
498
- if st.button(
499
- "Build Model",
500
- on_click=reset_model_result_dct,
501
- key="model_build_button",
502
- ):
503
- if iterations < 1:
504
- st.error("Please select number of iterations")
505
- st.stop()
506
- st.session_state["project_dct"]["model_build"]["build_button"] = True
507
- st.session_state["iterations"] = iterations
508
-
509
- # Section 4 - Model
510
- # st.session_state['media_data'] = st.session_state['media_data'].fillna(method='ffill')
511
- st.session_state["media_data"] = st.session_state["media_data"].ffill()
512
- st.markdown(
513
- "Data Split -- Training Period: May 9th, 2023 - October 5th,2023 , Testing Period: October 6th, 2023 - November 7th, 2023 "
514
- )
515
- progress_bar = st.progress(0) # Initialize the progress bar
516
- # time_remaining_text = st.empty() # Create an empty space for time remaining text
517
- start_time = time.time() # Record the start time
518
- progress_text = st.empty()
519
-
520
- # time_elapsed_text = st.empty()
521
- # for i, selected_features in enumerate(st.session_state["final_selection"][40000:40000 + int(iterations)]):
522
- # for i, selected_features in enumerate(st.session_state["final_selection"]):
523
-
524
- if is_panel == True:
525
- for i, selected_features in enumerate(
526
- st.session_state["final_selection"][0 : int(iterations)]
527
- ): # srishti
528
- df = st.session_state["media_data"]
529
-
530
- fet = [var for var in selected_features if len(var) > 0]
531
- inp_vars_str = " + ".join(fet) # new
532
-
533
- X = df[fet]
534
- y = df[target_col]
535
- ss = MinMaxScaler()
536
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
537
-
538
- X[target_col] = y # Sprint2
539
- X[panel_col] = df[panel_col] # Sprint2
540
-
541
- X_train = X.iloc[:train_idx]
542
- X_test = X.iloc[train_idx:]
543
- y_train = y.iloc[:train_idx]
544
- y_test = y.iloc[train_idx:]
545
-
546
- print(X_train.shape)
547
- # model = sm.OLS(y_train, X_train).fit()
548
- md_str = target_col + " ~ " + inp_vars_str
549
- # md = smf.mixedlm("total_approved_accounts_revenue ~ {}".format(inp_vars_str),
550
- # data=X_train[[target_col] + fet],
551
- # groups=X_train[panel_col])
552
- md = smf.mixedlm(
553
- md_str,
554
- data=X_train[[target_col] + fet],
555
- groups=X_train[panel_col],
556
- )
557
- mdf = md.fit()
558
- predicted_values = mdf.fittedvalues
559
-
560
- coefficients = mdf.fe_params.to_dict()
561
- model_positive = [
562
- col for col in coefficients.keys() if coefficients[col] > 0
563
- ]
564
-
565
- pvalues = [var for var in list(mdf.pvalues) if var <= 0.06]
566
-
567
- if (len(model_positive) / len(selected_features)) > 0 and (
568
- len(pvalues) / len(selected_features)
569
- ) >= 0: # srishti - changed just for testing, revert later
570
- # predicted_values = model.predict(X_train)
571
- mape = mean_absolute_percentage_error(
572
- y_train, predicted_values
573
- )
574
- r2 = r2_score(y_train, predicted_values)
575
- adjr2 = 1 - (1 - r2) * (len(y_train) - 1) / (
576
- len(y_train) - len(selected_features) - 1
577
- )
578
-
579
- filename = os.path.join(save_path, f"model_{i}.pkl")
580
- with open(filename, "wb") as f:
581
- pickle.dump(mdf, f)
582
- # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
583
- # model = pickle.load(file)
584
-
585
- st.session_state["Model_results"]["Model_object"].append(
586
- filename
587
- )
588
- st.session_state["Model_results"][
589
- "Model_iteration"
590
- ].append(i)
591
- st.session_state["Model_results"]["Feature_set"].append(
592
- fet
593
- )
594
- st.session_state["Model_results"]["MAPE"].append(mape)
595
- st.session_state["Model_results"]["R2"].append(r2)
596
- st.session_state["Model_results"]["pos_count"].append(
597
- len(model_positive)
598
- )
599
- st.session_state["Model_results"]["ADJR2"].append(adjr2)
600
-
601
- current_time = time.time()
602
- time_taken = current_time - start_time
603
- time_elapsed_minutes = time_taken / 60
604
- completed_iterations_text = f"{i + 1}/{iterations}"
605
- progress_bar.progress((i + 1) / int(iterations))
606
- progress_text.text(
607
- f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
608
- )
609
- st.write(
610
- f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
611
- )
612
-
613
- else:
614
-
615
- for i, selected_features in enumerate(
616
- st.session_state["final_selection"][0 : int(iterations)]
617
- ): # srishti
618
- df = st.session_state["media_data"]
619
-
620
- fet = [var for var in selected_features if len(var) > 0]
621
- inp_vars_str = " + ".join(fet)
622
-
623
- X = df[fet]
624
- y = df[target_col]
625
- ss = MinMaxScaler()
626
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
627
- X = sm.add_constant(X)
628
- X_train = X.iloc[:130]
629
- X_test = X.iloc[130:]
630
- y_train = y.iloc[:130]
631
- y_test = y.iloc[130:]
632
-
633
- model = sm.OLS(y_train, X_train).fit()
634
-
635
- coefficients = model.params.to_list()
636
- model_positive = [coef for coef in coefficients if coef > 0]
637
- predicted_values = model.predict(X_train)
638
- pvalues = [var for var in list(model.pvalues) if var <= 0.06]
639
-
640
- # if (len(model_possitive) / len(selected_features)) > 0.9 and (len(pvalues) / len(selected_features)) >= 0.8:
641
- if (len(model_positive) / len(selected_features)) > 0 and (
642
- len(pvalues) / len(selected_features)
643
- ) >= 0.5: # srishti - changed just for testing, revert later VALID MODEL CRITERIA
644
- # predicted_values = model.predict(X_train)
645
- mape = mean_absolute_percentage_error(
646
- y_train, predicted_values
647
- )
648
- adjr2 = model.rsquared_adj
649
- r2 = model.rsquared
650
-
651
- filename = os.path.join(save_path, f"model_{i}.pkl")
652
- with open(filename, "wb") as f:
653
- pickle.dump(model, f)
654
- # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
655
- # model = pickle.load(file)
656
-
657
- st.session_state["Model_results"]["Model_object"].append(
658
- filename
659
- )
660
- st.session_state["Model_results"][
661
- "Model_iteration"
662
- ].append(i)
663
- st.session_state["Model_results"]["Feature_set"].append(
664
- fet
665
- )
666
- st.session_state["Model_results"]["MAPE"].append(mape)
667
- st.session_state["Model_results"]["R2"].append(r2)
668
- st.session_state["Model_results"]["ADJR2"].append(adjr2)
669
- st.session_state["Model_results"]["pos_count"].append(
670
- len(model_positive)
671
- )
672
-
673
- current_time = time.time()
674
- time_taken = current_time - start_time
675
- time_elapsed_minutes = time_taken / 60
676
- completed_iterations_text = f"{i + 1}/{iterations}"
677
- progress_bar.progress((i + 1) / int(iterations))
678
- progress_text.text(
679
- f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
680
- )
681
- st.write(
682
- f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
683
- )
684
-
685
- pd.DataFrame(st.session_state["Model_results"]).to_csv(
686
- "model_output.csv"
687
- )
688
-
689
- def to_percentage(value):
690
- return f"{value * 100:.1f}%"
691
-
692
- ## Section 5 - Select Model
693
- st.title("2. Select Models")
694
- show_results_defualt = (
695
- st.session_state["project_dct"]["model_build"]["show_results_check"]
696
- if st.session_state["project_dct"]["model_build"]["show_results_check"]
697
- is not None
698
- else False
699
- )
700
- if "tick" not in st.session_state:
701
- st.session_state["tick"] = False
702
- if st.checkbox(
703
- "Show results of top 10 models (based on MAPE and Adj. R2)",
704
- value=show_results_defualt,
705
- ):
706
- st.session_state["project_dct"]["model_build"][
707
- "show_results_check"
708
- ] = True
709
- st.session_state["tick"] = True
710
- st.write(
711
- "Select one model iteration to generate performance metrics for it:"
712
- )
713
- data = pd.DataFrame(st.session_state["Model_results"])
714
- data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
715
- drop=True
716
- ) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
717
- data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
718
- data.drop_duplicates(subset="Model_iteration", inplace=True)
719
- top_10 = data.head(10)
720
- top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
721
- top_10[["MAPE", "R2", "ADJR2"]] = np.round(
722
- top_10[["MAPE", "R2", "ADJR2"]], 4
723
- ).applymap(to_percentage)
724
- top_10_table = top_10[
725
- ["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
726
- ]
727
- # top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
728
- gd = GridOptionsBuilder.from_dataframe(top_10_table)
729
- gd.configure_pagination(enabled=True)
730
-
731
- gd.configure_selection(
732
- use_checkbox=True,
733
- selection_mode="single",
734
- pre_select_all_rows=False,
735
- pre_selected_rows=[1],
736
- )
737
-
738
- gridoptions = gd.build()
739
-
740
- table = AgGrid(
741
- top_10,
742
- gridOptions=gridoptions,
743
- update_mode=GridUpdateMode.SELECTION_CHANGED,
744
- )
745
-
746
- selected_rows = table.selected_rows
747
- # if st.session_state["selected_rows"] != selected_rows:
748
- # st.session_state["build_rc_cb"] = False
749
- st.session_state["selected_rows"] = selected_rows
750
-
751
- # Section 6 - Display Results
752
-
753
- if len(selected_rows) > 0:
754
- st.header("2.1 Results Summary")
755
-
756
- model_object = data[
757
- data["Model_iteration"] == selected_rows[0]["Model_iteration"]
758
- ]["Model_object"]
759
- features_set = data[
760
- data["Model_iteration"] == selected_rows[0]["Model_iteration"]
761
- ]["Feature_set"]
762
-
763
- with open(str(model_object.values[0]), "rb") as file:
764
- # print(file)
765
- model = pickle.load(file)
766
- st.write(model.summary())
767
- st.header("2.2 Actual vs. Predicted Plot")
768
-
769
- if is_panel:
770
- df = st.session_state["media_data"]
771
- X = df[features_set.values[0]]
772
- y = df[target_col]
773
-
774
- ss = MinMaxScaler()
775
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
776
-
777
- # Sprint2 changes
778
- X[target_col] = y # new
779
- X[panel_col] = df[panel_col]
780
- X[date_col] = date
781
-
782
- X_train = X.iloc[:train_idx]
783
- X_test = X.iloc[train_idx:].reset_index(drop=True)
784
- y_train = y.iloc[:train_idx]
785
- y_test = y.iloc[train_idx:].reset_index(drop=True)
786
-
787
- test_spends = spends_data[
788
- train_idx:
789
- ] # Sprint3 - test spends for resp curves
790
- random_eff_df = get_random_effects(
791
- media_data, panel_col, model
792
- )
793
- train_pred = model.fittedvalues
794
- test_pred = mdf_predict(X_test, model, random_eff_df)
795
- print("__" * 20, test_pred.isna().sum())
796
-
797
- else:
798
- df = st.session_state["media_data"]
799
- X = df[features_set.values[0]]
800
- y = df[target_col]
801
-
802
- ss = MinMaxScaler()
803
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
804
- X = sm.add_constant(X)
805
-
806
- X[date_col] = date
807
-
808
- X_train = X.iloc[:130]
809
- X_test = X.iloc[130:].reset_index(drop=True)
810
- y_train = y.iloc[:130]
811
- y_test = y.iloc[130:].reset_index(drop=True)
812
-
813
- test_spends = spends_data[
814
- 130:
815
- ] # Sprint3 - test spends for resp curves
816
- train_pred = model.predict(
817
- X_train[features_set.values[0] + ["const"]]
818
- )
819
- test_pred = model.predict(
820
- X_test[features_set.values[0] + ["const"]]
821
- )
822
-
823
- # save x test to test - srishti
824
- # x_test_to_save = X_test.copy()
825
- # x_test_to_save['Actuals'] = y_test
826
- # x_test_to_save['Predictions'] = test_pred
827
- #
828
- # x_train_to_save = X_train.copy()
829
- # x_train_to_save['Actuals'] = y_train
830
- # x_train_to_save['Predictions'] = train_pred
831
- #
832
- # x_train_to_save.to_csv('Test/x_train_to_save.csv', index=False)
833
- # x_test_to_save.to_csv('Test/x_test_to_save.csv', index=False)
834
-
835
- st.session_state["X"] = X_train
836
- st.session_state["features_set"] = features_set.values[0]
837
- print(
838
- "**" * 20, "selected model features : ", features_set.values[0]
839
- )
840
- metrics_table, line, actual_vs_predicted_plot = (
841
- plot_actual_vs_predicted(
842
- X_train[date_col],
843
- y_train,
844
- train_pred,
845
- model,
846
- target_column=sel_target_col,
847
- is_panel=is_panel,
848
- )
849
- ) # Sprint2
850
-
851
- st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
852
-
853
- st.markdown("## 2.3 Residual Analysis")
854
- columns = st.columns(2)
855
- with columns[0]:
856
- fig = plot_residual_predicted(
857
- y_train, train_pred, X_train
858
- ) # Sprint2
859
- st.plotly_chart(fig)
860
-
861
- with columns[1]:
862
- st.empty()
863
- fig = qqplot(y_train, train_pred) # Sprint2
864
- st.plotly_chart(fig)
865
-
866
- with columns[0]:
867
- fig = residual_distribution(y_train, train_pred) # Sprint2
868
- st.pyplot(fig)
869
-
870
- vif_data = pd.DataFrame()
871
- # X=X.drop('const',axis=1)
872
- X_train_orig = (
873
- X_train.copy()
874
- ) # Sprint2 -- creating a copy of xtrain. Later deleting panel, target & date from xtrain
875
- del_col_list = list(
876
- set([target_col, panel_col, date_col]).intersection(
877
- set(X_train.columns)
878
- )
879
- )
880
- X_train.drop(columns=del_col_list, inplace=True) # Sprint2
881
-
882
- vif_data["Variable"] = X_train.columns
883
- vif_data["VIF"] = [
884
- variance_inflation_factor(X_train.values, i)
885
- for i in range(X_train.shape[1])
886
- ]
887
- vif_data.sort_values(by=["VIF"], ascending=False, inplace=True)
888
- vif_data = np.round(vif_data)
889
- vif_data["VIF"] = vif_data["VIF"].astype(float)
890
- st.header("2.4 Variance Inflation Factor (VIF)")
891
- # st.dataframe(vif_data)
892
- color_mapping = {
893
- "darkgreen": (vif_data["VIF"] < 3),
894
- "orange": (vif_data["VIF"] >= 3) & (vif_data["VIF"] <= 10),
895
- "darkred": (vif_data["VIF"] > 10),
896
- }
897
-
898
- # Create a horizontal bar plot
899
- fig, ax = plt.subplots()
900
- fig.set_figwidth(10) # Adjust the width of the figure as needed
901
-
902
- # Sort the bars by descending VIF values
903
- vif_data = vif_data.sort_values(by="VIF", ascending=False)
904
-
905
- # Iterate through the color mapping and plot bars with corresponding colors
906
- for color, condition in color_mapping.items():
907
- subset = vif_data[condition]
908
- bars = ax.barh(
909
- subset["Variable"], subset["VIF"], color=color, label=color
910
- )
911
-
912
- # Add text annotations on top of the bars
913
- for bar in bars:
914
- width = bar.get_width()
915
- ax.annotate(
916
- f"{width:}",
917
- xy=(width, bar.get_y() + bar.get_height() / 2),
918
- xytext=(5, 0),
919
- textcoords="offset points",
920
- va="center",
921
- )
922
-
923
- # Customize the plot
924
- ax.set_xlabel("VIF Values")
925
- # ax.set_title('2.4 Variance Inflation Factor (VIF)')
926
- # ax.legend(loc='upper right')
927
-
928
- # Display the plot in Streamlit
929
- st.pyplot(fig)
930
-
931
- with st.expander("Results Summary Test data"):
932
- # ss = MinMaxScaler()
933
- # X_test = pd.DataFrame(ss.fit_transform(X_test), columns=X_test.columns)
934
- st.header("2.2 Actual vs. Predicted Plot")
935
-
936
- metrics_table, line, actual_vs_predicted_plot = (
937
- plot_actual_vs_predicted(
938
- X_test[date_col],
939
- y_test,
940
- test_pred,
941
- model,
942
- target_column=sel_target_col,
943
- is_panel=is_panel,
944
- )
945
- ) # Sprint2
946
-
947
- st.plotly_chart(
948
- actual_vs_predicted_plot, use_container_width=True
949
- )
950
-
951
- st.markdown("## 2.3 Residual Analysis")
952
- columns = st.columns(2)
953
- with columns[0]:
954
- fig = plot_residual_predicted(
955
- y, test_pred, X_test
956
- ) # Sprint2
957
- st.plotly_chart(fig)
958
-
959
- with columns[1]:
960
- st.empty()
961
- fig = qqplot(y, test_pred) # Sprint2
962
- st.plotly_chart(fig)
963
-
964
- with columns[0]:
965
- fig = residual_distribution(y, test_pred) # Sprint2
966
- st.pyplot(fig)
967
-
968
- value = False
969
- save_button_model = st.checkbox(
970
- "Save this model to tune", key="build_rc_cb"
971
- ) # , on_click=set_save())
972
-
973
- if save_button_model:
974
- mod_name = st.text_input("Enter model name")
975
- if len(mod_name) > 0:
976
- mod_name = (
977
- mod_name + "__" + target_col
978
- ) # Sprint4 - adding target col to model name
979
- if is_panel:
980
- pred_train = model.fittedvalues
981
- pred_test = mdf_predict(X_test, model, random_eff_df)
982
- else:
983
- st.session_state["features_set"] = st.session_state[
984
- "features_set"
985
- ] + ["const"]
986
- pred_train = model.predict(
987
- X_train_orig[st.session_state["features_set"]]
988
- )
989
- pred_test = model.predict(
990
- X_test[st.session_state["features_set"]]
991
- )
992
-
993
- st.session_state["Model"][mod_name] = {
994
- "Model_object": model,
995
- "feature_set": st.session_state["features_set"],
996
- "X_train": X_train_orig,
997
- "X_test": X_test,
998
- "y_train": y_train,
999
- "y_test": y_test,
1000
- "pred_train": pred_train,
1001
- "pred_test": pred_test,
1002
- }
1003
- st.session_state["X_train"] = X_train_orig
1004
- st.session_state["X_test_spends"] = test_spends
1005
- st.session_state["saved_model_names"].append(mod_name)
1006
- # Sprint3 additions
1007
- if is_panel:
1008
- random_eff_df = get_random_effects(
1009
- media_data, panel_col, model
1010
- )
1011
- st.session_state["random_effects"] = random_eff_df
1012
-
1013
- with open(
1014
- os.path.join(
1015
- st.session_state["project_path"], "best_models.pkl"
1016
- ),
1017
- "wb",
1018
- ) as f:
1019
- pickle.dump(st.session_state["Model"], f)
1020
- st.success(
1021
- mod_name
1022
- + " model saved! Proceed to the next page to tune the model"
1023
- )
1024
-
1025
- urm = st.session_state["used_response_metrics"]
1026
- urm.append(sel_target_col)
1027
- st.session_state["used_response_metrics"] = list(
1028
- set(urm)
1029
- )
1030
- mod_name = ""
1031
- # Sprint4 - add the formatted name of the target col to used resp metrics
1032
- value = False
1033
-
1034
- st.session_state["project_dct"]["model_build"][
1035
- "session_state_saved"
1036
- ] = {}
1037
- for key in [
1038
- "Model",
1039
- "bin_dict",
1040
- "used_response_metrics",
1041
- "date",
1042
- "saved_model_names",
1043
- "media_data",
1044
- "X_test_spends",
1045
- ]:
1046
- st.session_state["project_dct"]["model_build"][
1047
- "session_state_saved"
1048
- ][key] = st.session_state[key]
1049
-
1050
- project_dct_path = os.path.join(
1051
- st.session_state["project_path"], "project_dct.pkl"
1052
- )
1053
- with open(project_dct_path, "wb") as f:
1054
- pickle.dump(st.session_state["project_dct"], f)
1055
-
1056
- update_db("4_Model_Build.py")
1057
-
1058
- st.toast("💾 Saved Successfully!")
1059
- else:
1060
- st.session_state["project_dct"]["model_build"][
1061
- "show_results_check"
1062
- ] = False
 
1
+ """
2
+ MMO Build Sprint 3
3
+ additions : adding more variables to session state for saved model : random effect, predicted train & test
4
+
5
+ MMO Build Sprint 4
6
+ additions : ability to run models for different response metrics
7
+ """
8
+
9
+ import streamlit as st
10
+ import pandas as pd
11
+ import plotly.express as px
12
+ import plotly.graph_objects as go
13
+ from Eda_functions import format_numbers
14
+ import numpy as np
15
+ import pickle
16
+ from st_aggrid import AgGrid
17
+ from st_aggrid import GridOptionsBuilder, GridUpdateMode
18
+ from utilities import set_header, load_local_css
19
+ from st_aggrid import GridOptionsBuilder
20
+ import time
21
+ import itertools
22
+ import statsmodels.api as sm
23
+ import numpy as npc
24
+ import re
25
+ import itertools
26
+ from sklearn.metrics import (
27
+ mean_absolute_error,
28
+ r2_score,
29
+ mean_absolute_percentage_error,
30
+ )
31
+ from sklearn.preprocessing import MinMaxScaler
32
+ import os
33
+ import matplotlib.pyplot as plt
34
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
35
+ import yaml
36
+ from yaml import SafeLoader
37
+ import streamlit_authenticator as stauth
38
+
39
+ st.set_option("deprecation.showPyplotGlobalUse", False)
40
+ import statsmodels.api as sm
41
+ import statsmodels.formula.api as smf
42
+
43
+ from datetime import datetime
44
+ import seaborn as sns
45
+ from Data_prep_functions import *
46
+ import sqlite3
47
+ from utilities import update_db
48
+
49
+
50
+ @st.cache_resource(show_spinner=False)
51
+ # def save_to_pickle(file_path, final_df):
52
+ # # Open the file in write-binary mode and dump the objects
53
+ # with open(file_path, "wb") as f:
54
+ # pickle.dump({file_path: final_df}, f)
55
+
56
+
57
+ def get_random_effects(media_data, panel_col, _mdf):
58
+ random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
59
+
60
+ for i, market in enumerate(media_data[panel_col].unique()):
61
+ print(i, end="\r")
62
+ intercept = _mdf.random_effects[market].values[0]
63
+ random_eff_df.loc[i, "random_effect"] = intercept
64
+ random_eff_df.loc[i, panel_col] = market
65
+
66
+ return random_eff_df
67
+
68
+
69
+ def mdf_predict(X_df, mdf, random_eff_df):
70
+ X = X_df.copy()
71
+ X["fixed_effect"] = mdf.predict(X)
72
+ X = pd.merge(X, random_eff_df, on=panel_col, how="left")
73
+ X["pred"] = X["fixed_effect"] + X["random_effect"]
74
+ # X.to_csv('Test/megred_df.csv',index=False)
75
+ X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
76
+ return X["pred"]
77
+
78
+
79
+ st.set_page_config(
80
+ page_title="Model Build",
81
+ page_icon=":shark:",
82
+ layout="wide",
83
+ initial_sidebar_state="collapsed",
84
+ )
85
+
86
+ load_local_css("styles.css")
87
+ set_header()
88
+
89
+ # Check for authentication status
90
+ for k, v in st.session_state.items():
91
+ if k not in [
92
+ "logout",
93
+ "login",
94
+ "config",
95
+ "model_build_button",
96
+ ] and not k.startswith("FormSubmitter"):
97
+ st.session_state[k] = v
98
+ with open("config.yaml") as file:
99
+ config = yaml.load(file, Loader=SafeLoader)
100
+ st.session_state["config"] = config
101
+ authenticator = stauth.Authenticate(
102
+ config["credentials"],
103
+ config["cookie"]["name"],
104
+ config["cookie"]["key"],
105
+ config["cookie"]["expiry_days"],
106
+ config["preauthorized"],
107
+ )
108
+ st.session_state["authenticator"] = authenticator
109
+ name, authentication_status, username = authenticator.login("Login", "main")
110
+ auth_status = st.session_state.get("authentication_status")
111
+
112
+ if auth_status == True:
113
+ authenticator.logout("Logout", "main")
114
+ is_state_initiaized = st.session_state.get("initialized", False)
115
+
116
+ conn = sqlite3.connect(
117
+ r"DB/User.db", check_same_thread=False
118
+ ) # connection with sql db
119
+ c = conn.cursor()
120
+
121
+ if not is_state_initiaized:
122
+
123
+ if "session_name" not in st.session_state:
124
+ st.session_state["session_name"] = None
125
+
126
+ if "project_dct" not in st.session_state:
127
+ st.error("Please load a project from Home page")
128
+ st.stop()
129
+
130
+ st.title("1. Build Your Model")
131
+
132
+ if not os.path.exists(
133
+ os.path.join(st.session_state["project_path"], "data_import.pkl")
134
+ ):
135
+ st.error("Please move to Data Import Page and save.")
136
+ st.stop()
137
+ with open(
138
+ os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
139
+ ) as f:
140
+ data = pickle.load(f)
141
+ st.session_state["bin_dict"] = data["bin_dict"]
142
+
143
+ if not os.path.exists(
144
+ os.path.join(
145
+ st.session_state["project_path"], "final_df_transformed.pkl"
146
+ )
147
+ ):
148
+ st.error(
149
+ "Please move to Transformation Page and save transformations."
150
+ )
151
+ st.stop()
152
+ with open(
153
+ os.path.join(
154
+ st.session_state["project_path"], "final_df_transformed.pkl"
155
+ ),
156
+ "rb",
157
+ ) as f:
158
+ data = pickle.load(f)
159
+ media_data = data["final_df_transformed"]
160
+ #media_data.to_csv("Test/media_data.csv", index=False)
161
+ train_idx = int(len(media_data) / 5) * 4
162
+ # Sprint4 - available response metrics is a list of all reponse metrics in the data
163
+ ## these will be put in a drop down
164
+
165
+ st.session_state["media_data"] = media_data
166
+
167
+ if "available_response_metrics" not in st.session_state:
168
+ # st.session_state['available_response_metrics'] = ['Total Approved Accounts - Revenue',
169
+ # 'Total Approved Accounts - Appsflyer',
170
+ # 'Account Requests - Appsflyer',
171
+ # 'App Installs - Appsflyer']
172
+
173
+ st.session_state["available_response_metrics"] = st.session_state[
174
+ "bin_dict"
175
+ ]["Response Metrics"]
176
+ # Sprint4
177
+ if "is_tuned_model" not in st.session_state:
178
+ st.session_state["is_tuned_model"] = {}
179
+ for resp_metric in st.session_state["available_response_metrics"]:
180
+ resp_metric = (
181
+ resp_metric.lower()
182
+ .replace(" ", "_")
183
+ .replace("-", "")
184
+ .replace(":", "")
185
+ .replace("__", "_")
186
+ )
187
+ st.session_state["is_tuned_model"][resp_metric] = False
188
+
189
+ # Sprint4 - used_response_metrics is a list of resp metrics for which user has created & saved a model
190
+ if "used_response_metrics" not in st.session_state:
191
+ st.session_state["used_response_metrics"] = []
192
+
193
+ # Sprint4 - saved_model_names
194
+ if "saved_model_names" not in st.session_state:
195
+ st.session_state["saved_model_names"] = []
196
+
197
+ if "Model" not in st.session_state:
198
+ if (
199
+ "session_state_saved"
200
+ in st.session_state["project_dct"]["model_build"].keys()
201
+ and st.session_state["project_dct"]["model_build"][
202
+ "session_state_saved"
203
+ ]
204
+ is not None
205
+ and "Model"
206
+ in st.session_state["project_dct"]["model_build"][
207
+ "session_state_saved"
208
+ ].keys()
209
+ ):
210
+ st.session_state["Model"] = st.session_state["project_dct"][
211
+ "model_build"
212
+ ]["session_state_saved"]["Model"]
213
+ else:
214
+ st.session_state["Model"] = {}
215
+
216
+ # Sprint4 - select a response metric
217
+ default_target_idx = (
218
+ st.session_state["project_dct"]["model_build"].get(
219
+ "sel_target_col", None
220
+ )
221
+ if st.session_state["project_dct"]["model_build"].get(
222
+ "sel_target_col", None
223
+ )
224
+ is not None
225
+ else st.session_state["available_response_metrics"][0]
226
+ )
227
+
228
+ sel_target_col = st.selectbox(
229
+ "Select the response metric",
230
+ st.session_state["available_response_metrics"],
231
+ index=st.session_state["available_response_metrics"].index(
232
+ default_target_idx
233
+ ),
234
+ )
235
+ # , on_change=reset_save())
236
+ st.session_state["project_dct"]["model_build"][
237
+ "sel_target_col"
238
+ ] = sel_target_col
239
+
240
+ target_col = (
241
+ sel_target_col.lower()
242
+ .replace(" ", "_")
243
+ .replace("-", "")
244
+ .replace(":", "")
245
+ .replace("__", "_")
246
+ )
247
+ new_name_dct = {
248
+ col: col.lower()
249
+ .replace(".", "_")
250
+ .lower()
251
+ .replace("@", "_")
252
+ .replace(" ", "_")
253
+ .replace("-", "")
254
+ .replace(":", "")
255
+ .replace("__", "_")
256
+ for col in media_data.columns
257
+ }
258
+ media_data.columns = [
259
+ col.lower()
260
+ .replace(".", "_")
261
+ .replace("@", "_")
262
+ .replace(" ", "_")
263
+ .replace("-", "")
264
+ .replace(":", "")
265
+ .replace("__", "_")
266
+ for col in media_data.columns
267
+ ]
268
+ panel_col = [
269
+ col.lower()
270
+ .replace(".", "_")
271
+ .replace("@", "_")
272
+ .replace(" ", "_")
273
+ .replace("-", "")
274
+ .replace(":", "")
275
+ .replace("__", "_")
276
+ for col in st.session_state["bin_dict"]["Panel Level 1"]
277
+ ][
278
+ 0
279
+ ] # set the panel column
280
+ date_col = "date"
281
+
282
+ is_panel = True if len(panel_col) > 0 else False
283
+
284
+ if "is_panel" not in st.session_state:
285
+ st.session_state["is_panel"] = is_panel
286
+
287
+ if is_panel:
288
+ media_data.sort_values([date_col, panel_col], inplace=True)
289
+ else:
290
+ media_data.sort_values(date_col, inplace=True)
291
+
292
+ media_data.reset_index(drop=True, inplace=True)
293
+
294
+ date = media_data[date_col]
295
+ st.session_state["date"] = date
296
+ y = media_data[target_col]
297
+
298
+ if is_panel:
299
+ spends_data = media_data[
300
+ [
301
+ c
302
+ for c in media_data.columns
303
+ if "_cost" in c.lower() or "_spend" in c.lower()
304
+ ]
305
+ + [date_col, panel_col]
306
+ ]
307
+ # Sprint3 - spends for resp curves
308
+ else:
309
+ spends_data = media_data[
310
+ [
311
+ c
312
+ for c in media_data.columns
313
+ if "_cost" in c.lower() or "_spend" in c.lower()
314
+ ]
315
+ + [date_col]
316
+ ]
317
+
318
+ y = media_data[target_col]
319
+ media_data.drop([date_col], axis=1, inplace=True)
320
+ media_data.reset_index(drop=True, inplace=True)
321
+
322
+ columns = st.columns(2)
323
+
324
+ old_shape = media_data.shape
325
+
326
+ if "old_shape" not in st.session_state:
327
+ st.session_state["old_shape"] = old_shape
328
+
329
+ if "media_data" not in st.session_state:
330
+ st.session_state["media_data"] = pd.DataFrame()
331
+
332
+ # Sprint3
333
+ if "orig_media_data" not in st.session_state:
334
+ st.session_state["orig_media_data"] = pd.DataFrame()
335
+
336
+ # Sprint3 additions
337
+ if "random_effects" not in st.session_state:
338
+ st.session_state["random_effects"] = pd.DataFrame()
339
+ if "pred_train" not in st.session_state:
340
+ st.session_state["pred_train"] = []
341
+ if "pred_test" not in st.session_state:
342
+ st.session_state["pred_test"] = []
343
+ # end of Sprint3 additions
344
+
345
+ # Section 3 - Create combinations
346
+
347
+ # bucket=['paid_search', 'kwai','indicacao','infleux', 'influencer','FB: Level Achieved - Tier 1 Impressions',
348
+ # ' FB: Level Achieved - Tier 2 Impressions','paid_social_others',
349
+ # ' GA App: Will And Cid Pequena Baixo Risco Clicks',
350
+ # 'digital_tactic_others',"programmatic"
351
+ # ]
352
+
353
+ # srishti - bucket names changed
354
+ bucket = [
355
+ "paid_search",
356
+ "kwai",
357
+ "indicacao",
358
+ "infleux",
359
+ "influencer",
360
+ "fb_level_achieved_tier_2",
361
+ "fb_level_achieved_tier_1",
362
+ "paid_social_others",
363
+ "ga_app",
364
+ "digital_tactic_others",
365
+ "programmatic",
366
+ ]
367
+
368
+ # with columns[0]:
369
+ # if st.button('Create Combinations of Variables'):
370
+
371
+ top_3_correlated_features = []
372
+ # # for col in st.session_state['media_data'].columns[:19]:
373
+ # original_cols = [c for c in st.session_state['media_data'].columns if
374
+ # "_clicks" in c.lower() or "_impressions" in c.lower()]
375
+ # original_cols = [c for c in original_cols if "_lag" not in c.lower() and "_adstock" not in c.lower()]
376
+
377
+ original_cols = (
378
+ st.session_state["bin_dict"]["Media"]
379
+ + st.session_state["bin_dict"]["Internal"]
380
+ )
381
+
382
+ original_cols = [
383
+ col.lower()
384
+ .replace(".", "_")
385
+ .replace("@", "_")
386
+ .replace(" ", "_")
387
+ .replace("-", "")
388
+ .replace(":", "")
389
+ .replace("__", "_")
390
+ for col in original_cols
391
+ ]
392
+ original_cols = [col for col in original_cols if "_cost" not in col]
393
+ # for col in st.session_state['media_data'].columns[:19]:
394
+ for col in original_cols: # srishti - new
395
+ corr_df = (
396
+ pd.concat(
397
+ [st.session_state["media_data"].filter(regex=col), y], axis=1
398
+ )
399
+ .corr()[target_col]
400
+ .iloc[:-1]
401
+ )
402
+ top_3_correlated_features.append(
403
+ list(corr_df.sort_values(ascending=False).head(2).index)
404
+ )
405
+ flattened_list = [
406
+ item for sublist in top_3_correlated_features for item in sublist
407
+ ]
408
+ # all_features_set={var:[col for col in flattened_list if var in col] for var in bucket}
409
+ all_features_set = {
410
+ var: [col for col in flattened_list if var in col]
411
+ for var in bucket
412
+ if len([col for col in flattened_list if var in col]) > 0
413
+ } # srishti
414
+ channels_all = [values for values in all_features_set.values()]
415
+ st.session_state["combinations"] = list(itertools.product(*channels_all))
416
+ # if 'combinations' not in st.session_state:
417
+ # st.session_state['combinations']=combinations_all
418
+
419
+ st.session_state["final_selection"] = st.session_state["combinations"]
420
+ # st.success('Created combinations')
421
+
422
+ # revenue.reset_index(drop=True,inplace=True)
423
+ y.reset_index(drop=True, inplace=True)
424
+ if "Model_results" not in st.session_state:
425
+ st.session_state["Model_results"] = {
426
+ "Model_object": [],
427
+ "Model_iteration": [],
428
+ "Feature_set": [],
429
+ "MAPE": [],
430
+ "R2": [],
431
+ "ADJR2": [],
432
+ "pos_count": [],
433
+ }
434
+
435
+ def reset_model_result_dct():
436
+ st.session_state["Model_results"] = {
437
+ "Model_object": [],
438
+ "Model_iteration": [],
439
+ "Feature_set": [],
440
+ "MAPE": [],
441
+ "R2": [],
442
+ "ADJR2": [],
443
+ "pos_count": [],
444
+ }
445
+
446
+ # if st.button('Build Model'):
447
+
448
+ if "iterations" not in st.session_state:
449
+ st.session_state["iterations"] = 0
450
+
451
+ if "final_selection" not in st.session_state:
452
+ st.session_state["final_selection"] = False
453
+
454
+ save_path = r"Model/"
455
+ if st.session_state["final_selection"]:
456
+ st.write(
457
+ f'Total combinations created {format_numbers(len(st.session_state["final_selection"]))}'
458
+ )
459
+
460
+ # st.session_state["project_dct"]["model_build"]["all_iters_check"] = False
461
+
462
+ checkbox_default = (
463
+ st.session_state["project_dct"]["model_build"]["all_iters_check"]
464
+ if st.session_state["project_dct"]["model_build"]["all_iters_check"]
465
+ is not None
466
+ else False
467
+ )
468
+
469
+ if st.checkbox("Build all iterations", value=checkbox_default):
470
+ # st.session_state["project_dct"]["model_build"]["all_iters_check"]
471
+ iterations = len(st.session_state["final_selection"])
472
+ st.session_state["project_dct"]["model_build"][
473
+ "all_iters_check"
474
+ ] = True
475
+
476
+ else:
477
+ iterations = st.number_input(
478
+ "Select the number of iterations to perform",
479
+ min_value=0,
480
+ step=100,
481
+ value=st.session_state["iterations"],
482
+ on_change=reset_model_result_dct,
483
+ )
484
+ st.session_state["project_dct"]["model_build"][
485
+ "all_iters_check"
486
+ ] = False
487
+ st.session_state["project_dct"]["model_build"][
488
+ "iterations"
489
+ ] = iterations
490
+
491
+ # st.stop()
492
+
493
+ # build_button = st.session_state["project_dct"]["model_build"]["build_button"] if \
494
+ # "build_button" in st.session_state["project_dct"]["model_build"].keys() else False
495
+ # model_button =st.button('Build Model', on_click=reset_model_result_dct, key='model_build_button')
496
+ # if
497
+ # if model_button:
498
+ if st.button(
499
+ "Build Model",
500
+ on_click=reset_model_result_dct,
501
+ key="model_build_button",
502
+ ):
503
+ if iterations < 1:
504
+ st.error("Please select number of iterations")
505
+ st.stop()
506
+ st.session_state["project_dct"]["model_build"]["build_button"] = True
507
+ st.session_state["iterations"] = iterations
508
+
509
+ # Section 4 - Model
510
+ # st.session_state['media_data'] = st.session_state['media_data'].fillna(method='ffill')
511
+ st.session_state["media_data"] = st.session_state["media_data"].ffill()
512
+ st.markdown(
513
+ "Data Split -- Training Period: May 9th, 2023 - October 5th,2023 , Testing Period: October 6th, 2023 - November 7th, 2023 "
514
+ )
515
+ progress_bar = st.progress(0) # Initialize the progress bar
516
+ # time_remaining_text = st.empty() # Create an empty space for time remaining text
517
+ start_time = time.time() # Record the start time
518
+ progress_text = st.empty()
519
+
520
+ # time_elapsed_text = st.empty()
521
+ # for i, selected_features in enumerate(st.session_state["final_selection"][40000:40000 + int(iterations)]):
522
+ # for i, selected_features in enumerate(st.session_state["final_selection"]):
523
+
524
+ if is_panel == True:
525
+ for i, selected_features in enumerate(
526
+ st.session_state["final_selection"][0 : int(iterations)]
527
+ ): # srishti
528
+ df = st.session_state["media_data"]
529
+
530
+ fet = [var for var in selected_features if len(var) > 0]
531
+ inp_vars_str = " + ".join(fet) # new
532
+
533
+ X = df[fet]
534
+ y = df[target_col]
535
+ ss = MinMaxScaler()
536
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
537
+
538
+ X[target_col] = y # Sprint2
539
+ X[panel_col] = df[panel_col] # Sprint2
540
+
541
+ X_train = X.iloc[:train_idx]
542
+ X_test = X.iloc[train_idx:]
543
+ y_train = y.iloc[:train_idx]
544
+ y_test = y.iloc[train_idx:]
545
+
546
+ print(X_train.shape)
547
+ # model = sm.OLS(y_train, X_train).fit()
548
+ md_str = target_col + " ~ " + inp_vars_str
549
+ # md = smf.mixedlm("total_approved_accounts_revenue ~ {}".format(inp_vars_str),
550
+ # data=X_train[[target_col] + fet],
551
+ # groups=X_train[panel_col])
552
+ md = smf.mixedlm(
553
+ md_str,
554
+ data=X_train[[target_col] + fet],
555
+ groups=X_train[panel_col],
556
+ )
557
+ mdf = md.fit()
558
+ predicted_values = mdf.fittedvalues
559
+
560
+ coefficients = mdf.fe_params.to_dict()
561
+ model_positive = [
562
+ col for col in coefficients.keys() if coefficients[col] > 0
563
+ ]
564
+
565
+ pvalues = [var for var in list(mdf.pvalues) if var <= 0.06]
566
+
567
+ if (len(model_positive) / len(selected_features)) > 0 and (
568
+ len(pvalues) / len(selected_features)
569
+ ) >= 0: # srishti - changed just for testing, revert later
570
+ # predicted_values = model.predict(X_train)
571
+ mape = mean_absolute_percentage_error(
572
+ y_train, predicted_values
573
+ )
574
+ r2 = r2_score(y_train, predicted_values)
575
+ adjr2 = 1 - (1 - r2) * (len(y_train) - 1) / (
576
+ len(y_train) - len(selected_features) - 1
577
+ )
578
+
579
+ filename = os.path.join(save_path, f"model_{i}.pkl")
580
+ with open(filename, "wb") as f:
581
+ pickle.dump(mdf, f)
582
+ # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
583
+ # model = pickle.load(file)
584
+
585
+ st.session_state["Model_results"]["Model_object"].append(
586
+ filename
587
+ )
588
+ st.session_state["Model_results"][
589
+ "Model_iteration"
590
+ ].append(i)
591
+ st.session_state["Model_results"]["Feature_set"].append(
592
+ fet
593
+ )
594
+ st.session_state["Model_results"]["MAPE"].append(mape)
595
+ st.session_state["Model_results"]["R2"].append(r2)
596
+ st.session_state["Model_results"]["pos_count"].append(
597
+ len(model_positive)
598
+ )
599
+ st.session_state["Model_results"]["ADJR2"].append(adjr2)
600
+
601
+ current_time = time.time()
602
+ time_taken = current_time - start_time
603
+ time_elapsed_minutes = time_taken / 60
604
+ completed_iterations_text = f"{i + 1}/{iterations}"
605
+ progress_bar.progress((i + 1) / int(iterations))
606
+ progress_text.text(
607
+ f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
608
+ )
609
+ st.write(
610
+ f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
611
+ )
612
+
613
+ else:
614
+
615
+ for i, selected_features in enumerate(
616
+ st.session_state["final_selection"][0 : int(iterations)]
617
+ ): # srishti
618
+ df = st.session_state["media_data"]
619
+
620
+ fet = [var for var in selected_features if len(var) > 0]
621
+ inp_vars_str = " + ".join(fet)
622
+
623
+ X = df[fet]
624
+ y = df[target_col]
625
+ ss = MinMaxScaler()
626
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
627
+ X = sm.add_constant(X)
628
+ X_train = X.iloc[:130]
629
+ X_test = X.iloc[130:]
630
+ y_train = y.iloc[:130]
631
+ y_test = y.iloc[130:]
632
+
633
+ model = sm.OLS(y_train, X_train).fit()
634
+
635
+ coefficients = model.params.to_list()
636
+ model_positive = [coef for coef in coefficients if coef > 0]
637
+ predicted_values = model.predict(X_train)
638
+ pvalues = [var for var in list(model.pvalues) if var <= 0.06]
639
+
640
+ # if (len(model_possitive) / len(selected_features)) > 0.9 and (len(pvalues) / len(selected_features)) >= 0.8:
641
+ if (len(model_positive) / len(selected_features)) > 0 and (
642
+ len(pvalues) / len(selected_features)
643
+ ) >= 0.5: # srishti - changed just for testing, revert later VALID MODEL CRITERIA
644
+ # predicted_values = model.predict(X_train)
645
+ mape = mean_absolute_percentage_error(
646
+ y_train, predicted_values
647
+ )
648
+ adjr2 = model.rsquared_adj
649
+ r2 = model.rsquared
650
+
651
+ filename = os.path.join(save_path, f"model_{i}.pkl")
652
+ with open(filename, "wb") as f:
653
+ pickle.dump(model, f)
654
+ # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
655
+ # model = pickle.load(file)
656
+
657
+ st.session_state["Model_results"]["Model_object"].append(
658
+ filename
659
+ )
660
+ st.session_state["Model_results"][
661
+ "Model_iteration"
662
+ ].append(i)
663
+ st.session_state["Model_results"]["Feature_set"].append(
664
+ fet
665
+ )
666
+ st.session_state["Model_results"]["MAPE"].append(mape)
667
+ st.session_state["Model_results"]["R2"].append(r2)
668
+ st.session_state["Model_results"]["ADJR2"].append(adjr2)
669
+ st.session_state["Model_results"]["pos_count"].append(
670
+ len(model_positive)
671
+ )
672
+
673
+ current_time = time.time()
674
+ time_taken = current_time - start_time
675
+ time_elapsed_minutes = time_taken / 60
676
+ completed_iterations_text = f"{i + 1}/{iterations}"
677
+ progress_bar.progress((i + 1) / int(iterations))
678
+ progress_text.text(
679
+ f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
680
+ )
681
+ st.write(
682
+ f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
683
+ )
684
+
685
+ pd.DataFrame(st.session_state["Model_results"]).to_csv(
686
+ "model_output.csv"
687
+ )
688
+
689
+ def to_percentage(value):
690
+ return f"{value * 100:.1f}%"
691
+
692
+ ## Section 5 - Select Model
693
+ st.title("2. Select Models")
694
+ show_results_defualt = (
695
+ st.session_state["project_dct"]["model_build"]["show_results_check"]
696
+ if st.session_state["project_dct"]["model_build"]["show_results_check"]
697
+ is not None
698
+ else False
699
+ )
700
+ if "tick" not in st.session_state:
701
+ st.session_state["tick"] = False
702
+ if st.checkbox(
703
+ "Show results of top 10 models (based on MAPE and Adj. R2)",
704
+ value=show_results_defualt,
705
+ ):
706
+ st.session_state["project_dct"]["model_build"][
707
+ "show_results_check"
708
+ ] = True
709
+ st.session_state["tick"] = True
710
+ st.write(
711
+ "Select one model iteration to generate performance metrics for it:"
712
+ )
713
+ data = pd.DataFrame(st.session_state["Model_results"])
714
+ data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
715
+ drop=True
716
+ ) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
717
+ data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
718
+ data.drop_duplicates(subset="Model_iteration", inplace=True)
719
+ top_10 = data.head(10)
720
+ top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
721
+ top_10[["MAPE", "R2", "ADJR2"]] = np.round(
722
+ top_10[["MAPE", "R2", "ADJR2"]], 4
723
+ ).applymap(to_percentage)
724
+ top_10_table = top_10[
725
+ ["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
726
+ ]
727
+ # top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
728
+ gd = GridOptionsBuilder.from_dataframe(top_10_table)
729
+ gd.configure_pagination(enabled=True)
730
+
731
+ gd.configure_selection(
732
+ use_checkbox=True,
733
+ selection_mode="single",
734
+ pre_select_all_rows=False,
735
+ pre_selected_rows=[1],
736
+ )
737
+
738
+ gridoptions = gd.build()
739
+
740
+ table = AgGrid(
741
+ top_10,
742
+ gridOptions=gridoptions,
743
+ update_mode=GridUpdateMode.SELECTION_CHANGED,
744
+ )
745
+
746
+ selected_rows = table.selected_rows
747
+ # if st.session_state["selected_rows"] != selected_rows:
748
+ # st.session_state["build_rc_cb"] = False
749
+ st.session_state["selected_rows"] = selected_rows
750
+
751
+ # Section 6 - Display Results
752
+
753
+ if len(selected_rows) > 0:
754
+ st.header("2.1 Results Summary")
755
+
756
+ model_object = data[
757
+ data["Model_iteration"] == selected_rows[0]["Model_iteration"]
758
+ ]["Model_object"]
759
+ features_set = data[
760
+ data["Model_iteration"] == selected_rows[0]["Model_iteration"]
761
+ ]["Feature_set"]
762
+
763
+ with open(str(model_object.values[0]), "rb") as file:
764
+ # print(file)
765
+ model = pickle.load(file)
766
+ st.write(model.summary())
767
+ st.header("2.2 Actual vs. Predicted Plot")
768
+
769
+ if is_panel:
770
+ df = st.session_state["media_data"]
771
+ X = df[features_set.values[0]]
772
+ y = df[target_col]
773
+
774
+ ss = MinMaxScaler()
775
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
776
+
777
+ # Sprint2 changes
778
+ X[target_col] = y # new
779
+ X[panel_col] = df[panel_col]
780
+ X[date_col] = date
781
+
782
+ X_train = X.iloc[:train_idx]
783
+ X_test = X.iloc[train_idx:].reset_index(drop=True)
784
+ y_train = y.iloc[:train_idx]
785
+ y_test = y.iloc[train_idx:].reset_index(drop=True)
786
+
787
+ test_spends = spends_data[
788
+ train_idx:
789
+ ] # Sprint3 - test spends for resp curves
790
+ random_eff_df = get_random_effects(
791
+ media_data, panel_col, model
792
+ )
793
+ train_pred = model.fittedvalues
794
+ test_pred = mdf_predict(X_test, model, random_eff_df)
795
+ print("__" * 20, test_pred.isna().sum())
796
+
797
+ else:
798
+ df = st.session_state["media_data"]
799
+ X = df[features_set.values[0]]
800
+ y = df[target_col]
801
+
802
+ ss = MinMaxScaler()
803
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
804
+ X = sm.add_constant(X)
805
+
806
+ X[date_col] = date
807
+
808
+ X_train = X.iloc[:130]
809
+ X_test = X.iloc[130:].reset_index(drop=True)
810
+ y_train = y.iloc[:130]
811
+ y_test = y.iloc[130:].reset_index(drop=True)
812
+
813
+ test_spends = spends_data[
814
+ 130:
815
+ ] # Sprint3 - test spends for resp curves
816
+ train_pred = model.predict(
817
+ X_train[features_set.values[0] + ["const"]]
818
+ )
819
+ test_pred = model.predict(
820
+ X_test[features_set.values[0] + ["const"]]
821
+ )
822
+
823
+ # save x test to test - srishti
824
+ # x_test_to_save = X_test.copy()
825
+ # x_test_to_save['Actuals'] = y_test
826
+ # x_test_to_save['Predictions'] = test_pred
827
+ #
828
+ # x_train_to_save = X_train.copy()
829
+ # x_train_to_save['Actuals'] = y_train
830
+ # x_train_to_save['Predictions'] = train_pred
831
+ #
832
+ # x_train_to_save.to_csv('Test/x_train_to_save.csv', index=False)
833
+ # x_test_to_save.to_csv('Test/x_test_to_save.csv', index=False)
834
+
835
+ st.session_state["X"] = X_train
836
+ st.session_state["features_set"] = features_set.values[0]
837
+ print(
838
+ "**" * 20, "selected model features : ", features_set.values[0]
839
+ )
840
+ metrics_table, line, actual_vs_predicted_plot = (
841
+ plot_actual_vs_predicted(
842
+ X_train[date_col],
843
+ y_train,
844
+ train_pred,
845
+ model,
846
+ target_column=sel_target_col,
847
+ is_panel=is_panel,
848
+ )
849
+ ) # Sprint2
850
+
851
+ st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
852
+
853
+ st.markdown("## 2.3 Residual Analysis")
854
+ columns = st.columns(2)
855
+ with columns[0]:
856
+ fig = plot_residual_predicted(
857
+ y_train, train_pred, X_train
858
+ ) # Sprint2
859
+ st.plotly_chart(fig)
860
+
861
+ with columns[1]:
862
+ st.empty()
863
+ fig = qqplot(y_train, train_pred) # Sprint2
864
+ st.plotly_chart(fig)
865
+
866
+ with columns[0]:
867
+ fig = residual_distribution(y_train, train_pred) # Sprint2
868
+ st.pyplot(fig)
869
+
870
+ vif_data = pd.DataFrame()
871
+ # X=X.drop('const',axis=1)
872
+ X_train_orig = (
873
+ X_train.copy()
874
+ ) # Sprint2 -- creating a copy of xtrain. Later deleting panel, target & date from xtrain
875
+ del_col_list = list(
876
+ set([target_col, panel_col, date_col]).intersection(
877
+ set(X_train.columns)
878
+ )
879
+ )
880
+ X_train.drop(columns=del_col_list, inplace=True) # Sprint2
881
+
882
+ vif_data["Variable"] = X_train.columns
883
+ vif_data["VIF"] = [
884
+ variance_inflation_factor(X_train.values, i)
885
+ for i in range(X_train.shape[1])
886
+ ]
887
+ vif_data.sort_values(by=["VIF"], ascending=False, inplace=True)
888
+ vif_data = np.round(vif_data)
889
+ vif_data["VIF"] = vif_data["VIF"].astype(float)
890
+ st.header("2.4 Variance Inflation Factor (VIF)")
891
+ # st.dataframe(vif_data)
892
+ color_mapping = {
893
+ "darkgreen": (vif_data["VIF"] < 3),
894
+ "orange": (vif_data["VIF"] >= 3) & (vif_data["VIF"] <= 10),
895
+ "darkred": (vif_data["VIF"] > 10),
896
+ }
897
+
898
+ # Create a horizontal bar plot
899
+ fig, ax = plt.subplots()
900
+ fig.set_figwidth(10) # Adjust the width of the figure as needed
901
+
902
+ # Sort the bars by descending VIF values
903
+ vif_data = vif_data.sort_values(by="VIF", ascending=False)
904
+
905
+ # Iterate through the color mapping and plot bars with corresponding colors
906
+ for color, condition in color_mapping.items():
907
+ subset = vif_data[condition]
908
+ bars = ax.barh(
909
+ subset["Variable"], subset["VIF"], color=color, label=color
910
+ )
911
+
912
+ # Add text annotations on top of the bars
913
+ for bar in bars:
914
+ width = bar.get_width()
915
+ ax.annotate(
916
+ f"{width:}",
917
+ xy=(width, bar.get_y() + bar.get_height() / 2),
918
+ xytext=(5, 0),
919
+ textcoords="offset points",
920
+ va="center",
921
+ )
922
+
923
+ # Customize the plot
924
+ ax.set_xlabel("VIF Values")
925
+ # ax.set_title('2.4 Variance Inflation Factor (VIF)')
926
+ # ax.legend(loc='upper right')
927
+
928
+ # Display the plot in Streamlit
929
+ st.pyplot(fig)
930
+
931
+ with st.expander("Results Summary Test data"):
932
+ # ss = MinMaxScaler()
933
+ # X_test = pd.DataFrame(ss.fit_transform(X_test), columns=X_test.columns)
934
+ st.header("2.2 Actual vs. Predicted Plot")
935
+
936
+ metrics_table, line, actual_vs_predicted_plot = (
937
+ plot_actual_vs_predicted(
938
+ X_test[date_col],
939
+ y_test,
940
+ test_pred,
941
+ model,
942
+ target_column=sel_target_col,
943
+ is_panel=is_panel,
944
+ )
945
+ ) # Sprint2
946
+
947
+ st.plotly_chart(
948
+ actual_vs_predicted_plot, use_container_width=True
949
+ )
950
+
951
+ st.markdown("## 2.3 Residual Analysis")
952
+ columns = st.columns(2)
953
+ with columns[0]:
954
+ fig = plot_residual_predicted(
955
+ y, test_pred, X_test
956
+ ) # Sprint2
957
+ st.plotly_chart(fig)
958
+
959
+ with columns[1]:
960
+ st.empty()
961
+ fig = qqplot(y, test_pred) # Sprint2
962
+ st.plotly_chart(fig)
963
+
964
+ with columns[0]:
965
+ fig = residual_distribution(y, test_pred) # Sprint2
966
+ st.pyplot(fig)
967
+
968
+ value = False
969
+ save_button_model = st.checkbox(
970
+ "Save this model to tune", key="build_rc_cb"
971
+ ) # , on_click=set_save())
972
+
973
+ if save_button_model:
974
+ mod_name = st.text_input("Enter model name")
975
+ if len(mod_name) > 0:
976
+ mod_name = (
977
+ mod_name + "__" + target_col
978
+ ) # Sprint4 - adding target col to model name
979
+ if is_panel:
980
+ pred_train = model.fittedvalues
981
+ pred_test = mdf_predict(X_test, model, random_eff_df)
982
+ else:
983
+ st.session_state["features_set"] = st.session_state[
984
+ "features_set"
985
+ ] + ["const"]
986
+ pred_train = model.predict(
987
+ X_train_orig[st.session_state["features_set"]]
988
+ )
989
+ pred_test = model.predict(
990
+ X_test[st.session_state["features_set"]]
991
+ )
992
+
993
+ st.session_state["Model"][mod_name] = {
994
+ "Model_object": model,
995
+ "feature_set": st.session_state["features_set"],
996
+ "X_train": X_train_orig,
997
+ "X_test": X_test,
998
+ "y_train": y_train,
999
+ "y_test": y_test,
1000
+ "pred_train": pred_train,
1001
+ "pred_test": pred_test,
1002
+ }
1003
+ st.session_state["X_train"] = X_train_orig
1004
+ st.session_state["X_test_spends"] = test_spends
1005
+ st.session_state["saved_model_names"].append(mod_name)
1006
+ # Sprint3 additions
1007
+ if is_panel:
1008
+ random_eff_df = get_random_effects(
1009
+ media_data, panel_col, model
1010
+ )
1011
+ st.session_state["random_effects"] = random_eff_df
1012
+
1013
+ with open(
1014
+ os.path.join(
1015
+ st.session_state["project_path"], "best_models.pkl"
1016
+ ),
1017
+ "wb",
1018
+ ) as f:
1019
+ pickle.dump(st.session_state["Model"], f)
1020
+ st.success(
1021
+ mod_name
1022
+ + " model saved! Proceed to the next page to tune the model"
1023
+ )
1024
+
1025
+ urm = st.session_state["used_response_metrics"]
1026
+ urm.append(sel_target_col)
1027
+ st.session_state["used_response_metrics"] = list(
1028
+ set(urm)
1029
+ )
1030
+ mod_name = ""
1031
+ # Sprint4 - add the formatted name of the target col to used resp metrics
1032
+ value = False
1033
+
1034
+ st.session_state["project_dct"]["model_build"][
1035
+ "session_state_saved"
1036
+ ] = {}
1037
+ for key in [
1038
+ "Model",
1039
+ "bin_dict",
1040
+ "used_response_metrics",
1041
+ "date",
1042
+ "saved_model_names",
1043
+ "media_data",
1044
+ "X_test_spends",
1045
+ ]:
1046
+ st.session_state["project_dct"]["model_build"][
1047
+ "session_state_saved"
1048
+ ][key] = st.session_state[key]
1049
+
1050
+ project_dct_path = os.path.join(
1051
+ st.session_state["project_path"], "project_dct.pkl"
1052
+ )
1053
+ with open(project_dct_path, "wb") as f:
1054
+ pickle.dump(st.session_state["project_dct"], f)
1055
+
1056
+ update_db("4_Model_Build.py")
1057
+
1058
+ st.toast("💾 Saved Successfully!")
1059
+ else:
1060
+ st.session_state["project_dct"]["model_build"][
1061
+ "show_results_check"
1062
+ ] = False