BlendMMM commited on
Commit
ff89010
·
verified ·
1 Parent(s): fdbbbbf

Upload 11 files

Browse files
pages/10_Saved_Scenarios.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from numerize.numerize import numerize
3
+ import io
4
+ import pandas as pd
5
+ from utilities import (
6
+ format_numbers,
7
+ decimal_formater,
8
+ channel_name_formating,
9
+ load_local_css,
10
+ set_header,
11
+ initialize_data,
12
+ load_authenticator,
13
+ )
14
+ from openpyxl import Workbook
15
+ from openpyxl.styles import Alignment, Font, PatternFill
16
+ import pickle
17
+ import streamlit_authenticator as stauth
18
+ import yaml
19
+ from yaml import SafeLoader
20
+ from classes import class_from_dict
21
+ from utilities import update_db
22
+
23
+ st.set_page_config(layout="wide")
24
+ load_local_css("styles.css")
25
+ set_header()
26
+
27
+ # for k, v in st.session_state.items():
28
+ # if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
29
+ # st.session_state[k] = v
30
+
31
+
32
+ def create_scenario_summary(scenario_dict):
33
+ summary_rows = []
34
+ for channel_dict in scenario_dict["channels"]:
35
+ name_mod = channel_name_formating(channel_dict["name"])
36
+ summary_rows.append(
37
+ [
38
+ name_mod,
39
+ channel_dict.get("actual_total_spends")
40
+ * channel_dict.get("conversion_rate"),
41
+ channel_dict.get("modified_total_spends")
42
+ * channel_dict.get("conversion_rate"),
43
+ channel_dict.get("actual_total_sales"),
44
+ channel_dict.get("modified_total_sales"),
45
+ channel_dict.get("actual_total_sales")
46
+ / (
47
+ channel_dict.get("actual_total_spends")
48
+ * channel_dict.get("conversion_rate")
49
+ ),
50
+ channel_dict.get("modified_total_sales")
51
+ / (
52
+ channel_dict.get("modified_total_spends")
53
+ * channel_dict.get("conversion_rate")
54
+ ),
55
+ channel_dict.get("actual_mroi"),
56
+ channel_dict.get("modified_mroi"),
57
+ channel_dict.get("actual_total_spends")
58
+ * channel_dict.get("conversion_rate")
59
+ / channel_dict.get("actual_total_sales"),
60
+ channel_dict.get("modified_total_spends")
61
+ * channel_dict.get("conversion_rate")
62
+ / channel_dict.get("modified_total_sales"),
63
+ ]
64
+ )
65
+
66
+ summary_rows.append(
67
+ [
68
+ "Total",
69
+ scenario_dict.get("actual_total_spends"),
70
+ scenario_dict.get("modified_total_spends"),
71
+ scenario_dict.get("actual_total_sales"),
72
+ scenario_dict.get("modified_total_sales"),
73
+ scenario_dict.get("actual_total_sales")
74
+ / scenario_dict.get("actual_total_spends"),
75
+ scenario_dict.get("modified_total_sales")
76
+ / scenario_dict.get("modified_total_spends"),
77
+ "-",
78
+ "-",
79
+ scenario_dict.get("actual_total_spends")
80
+ / scenario_dict.get("actual_total_sales"),
81
+ scenario_dict.get("modified_total_spends")
82
+ / scenario_dict.get("modified_total_sales"),
83
+ ]
84
+ )
85
+
86
+ columns_index = pd.MultiIndex.from_product(
87
+ [[""], ["Channel"]], names=["first", "second"]
88
+ )
89
+ columns_index = columns_index.append(
90
+ pd.MultiIndex.from_product(
91
+ [
92
+ ["Spends", "NRPU", "ROI", "MROI", "Spend per NRPU"],
93
+ ["Actual", "Simulated"],
94
+ ],
95
+ names=["first", "second"],
96
+ )
97
+ )
98
+ return pd.DataFrame(summary_rows, columns=columns_index)
99
+
100
+
101
+ def summary_df_to_worksheet(df, ws):
102
+ heading_fill = PatternFill(
103
+ fill_type="solid", start_color="FF11B6BD", end_color="FF11B6BD"
104
+ )
105
+ for j, header in enumerate(df.columns.values):
106
+ col = j + 1
107
+ for i in range(1, 3):
108
+ ws.cell(row=i, column=j + 1, value=header[i - 1]).font = Font(
109
+ bold=True, color="FF11B6BD"
110
+ )
111
+ ws.cell(row=i, column=j + 1).fill = heading_fill
112
+ if col > 1 and (col - 6) % 5 == 0:
113
+ ws.merge_cells(start_row=1, end_row=1, start_column=col - 3, end_column=col)
114
+ ws.cell(row=1, column=col).alignment = Alignment(horizontal="center")
115
+ for i, row in enumerate(df.itertuples()):
116
+ for j, value in enumerate(row):
117
+ if j == 0:
118
+ continue
119
+ elif (j - 2) % 4 == 0 or (j - 3) % 4 == 0:
120
+ ws.cell(row=i + 3, column=j, value=value).number_format = "$#,##0.0"
121
+ else:
122
+ ws.cell(row=i + 3, column=j, value=value)
123
+
124
+
125
+ from openpyxl.utils import get_column_letter
126
+ from openpyxl.styles import Font, PatternFill
127
+ import logging
128
+
129
+
130
+ def scenario_df_to_worksheet(df, ws):
131
+ heading_fill = PatternFill(
132
+ start_color="FF11B6BD", end_color="FF11B6BD", fill_type="solid"
133
+ )
134
+
135
+ for j, header in enumerate(df.columns.values):
136
+ cell = ws.cell(row=1, column=j + 1, value=header)
137
+ cell.font = Font(bold=True, color="FF11B6BD")
138
+ cell.fill = heading_fill
139
+
140
+ for i, row in enumerate(df.itertuples()):
141
+ for j, value in enumerate(
142
+ row[1:], start=1
143
+ ): # Start from index 1 to skip the index column
144
+ try:
145
+ cell = ws.cell(row=i + 2, column=j, value=value)
146
+ if isinstance(value, (int, float)):
147
+ cell.number_format = "$#,##0.0"
148
+ elif isinstance(value, str):
149
+ cell.value = value[:32767]
150
+ else:
151
+ cell.value = str(value)
152
+ except ValueError as e:
153
+ logging.error(
154
+ f"Error assigning value '{value}' to cell {get_column_letter(j)}{i+2}: {e}"
155
+ )
156
+ cell.value = None # Assign None to the cell where the error occurred
157
+
158
+ return ws
159
+
160
+
161
+ def download_scenarios():
162
+ """
163
+ Makes a excel with all saved scenarios and saves it locally
164
+ """
165
+ ## create summary page
166
+ if len(scenarios_to_download) == 0:
167
+ return
168
+ wb = Workbook()
169
+ wb.iso_dates = True
170
+ wb.remove(wb.active)
171
+ st.session_state["xlsx_buffer"] = io.BytesIO()
172
+ summary_df = None
173
+ # print(scenarios_to_download)
174
+ for scenario_name in scenarios_to_download:
175
+ scenario_dict = st.session_state["saved_scenarios"][scenario_name]
176
+ _spends = []
177
+ column_names = ["Date"]
178
+ _sales = None
179
+ dates = None
180
+ summary_rows = []
181
+ for channel in scenario_dict["channels"]:
182
+ if dates is None:
183
+ dates = channel.get("dates")
184
+ _spends.append(dates)
185
+ if _sales is None:
186
+ _sales = channel.get("modified_sales")
187
+ else:
188
+ _sales += channel.get("modified_sales")
189
+ _spends.append(
190
+ channel.get("modified_spends") * channel.get("conversion_rate")
191
+ )
192
+ column_names.append(channel.get("name"))
193
+
194
+ name_mod = channel_name_formating(channel["name"])
195
+ summary_rows.append(
196
+ [
197
+ name_mod,
198
+ channel.get("modified_total_spends")
199
+ * channel.get("conversion_rate"),
200
+ channel.get("modified_total_sales"),
201
+ channel.get("modified_total_sales")
202
+ / channel.get("modified_total_spends")
203
+ * channel.get("conversion_rate"),
204
+ channel.get("modified_mroi"),
205
+ channel.get("modified_total_sales")
206
+ / channel.get("modified_total_spends")
207
+ * channel.get("conversion_rate"),
208
+ ]
209
+ )
210
+ _spends.append(_sales)
211
+ column_names.append("NRPU")
212
+ scenario_df = pd.DataFrame(_spends).T
213
+ scenario_df.columns = column_names
214
+ ## write to sheet
215
+ ws = wb.create_sheet(scenario_name)
216
+ scenario_df_to_worksheet(scenario_df, ws)
217
+ summary_rows.append(
218
+ [
219
+ "Total",
220
+ scenario_dict.get("modified_total_spends"),
221
+ scenario_dict.get("modified_total_sales"),
222
+ scenario_dict.get("modified_total_sales")
223
+ / scenario_dict.get("modified_total_spends"),
224
+ "-",
225
+ scenario_dict.get("modified_total_spends")
226
+ / scenario_dict.get("modified_total_sales"),
227
+ ]
228
+ )
229
+ columns_index = pd.MultiIndex.from_product(
230
+ [[""], ["Channel"]], names=["first", "second"]
231
+ )
232
+ columns_index = columns_index.append(
233
+ pd.MultiIndex.from_product(
234
+ [[scenario_name], ["Spends", "NRPU", "ROI", "MROI", "Spends per NRPU"]],
235
+ names=["first", "second"],
236
+ )
237
+ )
238
+ if summary_df is None:
239
+ summary_df = pd.DataFrame(summary_rows, columns=columns_index)
240
+ summary_df = summary_df.set_index(("", "Channel"))
241
+ else:
242
+ _df = pd.DataFrame(summary_rows, columns=columns_index)
243
+ _df = _df.set_index(("", "Channel"))
244
+ summary_df = summary_df.merge(_df, left_index=True, right_index=True)
245
+ ws = wb.create_sheet("Summary", 0)
246
+ summary_df_to_worksheet(summary_df.reset_index(), ws)
247
+ wb.save(st.session_state["xlsx_buffer"])
248
+ st.session_state["disable_download_button"] = False
249
+
250
+
251
+ def disable_download_button():
252
+ st.session_state["disable_download_button"] = True
253
+
254
+
255
+ def transform(x):
256
+ if x.name == ("", "Channel"):
257
+ return x
258
+ elif x.name[0] == "ROI" or x.name[0] == "MROI":
259
+ return x.apply(
260
+ lambda y: (
261
+ y
262
+ if isinstance(y, str)
263
+ else decimal_formater(
264
+ format_numbers(y, include_indicator=False, n_decimals=4),
265
+ n_decimals=4,
266
+ )
267
+ )
268
+ )
269
+ else:
270
+ return x.apply(lambda y: y if isinstance(y, str) else format_numbers(y))
271
+
272
+
273
+ def delete_scenario():
274
+ if selected_scenario in st.session_state["saved_scenarios"]:
275
+ del st.session_state["saved_scenarios"][selected_scenario]
276
+ with open("../saved_scenarios.pkl", "wb") as f:
277
+ pickle.dump(st.session_state["saved_scenarios"], f)
278
+
279
+
280
+ def load_scenario():
281
+ if selected_scenario in st.session_state["saved_scenarios"]:
282
+ st.session_state["scenario"] = class_from_dict(selected_scenario_details)
283
+
284
+
285
+ authenticator = st.session_state.get("authenticator")
286
+ if authenticator is None:
287
+ authenticator = load_authenticator()
288
+
289
+ name, authentication_status, username = authenticator.login("Login", "main")
290
+ auth_status = st.session_state.get("authentication_status")
291
+
292
+ if auth_status == True:
293
+ is_state_initiaized = st.session_state.get("initialized", False)
294
+ if not is_state_initiaized:
295
+ # print("Scenario page state reloaded")
296
+ initialize_data()
297
+
298
+ saved_scenarios = st.session_state["saved_scenarios"]
299
+
300
+ if len(saved_scenarios) == 0:
301
+ st.header("No saved scenarios")
302
+
303
+ else:
304
+ selected_scenario_list = list(saved_scenarios.keys())
305
+ if "selected_scenario_selectbox_key" not in st.session_state:
306
+ st.session_state["selected_scenario_selectbox_key"] = (
307
+ selected_scenario_list[
308
+ st.session_state["project_dct"]["saved_scenarios"][
309
+ "selected_scenario_selectbox_key"
310
+ ]
311
+ ]
312
+ )
313
+
314
+ col_a, col_b = st.columns(2)
315
+ selected_scenario = col_a.selectbox(
316
+ "Pick a scenario to view details",
317
+ selected_scenario_list,
318
+ # key="selected_scenario_selectbox_key",
319
+ index=st.session_state["project_dct"]["saved_scenarios"][
320
+ "selected_scenario_selectbox_key"
321
+ ],
322
+ )
323
+ st.session_state["project_dct"]["saved_scenarios"][
324
+ "selected_scenario_selectbox_key"
325
+ ] = selected_scenario_list.index(selected_scenario)
326
+
327
+ scenarios_to_download = col_b.multiselect(
328
+ "Select scenarios to download",
329
+ list(saved_scenarios.keys()),
330
+ on_change=disable_download_button,
331
+ )
332
+
333
+ with col_a:
334
+ col3, col4 = st.columns(2)
335
+
336
+ col4.button(
337
+ "Delete scenarios",
338
+ on_click=delete_scenario,
339
+ use_container_width=True,
340
+ )
341
+ col3.button(
342
+ "Load Scenario",
343
+ on_click=load_scenario,
344
+ use_container_width=True,
345
+ )
346
+
347
+ with col_b:
348
+ col1, col2 = st.columns(2)
349
+
350
+ col1.button(
351
+ "Prepare download",
352
+ on_click=download_scenarios,
353
+ use_container_width=True,
354
+ )
355
+ col2.download_button(
356
+ label="Download Scenarios",
357
+ data=st.session_state["xlsx_buffer"].getvalue(),
358
+ file_name="scenarios.xlsx",
359
+ mime="application/vnd.ms-excel",
360
+ disabled=st.session_state["disable_download_button"],
361
+ on_click=disable_download_button,
362
+ use_container_width=True,
363
+ )
364
+
365
+ # column_1, column_2, column_3 = st.columns((6, 1, 1))
366
+ # with column_1:
367
+ # st.header(selected_scenario)
368
+ # with column_2:
369
+ # st.button("Delete scenarios", on_click=delete_scenario)
370
+ # with column_3:
371
+ # st.button("Load Scenario", on_click=load_scenario)
372
+
373
+ selected_scenario_details = saved_scenarios[selected_scenario]
374
+
375
+ #st.write(pd.DataFrame(selected_scenario_details))
376
+ pd.set_option("display.max_colwidth", 100)
377
+
378
+ st.markdown(
379
+ create_scenario_summary(selected_scenario_details)
380
+ .transform(transform)
381
+ .style.set_table_styles(
382
+ [
383
+ {"selector": "th", "props": [("background-color", "#FFFFFF")]},
384
+ {
385
+ "selector": "tr:nth-child(even)",
386
+ "props": [("background-color", "#FFFFFF")],
387
+ },
388
+ ]
389
+ )
390
+ .to_html(),
391
+ unsafe_allow_html=True,
392
+ )
393
+
394
+ elif auth_status == False:
395
+ st.error("Username/Password is incorrect")
396
+
397
+ if auth_status != True:
398
+ try:
399
+ username_forgot_pw, email_forgot_password, random_password = (
400
+ authenticator.forgot_password("Forgot password")
401
+ )
402
+ if username_forgot_pw:
403
+ st.success("New password sent securely")
404
+ # Random password to be transferred to user securely
405
+ elif username_forgot_pw == False:
406
+ st.error("Username not found")
407
+ except Exception as e:
408
+ st.error(e)
pages/11_Model_Optimized_Recommendation.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from numerize.numerize import numerize
3
+ import pandas as pd
4
+ from utilities import (format_numbers,decimal_formater,
5
+ load_local_css,set_header,
6
+ initialize_data,
7
+ load_authenticator)
8
+ import pickle
9
+ import streamlit_authenticator as stauth
10
+ import yaml
11
+ from yaml import SafeLoader
12
+ from classes import class_from_dict
13
+ import plotly.express as px
14
+ import numpy as np
15
+ import plotly.graph_objects as go
16
+ import pandas as pd
17
+ from plotly.subplots import make_subplots
18
+ import sqlite3
19
+ from utilities import update_db
20
+
21
+ def format_number(x):
22
+ if x >= 1_000_000:
23
+ return f'{x / 1_000_000:.2f}M'
24
+ elif x >= 1_000:
25
+ return f'{x / 1_000:.2f}K'
26
+ else:
27
+ return f'{x:.2f}'
28
+
29
+ def summary_plot(data, x, y, title, text_column, color, format_as_percent=False, format_as_decimal=False):
30
+ fig = px.bar(data, x=x, y=y, orientation='h',
31
+ title=title, text=text_column, color=color)
32
+ fig.update_layout(showlegend=False)
33
+ data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
34
+
35
+ # Update the format of the displayed text based on the chosen format
36
+ if format_as_percent:
37
+ fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
38
+ elif format_as_decimal:
39
+ fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
40
+ else:
41
+ fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
42
+
43
+ fig.update_layout(xaxis_title=x, yaxis_title='Channel Name', showlegend=False)
44
+ return fig
45
+
46
+
47
+ def stacked_summary_plot(data, x, y, title, text_column, color_column, stack_column=None, format_as_percent=False, format_as_decimal=False):
48
+ fig = px.bar(data, x=x, y=y, orientation='h',
49
+ title=title, text=text_column, color=color_column, facet_col=stack_column)
50
+ fig.update_layout(showlegend=False)
51
+ data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
52
+
53
+ # Update the format of the displayed text based on the chosen format
54
+ if format_as_percent:
55
+ fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
56
+ elif format_as_decimal:
57
+ fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
58
+ else:
59
+ fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
60
+
61
+ fig.update_layout(xaxis_title=x, yaxis_title='', showlegend=False)
62
+ return fig
63
+
64
+
65
+
66
+ def funnel_plot(data, x, y, title, text_column, color_column, format_as_percent=False, format_as_decimal=False):
67
+ data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
68
+
69
+ # Round the numeric values in the text column to two decimal points
70
+ data[text_column] = data[text_column].round(2)
71
+
72
+ # Create a color map for categorical data
73
+ color_map = {category: f'rgb({i * 30 % 255},{i * 50 % 255},{i * 70 % 255})' for i, category in enumerate(data[color_column].unique())}
74
+
75
+ fig = go.Figure(go.Funnel(
76
+ y=data[y],
77
+ x=data[x],
78
+ text=data[text_column],
79
+ marker=dict(color=data[color_column].map(color_map)),
80
+ textinfo="value",
81
+ hoverinfo='y+x+text'
82
+ ))
83
+
84
+ # Update the format of the displayed text based on the chosen format
85
+ if format_as_percent:
86
+ fig.update_layout(title=title, funnelmode="percent")
87
+ elif format_as_decimal:
88
+ fig.update_layout(title=title, funnelmode="overlay")
89
+ else:
90
+ fig.update_layout(title=title, funnelmode="group")
91
+
92
+ return fig
93
+
94
+
95
+ st.set_page_config(layout='wide')
96
+ load_local_css('styles.css')
97
+ set_header()
98
+
99
+ # for k, v in st.session_state.items():
100
+ # if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
101
+ # st.session_state[k] = v
102
+
103
+ st.empty()
104
+ st.header('Model Result Analysis')
105
+ spends_data=pd.read_excel('Overview_data_test.xlsx')
106
+
107
+ with open('summary_df.pkl', 'rb') as file:
108
+ summary_df_sorted = pickle.load(file)
109
+ #st.write(summary_df_sorted)
110
+
111
+ selected_scenario= st.selectbox('Select Saved Scenarios',['S1','S2'])
112
+ summary_df_sorted=summary_df_sorted.sort_values(by=['Optimized_spend'])
113
+ st.header('Optimized Spends Overview')
114
+
115
+
116
+ channel_colors = px.colors.qualitative.Plotly
117
+
118
+ fig = make_subplots(rows=1, cols=3, subplot_titles=('Actual Spend','Planned Spend', 'Delta'), horizontal_spacing=0.05)
119
+
120
+ for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
121
+ channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
122
+ channel_color = channel_colors[i % len(channel_colors)]
123
+
124
+ fig.add_trace(go.Bar(x=channel_df['Actual_spend'],
125
+ y=channel_df['Channel_name'],
126
+ text=channel_df['Actual_spend'].apply(format_number),
127
+ marker_color=channel_color,
128
+ orientation='h'), row=1, col=1)
129
+
130
+ fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
131
+ y=channel_df['Channel_name'],
132
+ text=channel_df['Optimized_spend'].apply(format_number),
133
+ marker_color=channel_color,
134
+ orientation='h', showlegend=False), row=1, col=2)
135
+
136
+ fig.add_trace(go.Bar(x=channel_df['Delta_percent'],
137
+ y=channel_df['Channel_name'],
138
+ text=channel_df['Delta_percent'].apply(format_number),
139
+ marker_color=channel_color,
140
+ orientation='h', showlegend=False), row=1, col=3)
141
+ fig.update_layout(
142
+ height=600,
143
+ width=900,
144
+ title='',
145
+ showlegend=False
146
+ )
147
+
148
+ fig.update_yaxes(showticklabels=False ,row=1, col=2 )
149
+ fig.update_yaxes(showticklabels=False, row=1, col=3)
150
+
151
+ fig.update_xaxes(showticklabels=False, row=1, col=1)
152
+ fig.update_xaxes(showticklabels=False, row=1, col=2)
153
+ fig.update_xaxes(showticklabels=False, row=1, col=3)
154
+
155
+
156
+ st.plotly_chart(fig, use_container_width=True)
157
+
158
+
159
+ # ___columns=st.columns(3)
160
+
161
+
162
+ # with ___columns[2]:
163
+
164
+ # fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent',color='Channel_name')
165
+ # st.plotly_chart(fig,use_container_width=True)
166
+ # with ___columns[0]:
167
+ # fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend',color='Channel_name')
168
+ # st.plotly_chart(fig,use_container_width=True)
169
+ # with ___columns[1]:
170
+ # fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
171
+ # st.plotly_chart(fig,use_container_width=False)
172
+ summary_df_sorted['Perc_alloted']=np.round(summary_df_sorted['Optimized_spend']/summary_df_sorted['Optimized_spend'].sum(),2)
173
+ st.header(' Budget Allocation')
174
+
175
+ fig = make_subplots(rows=1, cols=2, subplot_titles=('Planned Spend','% Split'), horizontal_spacing=0.05)
176
+
177
+ for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
178
+ channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
179
+ channel_color = channel_colors[i % len(channel_colors)]
180
+
181
+ fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
182
+ y=channel_df['Channel_name'],
183
+ text=channel_df['Optimized_spend'].apply(format_number),
184
+ marker_color=channel_color,
185
+ orientation='h'), row=1, col=1)
186
+
187
+ fig.add_trace(go.Bar(x=channel_df['Perc_alloted'],
188
+ y=channel_df['Channel_name'],
189
+ text=channel_df['Perc_alloted'].apply(lambda x: f'{100*x:.0f}%'),
190
+ marker_color=channel_color,
191
+ orientation='h', showlegend=False), row=1, col=2)
192
+
193
+ fig.update_layout(
194
+ height=600,
195
+ width=900,
196
+ title='',
197
+ showlegend=False
198
+ )
199
+
200
+ fig.update_yaxes(showticklabels=False ,row=1, col=2 )
201
+ fig.update_yaxes(showticklabels=False, row=1, col=3)
202
+
203
+ fig.update_xaxes(showticklabels=False, row=1, col=1)
204
+ fig.update_xaxes(showticklabels=False, row=1, col=2)
205
+ fig.update_xaxes(showticklabels=False, row=1, col=3)
206
+
207
+
208
+ st.plotly_chart(fig, use_container_width=True)
209
+
210
+ # summary_df_sorted['Perc_alloted']=np.round(summary_df_sorted['Optimized_spend']/summary_df_sorted['Optimized_spend'].sum(),2)
211
+ # columns2=st.columns(2)
212
+ # with columns2[0]:
213
+ # fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
214
+ # st.plotly_chart(fig,use_container_width=True)
215
+ # with columns2[1]:
216
+ # fig=summary_plot(summary_df_sorted, x='Perc_alloted', y='Channel_name', title='% Split', text_column='Perc_alloted',color='Channel_name',format_as_percent=True)
217
+ # st.plotly_chart(fig,use_container_width=True)
218
+
219
+
220
+
221
+
222
+
223
+
224
+
225
+
226
+
227
+ if 'raw_data' not in st.session_state:
228
+ st.session_state['raw_data']=pd.read_excel('raw_data_nov7_combined1.xlsx')
229
+ st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['MediaChannelName'].isin(summary_df_sorted['Channel_name'].unique())]
230
+ st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['Date'].isin(spends_data["Date"].unique())]
231
+
232
+
233
+
234
+ #st.write(st.session_state['raw_data']['ResponseMetricName'])
235
+ # st.write(st.session_state['raw_data'])
236
+
237
+
238
+ st.header('Response Forecast Overview')
239
+ raw_data=st.session_state['raw_data']
240
+ effectiveness_overall=raw_data.groupby('ResponseMetricName').agg({'ResponseMetricValue': 'sum'}).reset_index()
241
+ effectiveness_overall['Efficiency']=effectiveness_overall['ResponseMetricValue'].map(lambda x: x/raw_data['Media Spend'].sum() )
242
+ # st.write(effectiveness_overall)
243
+
244
+ columns6=st.columns(3)
245
+
246
+ effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False,inplace=True)
247
+ effectiveness_overall=np.round(effectiveness_overall,2)
248
+ effectiveness_overall['ResponseMetric'] = effectiveness_overall['ResponseMetricName'].apply(lambda x: 'BAU' if 'BAU' in x else ('Gamified' if 'Gamified' in x else x))
249
+ # effectiveness_overall=np.where(effectiveness_overall[effectiveness_overall['ResponseMetricName']=="Adjusted Account Approval BAU"],"Adjusted Account Approval BAU",effectiveness_overall['ResponseMetricName'])
250
+
251
+ effectiveness_overall.replace({'ResponseMetricName':{'BAU approved clients - Appsflyer':'Approved clients - Appsflyer',
252
+ 'Gamified approved clients - Appsflyer':'Approved clients - Appsflyer'}},inplace=True)
253
+
254
+ # st.write(effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False))
255
+
256
+
257
+ condition = effectiveness_overall['ResponseMetricName'] == "Adjusted Account Approval BAU"
258
+ condition1= effectiveness_overall['ResponseMetricName'] == "Approved clients - Appsflyer"
259
+ effectiveness_overall['ResponseMetric'] = np.where(condition, "Adjusted Account Approval BAU", effectiveness_overall['ResponseMetric'])
260
+
261
+ effectiveness_overall['ResponseMetricName'] = np.where(condition1, "Approved clients - Appsflyer (BAU, Gamified)", effectiveness_overall['ResponseMetricName'])
262
+ # effectiveness_overall=pd.DataFrame({'ResponseMetricName':["App Installs - Appsflyer",'Account Requests - Appsflyer',
263
+ # 'Total Adjusted Account Approval','Adjusted Account Approval BAU',
264
+ # 'Approved clients - Appsflyer','Approved clients - Appsflyer'],
265
+ # 'ResponseMetricValue':[683067,367020,112315,79768,36661,16834],
266
+ # 'Efficiency':[1.24,0.67,0.2,0.14,0.07,0.03],
267
+ custom_colors = {
268
+ 'App Installs - Appsflyer': 'rgb(255, 135, 0)', # Steel Blue (Blue)
269
+ 'Account Requests - Appsflyer': 'rgb(125, 239, 161)', # Cornflower Blue (Blue)
270
+ 'Adjusted Account Approval': 'rgb(129, 200, 255)', # Dodger Blue (Blue)
271
+ 'Adjusted Account Approval BAU': 'rgb(255, 207, 98)', # Light Sky Blue (Blue)
272
+ 'Approved clients - Appsflyer': 'rgb(0, 97, 198)', # Light Blue (Blue)
273
+ "BAU": 'rgb(41, 176, 157)', # Steel Blue (Blue)
274
+ "Gamified": 'rgb(213, 218, 229)' # Silver (Gray)
275
+ # Add more categories and their respective shades of blue as needed
276
+ }
277
+
278
+
279
+
280
+
281
+
282
+
283
+ with columns6[0]:
284
+ revenue=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Total Approved Accounts - Revenue']['ResponseMetricValue']).iloc[0]
285
+ revenue=round(revenue / 1_000_000, 2)
286
+
287
+ # st.metric('Total Revenue', f"${revenue} M")
288
+ # with columns6[1]:
289
+ # BAU=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='BAU approved clients - Revenue']['ResponseMetricValue']).iloc[0]
290
+ # BAU=round(BAU / 1_000_000, 2)
291
+ # st.metric('BAU approved clients - Revenue', f"${BAU} M")
292
+ # with columns6[2]:
293
+ # Gam=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Gamified approved clients - Revenue']['ResponseMetricValue']).iloc[0]
294
+ # Gam=round(Gam / 1_000_000, 2)
295
+ # st.metric('Gamified approved clients - Revenue', f"${Gam} M")
296
+
297
+ # st.write(effectiveness_overall)
298
+ data = {'Revenue': ['BAU approved clients - Revenue', 'Gamified approved clients- Revenue'],
299
+ 'ResponseMetricValue': [70200000, 1770000],
300
+ 'Efficiency':[127.54,3.21]}
301
+ df = pd.DataFrame(data)
302
+
303
+
304
+ columns9=st.columns([0.60,0.40])
305
+ with columns9[0]:
306
+ figd = px.pie(df,
307
+ names='Revenue',
308
+ values='ResponseMetricValue',
309
+ hole=0.3, # set the size of the hole in the donut
310
+ title='Effectiveness')
311
+ figd.update_layout(
312
+ margin=dict(l=0, r=0, b=0, t=0),width=100, height=180,legend=dict(
313
+ orientation='v', # set orientation to horizontal
314
+ x=0, # set x to 0 to move to the left
315
+ y=0.8 # adjust y as needed
316
+ )
317
+ )
318
+
319
+ st.plotly_chart(figd, use_container_width=True)
320
+
321
+ with columns9[1]:
322
+ figd1 = px.pie(df,
323
+ names='Revenue',
324
+ values='Efficiency',
325
+ hole=0.3, # set the size of the hole in the donut
326
+ title='Efficiency')
327
+ figd1.update_layout(
328
+ margin=dict(l=0, r=0, b=0, t=0),width=100,height=180,showlegend=False
329
+ )
330
+ st.plotly_chart(figd1, use_container_width=True)
331
+
332
+ effectiveness_overall['Response Metric Name']=effectiveness_overall['ResponseMetricName']
333
+
334
+
335
+
336
+ columns4= st.columns([0.55,0.45])
337
+ with columns4[0]:
338
+ fig=px.funnel(effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
339
+ 'BAU approved clients - Revenue',
340
+ 'Gamified approved clients - Revenue',
341
+ "Total Approved Accounts - Appsflyer"]))],
342
+ x='ResponseMetricValue', y='Response Metric Name',color='ResponseMetric',
343
+ color_discrete_map=custom_colors,title='Effectiveness',
344
+ labels=None)
345
+ custom_y_labels=['App Installs - Appsflyer','Account Requests - Appsflyer','Adjusted Account Approval','Adjusted Account Approval BAU',
346
+ "Approved clients - Appsflyer (BAU, Gamified)"
347
+ ]
348
+ fig.update_layout(showlegend=False,
349
+ yaxis=dict(
350
+ tickmode='array',
351
+ ticktext=custom_y_labels,
352
+ )
353
+ )
354
+ fig.update_traces(textinfo='value', textposition='inside', texttemplate='%{x:.2s} ', hoverinfo='y+x+percent initial')
355
+
356
+ last_trace_index = len(fig.data) - 1
357
+ fig.update_traces(marker=dict(line=dict(color='black', width=2)), selector=dict(marker=dict(color='blue')))
358
+
359
+ st.plotly_chart(fig,use_container_width=True)
360
+
361
+
362
+
363
+
364
+
365
+ with columns4[1]:
366
+
367
+ # Your existing code for creating the bar chart
368
+ fig1 = px.bar((effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
369
+ 'BAU approved clients - Revenue',
370
+ 'Gamified approved clients - Revenue',
371
+ "Total Approved Accounts - Appsflyer"]))]).sort_values(by='ResponseMetricValue'),
372
+ x='Efficiency', y='Response Metric Name',
373
+ color_discrete_map=custom_colors, color='ResponseMetric',
374
+ labels=None,text_auto=True,title='Efficiency'
375
+ )
376
+
377
+ # Update layout and traces
378
+ fig1.update_traces(customdata=effectiveness_overall['Efficiency'],
379
+ textposition='auto')
380
+ fig1.update_layout(showlegend=False)
381
+ fig1.update_yaxes(title='',showticklabels=False)
382
+ fig1.update_xaxes(title='',showticklabels=False)
383
+ fig1.update_xaxes(tickfont=dict(size=20))
384
+ fig1.update_yaxes(tickfont=dict(size=20))
385
+ st.plotly_chart(fig1, use_container_width=True)
386
+
387
+
388
+ effectiveness_overall_revenue=pd.DataFrame({'ResponseMetricName':['Approved Clients','Approved Clients'],
389
+ 'ResponseMetricValue':[70201070,1768900],
390
+ 'Efficiency':[127.54,3.21],
391
+ 'ResponseMetric':['BAU','Gamified']
392
+ })
393
+ # from plotly.subplots import make_subplots
394
+ # fig = make_subplots(rows=1, cols=2,
395
+ # subplot_titles=["Effectiveness", "Efficiency"])
396
+
397
+ # # Add first plot as subplot
398
+ # fig.add_trace(go.Funnel(
399
+ # x = fig.data[0].x,
400
+ # y = fig.data[0].y,
401
+ # textinfo = 'value+percent initial',
402
+ # hoverinfo = 'x+y+percent initial'
403
+ # ), row=1, col=1)
404
+
405
+ # # Update layout for first subplot
406
+ # fig.update_xaxes(title_text="Response Metric Value", row=1, col=1)
407
+ # fig.update_yaxes(ticktext = custom_y_labels, row=1, col=1)
408
+
409
+ # # Add second plot as subplot
410
+ # fig.add_trace(go.Bar(
411
+ # x = fig1.data[0].x,
412
+ # y = fig1.data[0].y,
413
+ # customdata = fig1.data[0].customdata,
414
+ # textposition = 'auto'
415
+ # ), row=1, col=2)
416
+
417
+ # # Update layout for second subplot
418
+ # fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
419
+ # fig.update_yaxes(title='', showticklabels=False, row=1, col=2)
420
+
421
+ # fig.update_layout(height=600, width=800, title_text="Key Metrics")
422
+ # st.plotly_chart(fig)
423
+
424
+
425
+ st.header('Return Forecast by Media Channel')
426
+ with st.expander("Return Forecast by Media Channel"):
427
+ metric_data=[val for val in list(st.session_state['raw_data']['ResponseMetricName'].unique()) if val!=np.NaN]
428
+ # st.write(metric_data)
429
+ metric=st.selectbox('Select Metric',metric_data,index=1)
430
+
431
+ selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
432
+ # st.dataframe(selected_metric.head(2))
433
+ selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
434
+ effectiveness=selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()
435
+ effectiveness_df=pd.DataFrame({'Channel':effectiveness.index,"ResponseMetricValue":effectiveness.values})
436
+
437
+ summary_df_sorted=summary_df_sorted.merge(effectiveness_df,left_on="Channel_name",right_on='Channel')
438
+
439
+ #
440
+ summary_df_sorted['Efficiency'] = summary_df_sorted['ResponseMetricValue'] / summary_df_sorted['Optimized_spend']
441
+ summary_df_sorted=summary_df_sorted.sort_values(by='Optimized_spend',ascending=True)
442
+ #st.dataframe(summary_df_sorted)
443
+
444
+ channel_colors = px.colors.qualitative.Plotly
445
+
446
+ fig = make_subplots(rows=1, cols=3, subplot_titles=('Optimized Spends', 'Effectiveness', 'Efficiency'), horizontal_spacing=0.05)
447
+
448
+ for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
449
+ channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
450
+ channel_color = channel_colors[i % len(channel_colors)]
451
+
452
+ fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
453
+ y=channel_df['Channel_name'],
454
+ text=channel_df['Optimized_spend'].apply(format_number),
455
+ marker_color=channel_color,
456
+ orientation='h'), row=1, col=1)
457
+
458
+ fig.add_trace(go.Bar(x=channel_df['ResponseMetricValue'],
459
+ y=channel_df['Channel_name'],
460
+ text=channel_df['ResponseMetricValue'].apply(format_number),
461
+ marker_color=channel_color,
462
+ orientation='h', showlegend=False), row=1, col=2)
463
+
464
+ fig.add_trace(go.Bar(x=channel_df['Efficiency'],
465
+ y=channel_df['Channel_name'],
466
+ text=channel_df['Efficiency'].apply(format_number),
467
+ marker_color=channel_color,
468
+ orientation='h', showlegend=False), row=1, col=3)
469
+
470
+ fig.update_layout(
471
+ height=600,
472
+ width=900,
473
+ title='Media Channel Performance',
474
+ showlegend=False
475
+ )
476
+
477
+ fig.update_yaxes(showticklabels=False ,row=1, col=2 )
478
+ fig.update_yaxes(showticklabels=False, row=1, col=3)
479
+
480
+ fig.update_xaxes(showticklabels=False, row=1, col=1)
481
+ fig.update_xaxes(showticklabels=False, row=1, col=2)
482
+ fig.update_xaxes(showticklabels=False, row=1, col=3)
483
+
484
+
485
+ st.plotly_chart(fig, use_container_width=True)
486
+
487
+
488
+
489
+ # columns= st.columns(3)
490
+ # with columns[0]:
491
+ # fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='', text_column='Optimized_spend',color='Channel_name')
492
+ # st.plotly_chart(fig,use_container_width=True)
493
+ # with columns[1]:
494
+
495
+ # # effectiveness=(selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()).values
496
+ # # effectiveness_df=pd.DataFrame({'Channel':st.session_state['raw_data']['MediaChannelName'].unique(),"ResponseMetricValue":effectiveness})
497
+ # # # effectiveness.reset_index(inplace=True)
498
+ # # # st.dataframe(effectiveness.head())
499
+
500
+
501
+ # fig=summary_plot(summary_df_sorted, x='ResponseMetricValue', y='Channel_name', title='Effectiveness', text_column='ResponseMetricValue',color='Channel_name')
502
+ # st.plotly_chart(fig,use_container_width=True)
503
+
504
+ # with columns[2]:
505
+ # fig=summary_plot(summary_df_sorted, x='Efficiency', y='Channel_name', title='Efficiency', text_column='Efficiency',color='Channel_name',format_as_decimal=True)
506
+ # st.plotly_chart(fig,use_container_width=True)
507
+
508
+
509
+ # Create figure with subplots
510
+ # fig = make_subplots(rows=1, cols=2)
511
+
512
+ # # Add funnel plot to subplot 1
513
+ # fig.add_trace(
514
+ # go.Funnel(
515
+ # x=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricValue'],
516
+ # y=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricName'],
517
+ # textposition="inside",
518
+ # texttemplate="%{x:.2s}",
519
+ # customdata=effectiveness_overall['Efficiency'],
520
+ # hovertemplate="%{customdata:.2f}<extra></extra>"
521
+ # ),
522
+ # row=1, col=1
523
+ # )
524
+
525
+ # # Add bar plot to subplot 2
526
+ # fig.add_trace(
527
+ # go.Bar(
528
+ # x=effectiveness_overall.sort_values(by='ResponseMetricValue')['Efficiency'],
529
+ # y=effectiveness_overall.sort_values(by='ResponseMetricValue')['ResponseMetricName'],
530
+ # marker_color=effectiveness_overall['ResponseMetric'],
531
+ # customdata=effectiveness_overall['Efficiency'],
532
+ # hovertemplate="%{customdata:.2f}<extra></extra>",
533
+ # textposition="outside"
534
+ # ),
535
+ # row=1, col=2
536
+ # )
537
+
538
+ # # Update layout
539
+ # fig.update_layout(title_text="Effectiveness")
540
+ # fig.update_yaxes(title_text="", row=1, col=1)
541
+ # fig.update_yaxes(title_text="", showticklabels=False, row=1, col=2)
542
+ # fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
543
+
544
+ # # Show figure
545
+ # st.plotly_chart(fig)
pages/1_Data_Import.py ADDED
@@ -0,0 +1,1501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import streamlit as st
3
+ import os
4
+
5
+ st.set_page_config(
6
+ page_title="Data Import",
7
+ page_icon=":shark:",
8
+ layout="wide",
9
+ initial_sidebar_state="collapsed",
10
+ )
11
+
12
+ import pickle
13
+ import pandas as pd
14
+ from utilities import set_header, load_local_css,update_db,project_selection
15
+ import streamlit_authenticator as stauth
16
+ import yaml
17
+ from yaml import SafeLoader
18
+ import sqlite3
19
+
20
+ load_local_css("styles.css")
21
+ set_header()
22
+
23
+ for k, v in st.session_state.items():
24
+ if (
25
+ k not in ["logout", "login", "config"]
26
+ and not k.startswith("FormSubmitter")
27
+ and not k.startswith("data-editor")
28
+ ):
29
+ st.session_state[k] = v
30
+ with open("config.yaml") as file:
31
+ config = yaml.load(file, Loader=SafeLoader)
32
+ st.session_state["config"] = config
33
+ authenticator = stauth.Authenticate(
34
+ config["credentials"],
35
+ config["cookie"]["name"],
36
+ config["cookie"]["key"],
37
+ config["cookie"]["expiry_days"],
38
+ config["preauthorized"],
39
+ )
40
+ st.session_state["authenticator"] = authenticator
41
+ name, authentication_status, username = authenticator.login("Login", "main")
42
+ auth_status = st.session_state.get("authentication_status")
43
+
44
+ if auth_status == True:
45
+ authenticator.logout("Logout", "main")
46
+ is_state_initiaized = st.session_state.get("initialized", False)
47
+
48
+ if not is_state_initiaized:
49
+ a=1
50
+
51
+ if "project_name" not in st.session_state:
52
+ st.session_state["project_name"] = None
53
+
54
+ if "project_dct" not in st.session_state:
55
+ # home()
56
+ project_selection(name)
57
+ #st.write(st.session_state['project_name'])
58
+
59
+ cols1 = st.columns([2, 1])
60
+
61
+ with cols1[0]:
62
+ st.markdown(f"**Welcome {name}**")
63
+ with cols1[1]:
64
+ st.markdown(
65
+ f"**Current Project: {st.session_state['project_name']}**"
66
+ )
67
+
68
+ # Function to validate date column in dataframe
69
+
70
+
71
+ #st.warning("please select a project from Home page")
72
+ #st.stop()
73
+
74
+ def validate_date_column(df):
75
+ try:
76
+ # Attempt to convert the 'Date' column to datetime
77
+ df["date"] = pd.to_datetime(df["date"], format="%d-%m-%Y")
78
+ return True
79
+ except:
80
+ return False
81
+
82
+ # Function to determine data interval
83
+ def determine_data_interval(common_freq):
84
+ if common_freq == 1:
85
+ return "daily"
86
+ elif common_freq == 7:
87
+ return "weekly"
88
+ elif 28 <= common_freq <= 31:
89
+ return "monthly"
90
+ else:
91
+ return "irregular"
92
+
93
+ # Function to read each uploaded Excel file into a pandas DataFrame and stores them in a dictionary
94
+ st.cache_resource(show_spinner=False)
95
+
96
+ def files_to_dataframes(uploaded_files):
97
+ df_dict = {}
98
+ for uploaded_file in uploaded_files:
99
+ # Extract file name without extension
100
+ file_name = uploaded_file.name.rsplit(".", 1)[0]
101
+
102
+ # Check for duplicate file names
103
+ if file_name in df_dict:
104
+ st.warning(
105
+ f"Duplicate File: {file_name}. This file will be skipped.",
106
+ icon="⚠️",
107
+ )
108
+ continue
109
+
110
+ # Read the file into a DataFrame
111
+ df = pd.read_excel(uploaded_file)
112
+
113
+ # Convert all column names to lowercase
114
+ df.columns = df.columns.str.lower().str.strip()
115
+
116
+ # Separate numeric and non-numeric columns
117
+ numeric_cols = list(df.select_dtypes(include=["number"]).columns)
118
+ non_numeric_cols = [
119
+ col
120
+ for col in df.select_dtypes(exclude=["number"]).columns
121
+ if col.lower() != "date"
122
+ ]
123
+
124
+ # Check for 'Date' column
125
+ if not (validate_date_column(df) and len(numeric_cols) > 0):
126
+ st.warning(
127
+ f"File Name: {file_name} ➜ Please upload data with Date column in 'DD-MM-YYYY' format and at least one media/exogenous column. This file will be skipped.",
128
+ icon="⚠️",
129
+ )
130
+ continue
131
+
132
+ # Check for interval
133
+ common_freq = common_freq = (
134
+ pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
135
+ )
136
+ # Calculate the data interval (daily, weekly, monthly or irregular)
137
+ interval = determine_data_interval(common_freq)
138
+ if interval == "irregular":
139
+ st.warning(
140
+ f"File Name: {file_name} ➜ Please upload data in daily, weekly or monthly interval. This file will be skipped.",
141
+ icon="⚠️",
142
+ )
143
+ continue
144
+
145
+ # Store both DataFrames in the dictionary under their respective keys
146
+ df_dict[file_name] = {
147
+ "numeric": numeric_cols,
148
+ "non_numeric": non_numeric_cols,
149
+ "interval": interval,
150
+ "df": df,
151
+ }
152
+
153
+ return df_dict
154
+
155
+ # Function to adjust dataframe granularity
156
+ def adjust_dataframe_granularity(df, current_granularity, target_granularity):
157
+ # Set index
158
+ df.set_index("date", inplace=True)
159
+
160
+ # Define aggregation rules for resampling
161
+ aggregation_rules = {
162
+ col: "sum" if pd.api.types.is_numeric_dtype(df[col]) else "first"
163
+ for col in df.columns
164
+ }
165
+
166
+ # Initialize resampled_df
167
+ resampled_df = df
168
+ if current_granularity == "daily" and target_granularity == "weekly":
169
+ resampled_df = df.resample("W-MON", closed="left", label="left").agg(
170
+ aggregation_rules
171
+ )
172
+
173
+ elif current_granularity == "daily" and target_granularity == "monthly":
174
+ resampled_df = df.resample("MS", closed="left", label="left").agg(
175
+ aggregation_rules
176
+ )
177
+
178
+ elif current_granularity == "daily" and target_granularity == "daily":
179
+ resampled_df = df.resample("D").agg(aggregation_rules)
180
+
181
+ elif (
182
+ current_granularity in ["weekly", "monthly"]
183
+ and target_granularity == "daily"
184
+ ):
185
+ # For higher to lower granularity, distribute numeric and replicate non-numeric values equally across the new period
186
+ expanded_data = []
187
+ for _, row in df.iterrows():
188
+ if current_granularity == "weekly":
189
+ period_range = pd.date_range(start=row.name, periods=7)
190
+ elif current_granularity == "monthly":
191
+ period_range = pd.date_range(
192
+ start=row.name, periods=row.name.days_in_month
193
+ )
194
+
195
+ for date in period_range:
196
+ new_row = {}
197
+ for col in df.columns:
198
+ if pd.api.types.is_numeric_dtype(df[col]):
199
+ if current_granularity == "weekly":
200
+ new_row[col] = row[col] / 7
201
+ elif current_granularity == "monthly":
202
+ new_row[col] = row[col] / row.name.days_in_month
203
+ else:
204
+ new_row[col] = row[col]
205
+ expanded_data.append((date, new_row))
206
+
207
+ resampled_df = pd.DataFrame(
208
+ [data for _, data in expanded_data],
209
+ index=[date for date, _ in expanded_data],
210
+ )
211
+
212
+ # Reset index
213
+ resampled_df = resampled_df.reset_index().rename(columns={"index": "date"})
214
+
215
+ return resampled_df
216
+
217
+ # Function to clean and extract unique values of Panel_1 and Panel_2
218
+ st.cache_resource(show_spinner=False)
219
+
220
+ def clean_and_extract_unique_values(files_dict, selections):
221
+ all_panel1_values = set()
222
+ all_panel2_values = set()
223
+
224
+ for file_name, file_data in files_dict.items():
225
+ df = file_data["df"]
226
+
227
+ # 'Panel_1' and 'Panel_2' selections
228
+ selected_panel1 = selections[file_name].get("Panel_1")
229
+ selected_panel2 = selections[file_name].get("Panel_2")
230
+
231
+ # Clean and standardize Panel_1 column if it exists and is selected
232
+ if (
233
+ selected_panel1
234
+ and selected_panel1 != "N/A"
235
+ and selected_panel1 in df.columns
236
+ ):
237
+ df[selected_panel1] = (
238
+ df[selected_panel1].str.lower().str.strip().str.replace("_", " ")
239
+ )
240
+ all_panel1_values.update(df[selected_panel1].dropna().unique())
241
+
242
+ # Clean and standardize Panel_2 column if it exists and is selected
243
+ if (
244
+ selected_panel2
245
+ and selected_panel2 != "N/A"
246
+ and selected_panel2 in df.columns
247
+ ):
248
+ df[selected_panel2] = (
249
+ df[selected_panel2].str.lower().str.strip().str.replace("_", " ")
250
+ )
251
+ all_panel2_values.update(df[selected_panel2].dropna().unique())
252
+
253
+ # Update the processed DataFrame back in the dictionary
254
+ files_dict[file_name]["df"] = df
255
+
256
+ return all_panel1_values, all_panel2_values
257
+
258
+ # Function to format values for display
259
+ st.cache_resource(show_spinner=False)
260
+
261
+ def format_values_for_display(values_list):
262
+ # Capitalize the first letter of each word and replace underscores with spaces
263
+ formatted_list = [value.replace("_", " ").title() for value in values_list]
264
+ # Join values with commas and 'and' before the last value
265
+ if len(formatted_list) > 1:
266
+ return ", ".join(formatted_list[:-1]) + ", and " + formatted_list[-1]
267
+ elif formatted_list:
268
+ return formatted_list[0]
269
+ return "No values available"
270
+
271
+ # Function to normalizes all data within files_dict to a daily granularity
272
+ st.cache(show_spinner=False, allow_output_mutation=True)
273
+
274
+ def standardize_data_to_daily(files_dict, selections):
275
+ # Normalize all data to a daily granularity using a provided function
276
+ files_dict = apply_granularity_to_all(files_dict, "daily", selections)
277
+
278
+ # Update the "interval" attribute for each dataset to indicate the new granularity
279
+ for files_name, files_data in files_dict.items():
280
+ files_data["interval"] = "daily"
281
+
282
+ return files_dict
283
+
284
+ # Function to apply granularity transformation to all DataFrames in files_dict
285
+ st.cache_resource(show_spinner=False)
286
+
287
+ def apply_granularity_to_all(files_dict, granularity_selection, selections):
288
+ for file_name, file_data in files_dict.items():
289
+ df = file_data["df"].copy()
290
+
291
+ # Handling when Panel_1 or Panel_2 might be 'N/A'
292
+ selected_panel1 = selections[file_name].get("Panel_1")
293
+ selected_panel2 = selections[file_name].get("Panel_2")
294
+
295
+ # Correcting the segment selection logic & handling 'N/A'
296
+ if selected_panel1 != "N/A" and selected_panel2 != "N/A":
297
+ unique_combinations = df[
298
+ [selected_panel1, selected_panel2]
299
+ ].drop_duplicates()
300
+ elif selected_panel1 != "N/A":
301
+ unique_combinations = df[[selected_panel1]].drop_duplicates()
302
+ selected_panel2 = None # Ensure Panel_2 is ignored if N/A
303
+ elif selected_panel2 != "N/A":
304
+ unique_combinations = df[[selected_panel2]].drop_duplicates()
305
+ selected_panel1 = None # Ensure Panel_1 is ignored if N/A
306
+ else:
307
+ # If both are 'N/A', process the entire dataframe as is
308
+ df = adjust_dataframe_granularity(
309
+ df, file_data["interval"], granularity_selection
310
+ )
311
+ files_dict[file_name]["df"] = df
312
+ continue # Skip to the next file
313
+
314
+ transformed_segments = []
315
+ for _, combo in unique_combinations.iterrows():
316
+ if selected_panel1 and selected_panel2:
317
+ segment = df[
318
+ (df[selected_panel1] == combo[selected_panel1])
319
+ & (df[selected_panel2] == combo[selected_panel2])
320
+ ]
321
+ elif selected_panel1:
322
+ segment = df[df[selected_panel1] == combo[selected_panel1]]
323
+ elif selected_panel2:
324
+ segment = df[df[selected_panel2] == combo[selected_panel2]]
325
+
326
+ # Adjust granularity of the segment
327
+ transformed_segment = adjust_dataframe_granularity(
328
+ segment, file_data["interval"], granularity_selection
329
+ )
330
+ transformed_segments.append(transformed_segment)
331
+
332
+ # Combine all transformed segments into a single DataFrame for this file
333
+ transformed_df = pd.concat(transformed_segments, ignore_index=True)
334
+ files_dict[file_name]["df"] = transformed_df
335
+
336
+ return files_dict
337
+
338
+ # Function to create main dataframe structure
339
+ st.cache_resource(show_spinner=False)
340
+
341
+ def create_main_dataframe(
342
+ files_dict, all_panel1_values, all_panel2_values, granularity_selection
343
+ ):
344
+ # Determine the global start and end dates across all DataFrames
345
+ global_start = min(df["df"]["date"].min() for df in files_dict.values())
346
+ global_end = max(df["df"]["date"].max() for df in files_dict.values())
347
+
348
+ # Adjust the date_range generation based on the granularity_selection
349
+ if granularity_selection == "weekly":
350
+ # Generate a weekly range, with weeks starting on Monday
351
+ date_range = pd.date_range(start=global_start, end=global_end, freq="W-MON")
352
+ elif granularity_selection == "monthly":
353
+ # Generate a monthly range, starting from the first day of each month
354
+ date_range = pd.date_range(start=global_start, end=global_end, freq="MS")
355
+ else: # Default to daily if not weekly or monthly
356
+ date_range = pd.date_range(start=global_start, end=global_end, freq="D")
357
+
358
+ # Collect all unique Panel_1 and Panel_2 values, excluding 'N/A'
359
+ all_panel1s = all_panel1_values
360
+ all_panel2s = all_panel2_values
361
+
362
+ # Dynamically build the list of dimensions (Panel_1, Panel_2) to include in the main DataFrame based on availability
363
+ dimensions, merge_keys = [], []
364
+ if all_panel1s:
365
+ dimensions.append(all_panel1s)
366
+ merge_keys.append("Panel_1")
367
+ if all_panel2s:
368
+ dimensions.append(all_panel2s)
369
+ merge_keys.append("Panel_2")
370
+
371
+ dimensions.append(date_range) # Date range is always included
372
+ merge_keys.append("date") # Date range is always included
373
+
374
+ # Create a main DataFrame template with the dimensions
375
+ main_df = pd.MultiIndex.from_product(
376
+ dimensions,
377
+ names=[name for name, _ in zip(merge_keys, dimensions)],
378
+ ).to_frame(index=False)
379
+
380
+ return main_df.reset_index(drop=True)
381
+
382
+ # Function to prepare and merge dataFrames
383
+ st.cache_resource(show_spinner=False)
384
+
385
+ def merge_into_main_df(main_df, files_dict, selections):
386
+ for file_name, file_data in files_dict.items():
387
+ df = file_data["df"].copy()
388
+
389
+ # Rename selected Panel_1 and Panel_2 columns if not 'N/A'
390
+ selected_panel1 = selections[file_name].get("Panel_1", "N/A")
391
+ selected_panel2 = selections[file_name].get("Panel_2", "N/A")
392
+ if selected_panel1 != "N/A":
393
+ df.rename(columns={selected_panel1: "Panel_1"}, inplace=True)
394
+ if selected_panel2 != "N/A":
395
+ df.rename(columns={selected_panel2: "Panel_2"}, inplace=True)
396
+
397
+ # Merge current DataFrame into main_df based on 'date', and where applicable, 'Panel_1' and 'Panel_2'
398
+ merge_keys = ["date"]
399
+ if "Panel_1" in df.columns:
400
+ merge_keys.append("Panel_1")
401
+ if "Panel_2" in df.columns:
402
+ merge_keys.append("Panel_2")
403
+ main_df = pd.merge(main_df, df, on=merge_keys, how="left")
404
+
405
+ # After all merges, sort by 'date' and reset index for cleanliness
406
+ sort_by = ["date"]
407
+ if "Panel_1" in main_df.columns:
408
+ sort_by.append("Panel_1")
409
+ if "Panel_2" in main_df.columns:
410
+ sort_by.append("Panel_2")
411
+ main_df.sort_values(by=sort_by, inplace=True)
412
+ main_df.reset_index(drop=True, inplace=True)
413
+
414
+ return main_df
415
+
416
+ # Function to categorize column
417
+ def categorize_column(column_name):
418
+ # Define keywords for each category
419
+ internal_keywords = [
420
+ "Price",
421
+ "Discount",
422
+ "product_price",
423
+ "cost",
424
+ "margin",
425
+ "inventory",
426
+ "sales",
427
+ "revenue",
428
+ "turnover",
429
+ "expense",
430
+ ]
431
+ exogenous_keywords = [
432
+ "GDP",
433
+ "Tax",
434
+ "Inflation",
435
+ "interest_rate",
436
+ "employment_rate",
437
+ "exchange_rate",
438
+ "consumer_spending",
439
+ "retail_sales",
440
+ "oil_prices",
441
+ "weather",
442
+ ]
443
+
444
+ # Check if the column name matches any of the keywords for Internal or Exogenous categories
445
+
446
+ if (
447
+ column_name
448
+ in st.session_state["project_dct"]["data_import"]["cat_dct"].keys()
449
+ and st.session_state["project_dct"]["data_import"]["cat_dct"][column_name]
450
+ is not None
451
+ ):
452
+
453
+ return st.session_state["project_dct"]["data_import"]["cat_dct"][
454
+ column_name
455
+ ] # resume project manoj
456
+
457
+ else:
458
+ for keyword in internal_keywords:
459
+ if keyword.lower() in column_name.lower():
460
+ return "Internal"
461
+ for keyword in exogenous_keywords:
462
+ if keyword.lower() in column_name.lower():
463
+ return "Exogenous"
464
+
465
+ # Default to Media if no match found
466
+ return "Media"
467
+
468
+ # Function to calculate missing stats and prepare for editable DataFrame
469
+ st.cache_resource(show_spinner=False)
470
+
471
+ def prepare_missing_stats_df(df):
472
+ missing_stats = []
473
+ for column in df.columns:
474
+ if (
475
+ column == "date" or column == "Panel_2" or column == "Panel_1"
476
+ ): # Skip Date, Panel_1 and Panel_2 column
477
+ continue
478
+
479
+ missing = df[column].isnull().sum()
480
+ pct_missing = round((missing / len(df)) * 100, 2)
481
+
482
+ # Dynamically assign category based on column name
483
+ category = categorize_column(column)
484
+ # category = "Media" # Keep default bin as Media
485
+
486
+ missing_stats.append(
487
+ {
488
+ "Column": column,
489
+ "Missing Values": missing,
490
+ "Missing Percentage": pct_missing,
491
+ "Impute Method": "Fill with 0", # Default value
492
+ "Category": category,
493
+ }
494
+ )
495
+
496
+ stats_df = pd.DataFrame(missing_stats)
497
+
498
+ return stats_df
499
+
500
+ # Function to add API DataFrame details to the files dictionary
501
+ st.cache_resource(show_spinner=False)
502
+
503
+ def add_api_dataframe_to_dict(main_df, files_dict):
504
+ files_dict["API"] = {
505
+ "numeric": list(main_df.select_dtypes(include=["number"]).columns),
506
+ "non_numeric": [
507
+ col
508
+ for col in main_df.select_dtypes(exclude=["number"]).columns
509
+ if col.lower() != "date"
510
+ ],
511
+ "interval": determine_data_interval(
512
+ pd.Series(main_df["date"].unique()).diff().dt.days.dropna().mode()[0]
513
+ ),
514
+ "df": main_df,
515
+ }
516
+
517
+ return files_dict
518
+
519
+ # Function to reads an API into a DataFrame, parsing specified columns as datetime
520
+ @st.cache_resource(show_spinner=False)
521
+ def read_API_data():
522
+ return pd.read_excel(
523
+ r"./upf_data_converted_randomized_resp_metrics.xlsx",
524
+ parse_dates=["Date"],
525
+ )
526
+
527
+ # Function to set the 'Panel_1_Panel_2_Selected' session state variable to False
528
+ def set_Panel_1_Panel_2_Selected_false():
529
+
530
+ st.session_state["Panel_1_Panel_2_Selected"] = False
531
+
532
+ # restoring project_dct to default values when user modify any widjets
533
+ st.session_state["project_dct"]["data_import"]["edited_stats_df"] = None
534
+ st.session_state["project_dct"]["data_import"]["merged_df"] = None
535
+ st.session_state["project_dct"]["data_import"]["missing_stats_df"] = None
536
+ st.session_state["project_dct"]["data_import"]["cat_dct"] = {}
537
+ st.session_state["project_dct"]["data_import"]["numeric_columns"] = None
538
+ st.session_state["project_dct"]["data_import"]["default_df"] = None
539
+ st.session_state["project_dct"]["data_import"]["final_df"] = None
540
+ st.session_state["project_dct"]["data_import"]["edited_df"] = None
541
+
542
+ # Function to serialize and save the objects into a pickle file
543
+ @st.cache_resource(show_spinner=False)
544
+ def save_to_pickle(file_path, final_df, bin_dict):
545
+ # Open the file in write-binary mode and dump the objects
546
+ with open(file_path, "wb") as f:
547
+ pickle.dump({"final_df": final_df, "bin_dict": bin_dict}, f)
548
+ # Data is now saved to file
549
+
550
+ # Function to processes the merged_df DataFrame based on operations defined in edited_df
551
+ @st.cache_resource(show_spinner=False)
552
+ def process_dataframes(merged_df, edited_df, edited_stats_df):
553
+ # Ensure there are operations defined by the user
554
+ if edited_df.empty:
555
+
556
+ return merged_df, edited_stats_df # No operations to apply
557
+
558
+ # Perform operations as defined by the user
559
+ else:
560
+
561
+ for index, row in edited_df.iterrows():
562
+ result_column_name = (
563
+ f"{row['Column 1']}{row['Operator']}{row['Column 2']}"
564
+ )
565
+ col1 = row["Column 1"]
566
+ col2 = row["Column 2"]
567
+ op = row["Operator"]
568
+
569
+ # Apply the specified operation
570
+ if op == "+":
571
+ merged_df[result_column_name] = merged_df[col1] + merged_df[col2]
572
+ elif op == "-":
573
+ merged_df[result_column_name] = merged_df[col1] - merged_df[col2]
574
+ elif op == "*":
575
+ merged_df[result_column_name] = merged_df[col1] * merged_df[col2]
576
+ elif op == "/":
577
+ merged_df[result_column_name] = merged_df[col1] / merged_df[
578
+ col2
579
+ ].replace(0, 1e-9)
580
+
581
+ # Add summary of operation to edited_stats_df
582
+ new_row = {
583
+ "Column": result_column_name,
584
+ "Missing Values": None,
585
+ "Missing Percentage": None,
586
+ "Impute Method": None,
587
+ "Category": row["Category"],
588
+ }
589
+ new_row_df = pd.DataFrame([new_row])
590
+
591
+ # Use pd.concat to add the new_row_df to edited_stats_df
592
+ edited_stats_df = pd.concat(
593
+ [edited_stats_df, new_row_df], ignore_index=True, axis=0
594
+ )
595
+
596
+ # Combine column names from edited_df for cleanup
597
+ combined_columns = set(edited_df["Column 1"]).union(
598
+ set(edited_df["Column 2"])
599
+ )
600
+
601
+ # Filter out rows in edited_stats_df and drop columns from merged_df
602
+ edited_stats_df = edited_stats_df[
603
+ ~edited_stats_df["Column"].isin(combined_columns)
604
+ ]
605
+ merged_df.drop(
606
+ columns=list(combined_columns), errors="ignore", inplace=True
607
+ )
608
+
609
+ return merged_df, edited_stats_df
610
+
611
+ # Function to prepare a list of numeric column names and initialize an empty DataFrame with predefined structure
612
+ st.cache_resource(show_spinner=False)
613
+
614
+ def prepare_numeric_columns_and_default_df(merged_df, edited_stats_df):
615
+ # Get columns categorized as 'Response Metrics'
616
+ columns_response_metrics = edited_stats_df[
617
+ edited_stats_df["Category"] == "Response Metrics"
618
+ ]["Column"].tolist()
619
+
620
+ # Filter numeric columns, excluding those categorized as 'Response Metrics'
621
+ numeric_columns = [
622
+ col
623
+ for col in merged_df.select_dtypes(include=["number"]).columns
624
+ if col not in columns_response_metrics
625
+ ]
626
+
627
+ # Define the structure of the empty DataFrame
628
+ data = {
629
+ "Column 1": pd.Series([], dtype="str"),
630
+ "Operator": pd.Series([], dtype="str"),
631
+ "Column 2": pd.Series([], dtype="str"),
632
+ "Category": pd.Series([], dtype="str"),
633
+ }
634
+ default_df = pd.DataFrame(data)
635
+
636
+ return numeric_columns, default_df
637
+
638
+ # function to reset to default values in project_dct:
639
+
640
+ # Initialize 'final_df' in session state
641
+ if "final_df" not in st.session_state:
642
+ st.session_state["final_df"] = pd.DataFrame()
643
+
644
+ # Initialize 'bin_dict' in session state
645
+ if "bin_dict" not in st.session_state:
646
+ st.session_state["bin_dict"] = {}
647
+
648
+ # Initialize 'Panel_1_Panel_2_Selected' in session state
649
+ if "Panel_1_Panel_2_Selected" not in st.session_state:
650
+ st.session_state["Panel_1_Panel_2_Selected"] = False
651
+
652
+ # Page Title
653
+ st.write("") # Top padding
654
+ st.title("Data Import")
655
+
656
+ conn = sqlite3.connect(
657
+ r"DB\User.db", check_same_thread=False
658
+ ) # connection with sql db
659
+ c = conn.cursor()
660
+
661
+ #########################################################################################################################################################
662
+ # Create a dictionary to hold all DataFrames and collect user input to specify "Panel_2" and "Panel_1" columns for each file
663
+ #########################################################################################################################################################
664
+
665
+ # Read the Excel file, parsing 'Date' column as datetime
666
+ main_df = read_API_data()
667
+
668
+ # Convert all column names to lowercase
669
+ main_df.columns = main_df.columns.str.lower().str.strip()
670
+
671
+ # File uploader
672
+ uploaded_files = st.file_uploader(
673
+ "Upload additional data",
674
+ type=["xlsx"],
675
+ accept_multiple_files=True,
676
+ on_change=set_Panel_1_Panel_2_Selected_false,
677
+ )
678
+
679
+ # Custom HTML for upload instructions
680
+ recommendation_html = f"""
681
+ <div style="text-align: justify;">
682
+ <strong>Recommendation:</strong> For optimal processing, please ensure that all uploaded datasets including panel, media, internal, and exogenous data adhere to the following guidelines: Each dataset must include a <code>Date</code> column formatted as <code>DD-MM-YYYY</code>, be free of missing values.
683
+ </div>
684
+ """
685
+ st.markdown(recommendation_html, unsafe_allow_html=True)
686
+
687
+ # Choose Desired Granularity
688
+ st.markdown("#### Choose Desired Granularity")
689
+ # Granularity Selection
690
+
691
+ granularity_selection = st.selectbox(
692
+ "Choose Date Granularity",
693
+ ["Daily", "Weekly", "Monthly"],
694
+ label_visibility="collapsed",
695
+ on_change=set_Panel_1_Panel_2_Selected_false,
696
+ index=st.session_state["project_dct"]["data_import"][
697
+ "granularity_selection"
698
+ ], # resume
699
+ )
700
+
701
+ # st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
702
+
703
+ st.session_state["project_dct"]["data_import"]["granularity_selection"] = [
704
+ "Daily",
705
+ "Weekly",
706
+ "Monthly",
707
+ ].index(granularity_selection)
708
+ # st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
709
+ granularity_selection = str(granularity_selection).lower()
710
+
711
+ # Convert files to dataframes
712
+ files_dict = files_to_dataframes(uploaded_files)
713
+
714
+ # Add API Dataframe
715
+ if main_df is not None:
716
+ files_dict = add_api_dataframe_to_dict(main_df, files_dict)
717
+
718
+ # Display a warning message if no files have been uploaded and halt further execution
719
+ if not files_dict:
720
+ st.warning(
721
+ "Please upload at least one file to proceed.",
722
+ icon="⚠️",
723
+ )
724
+ st.stop() # Halts further execution until file is uploaded
725
+
726
+ # Select Panel_1 and Panel_2 columns
727
+ st.markdown("#### Select Panel columns")
728
+ selections = {}
729
+ with st.expander("Select Panel columns", expanded=False):
730
+ count = 0 # Initialize counter to manage the visibility of labels and keys
731
+ for file_name, file_data in files_dict.items():
732
+
733
+ # generatimg project dct keys dynamically
734
+ if (
735
+ f"Panel_1_selectbox{file_name}"
736
+ not in st.session_state["project_dct"]["data_import"].keys()
737
+ ):
738
+ st.session_state["project_dct"]["data_import"][
739
+ f"Panel_1_selectbox{file_name}"
740
+ ] = 0
741
+
742
+ if (
743
+ f"Panel_2_selectbox{file_name}"
744
+ not in st.session_state["project_dct"]["data_import"].keys()
745
+ ):
746
+
747
+ st.session_state["project_dct"]["data_import"][
748
+ f"Panel_2_selectbox{file_name}"
749
+ ] = 0
750
+
751
+ # Determine visibility of the label based on the count
752
+ if count == 0:
753
+ label_visibility = "visible"
754
+ else:
755
+ label_visibility = "collapsed"
756
+
757
+ # Extract non-numeric columns
758
+ non_numeric_cols = file_data["non_numeric"]
759
+
760
+ # Prepare Panel_1 and Panel_2 values for dropdown, adding "N/A" as an option
761
+ panel1_values = non_numeric_cols + ["N/A"]
762
+ panel2_values = non_numeric_cols + ["N/A"]
763
+
764
+ # Skip if only one option is available
765
+ if len(panel1_values) == 1 and len(panel2_values) == 1:
766
+ selected_panel1, selected_panel2 = "N/A", "N/A"
767
+ # Update the selections for Panel_1 and Panel_2 for the current file
768
+ selections[file_name] = {
769
+ "Panel_1": selected_panel1,
770
+ "Panel_2": selected_panel2,
771
+ }
772
+ continue
773
+
774
+ # Create layout columns for File Name, Panel_2, and Panel_1 selections
775
+ file_name_col, Panel_1_col, Panel_2_col = st.columns([2, 4, 4])
776
+
777
+ with file_name_col:
778
+ # Display "File Name" label only for the first file
779
+ if count == 0:
780
+ st.write("File Name")
781
+ else:
782
+ st.write("")
783
+ st.write(file_name) # Display the file name
784
+
785
+ with Panel_1_col:
786
+ # Display a selectbox for Panel_1 values
787
+ selected_panel1 = st.selectbox(
788
+ "Select Panel Level 1",
789
+ panel2_values,
790
+ on_change=set_Panel_1_Panel_2_Selected_false,
791
+ label_visibility=label_visibility, # Control visibility of the label
792
+ key=f"Panel_1_selectbox{count}", # Ensure unique key for each selectbox
793
+ index=st.session_state["project_dct"]["data_import"][
794
+ f"Panel_1_selectbox{file_name}"
795
+ ],
796
+ )
797
+
798
+ st.session_state["project_dct"]["data_import"][
799
+ f"Panel_1_selectbox{file_name}"
800
+ ] = panel2_values.index(selected_panel1)
801
+
802
+ with Panel_2_col:
803
+ # Display a selectbox for Panel_2 values
804
+ selected_panel2 = st.selectbox(
805
+ "Select Panel Level 2",
806
+ panel1_values,
807
+ on_change=set_Panel_1_Panel_2_Selected_false,
808
+ label_visibility=label_visibility, # Control visibility of the label
809
+ key=f"Panel_2_selectbox{count}", # Ensure unique key for each selectbox
810
+ index=st.session_state["project_dct"]["data_import"][
811
+ f"Panel_2_selectbox{file_name}"
812
+ ],
813
+ )
814
+
815
+ st.session_state["project_dct"]["data_import"][
816
+ f"Panel_2_selectbox{file_name}"
817
+ ] = panel1_values.index(selected_panel2)
818
+
819
+ # st.write(st.session_state['project_dct']['data_import'][f"Panel_2_selectbox{file_name}"])
820
+
821
+ # Skip processing if the same column is selected for both Panel_1 and Panel_2 due to potential data integrity issues
822
+
823
+ if selected_panel2 == selected_panel1 and not (
824
+ selected_panel2 == "N/A" and selected_panel1 == "N/A"
825
+ ):
826
+ st.warning(
827
+ f"File: {file_name} → The same column cannot serve as both Panel_1 and Panel_2. Please adjust your selections.",
828
+ )
829
+ selected_panel1, selected_panel2 = "N/A", "N/A"
830
+ st.stop()
831
+
832
+ # Update the selections for Panel_1 and Panel_2 for the current file
833
+ selections[file_name] = {
834
+ "Panel_1": selected_panel1,
835
+ "Panel_2": selected_panel2,
836
+ }
837
+
838
+ count += 1 # Increment the counter after processing each file
839
+ st.write()
840
+ # Accept Panel_1 and Panel_2 selection
841
+ accept = st.button(
842
+ "Accept and Process", use_container_width=True
843
+ ) # resume project manoj
844
+
845
+ if (
846
+ accept == False
847
+ and st.session_state["project_dct"]["data_import"]["edited_stats_df"]
848
+ is not None
849
+ ):
850
+
851
+ # st.write(st.session_state['project_dct'])
852
+ st.markdown("#### Unique Panel values")
853
+ # Display Panel_1 and Panel_2 values
854
+ with st.expander("Unique Panel values"):
855
+ st.write("")
856
+ st.markdown(
857
+ f"""
858
+ <style>
859
+ .justify-text {{
860
+ text-align: justify;
861
+ }}
862
+ </style>
863
+ <div class="justify-text">
864
+ <strong>Panel Level 1 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel1_values']}<br>
865
+ <strong>Panel Level 2 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel2_values']}
866
+ </div>
867
+ """,
868
+ unsafe_allow_html=True,
869
+ )
870
+
871
+ # Display total Panel_1 and Panel_2
872
+ st.write("")
873
+ st.markdown(
874
+ f"""
875
+ <div style="text-align: justify;">
876
+ <strong>Number of Level 1 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}<br>
877
+ <strong>Number of Level 2 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}
878
+ </div>
879
+ """,
880
+ unsafe_allow_html=True,
881
+ )
882
+ st.write("")
883
+
884
+ # Create an editable DataFrame in Streamlit
885
+
886
+ st.markdown("#### Select Variables Category & Impute Missing Values")
887
+
888
+ # data_temp_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
889
+
890
+ # with open(data_temp_path,"rb") as f:
891
+ # saved_edited_stats_df=pickle.load(f)
892
+
893
+ # a=st.data_editor(saved_edited_stats_df)
894
+
895
+ merged_df = st.session_state["project_dct"]["data_import"]["merged_df"].copy()
896
+
897
+ missing_stats_df = st.session_state["project_dct"]["data_import"][
898
+ "missing_stats_df"
899
+ ]
900
+
901
+ editable_df = st.session_state["project_dct"]["data_import"]["edited_stats_df"]
902
+ sorted_editable_df = editable_df.sort_values(
903
+ by="Missing Values", ascending=False, na_position="first"
904
+ )
905
+
906
+ edited_stats_df = st.data_editor(
907
+ sorted_editable_df,
908
+ column_config={
909
+ "Impute Method": st.column_config.SelectboxColumn(
910
+ options=[
911
+ "Drop Column",
912
+ "Fill with Mean",
913
+ "Fill with Median",
914
+ "Fill with 0",
915
+ ],
916
+ required=True,
917
+ default="Fill with 0",
918
+ ),
919
+ "Category": st.column_config.SelectboxColumn(
920
+ options=[
921
+ "Media",
922
+ "Exogenous",
923
+ "Internal",
924
+ "Response Metrics",
925
+ ],
926
+ required=True,
927
+ default="Media",
928
+ ),
929
+ },
930
+ disabled=["Column", "Missing Values", "Missing Percentage"],
931
+ hide_index=True,
932
+ use_container_width=True,
933
+ key="data-editor-1",
934
+ )
935
+
936
+ st.session_state["project_dct"]["data_import"]["cat_dct"] = {
937
+ col: cat
938
+ for col, cat in zip(edited_stats_df["Column"], edited_stats_df["Category"])
939
+ }
940
+
941
+ for i, row in edited_stats_df.iterrows():
942
+ column = row["Column"]
943
+ if row["Impute Method"] == "Drop Column":
944
+ merged_df.drop(columns=[column], inplace=True)
945
+
946
+ elif row["Impute Method"] == "Fill with Mean":
947
+ merged_df[column].fillna(
948
+ st.session_state["project_dct"]["data_import"]["merged_df"][
949
+ column
950
+ ].mean(),
951
+ inplace=True,
952
+ )
953
+
954
+ elif row["Impute Method"] == "Fill with Median":
955
+ merged_df[column].fillna(
956
+ st.session_state["project_dct"]["data_import"]["merged_df"][
957
+ column
958
+ ].median(),
959
+ inplace=True,
960
+ )
961
+
962
+ elif row["Impute Method"] == "Fill with 0":
963
+ merged_df[column].fillna(0, inplace=True)
964
+
965
+ # st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
966
+ #########################################################################################################################################################
967
+ # Group columns
968
+ #########################################################################################################################################################
969
+
970
+ # Display Group columns header
971
+ numeric_columns = st.session_state["project_dct"]["data_import"][
972
+ "numeric_columns"
973
+ ]
974
+ default_df = st.session_state["project_dct"]["data_import"]["default_df"]
975
+
976
+ st.markdown("#### Feature engineering")
977
+
978
+ edited_df = st.data_editor(
979
+ st.session_state["project_dct"]["data_import"]["edited_df"],
980
+ column_config={
981
+ "Column 1": st.column_config.SelectboxColumn(
982
+ options=numeric_columns,
983
+ required=True,
984
+ width=400,
985
+ ),
986
+ "Operator": st.column_config.SelectboxColumn(
987
+ options=["+", "-", "*", "/"],
988
+ required=True,
989
+ default="+",
990
+ width=100,
991
+ ),
992
+ "Column 2": st.column_config.SelectboxColumn(
993
+ options=numeric_columns,
994
+ required=True,
995
+ default=numeric_columns[0],
996
+ width=400,
997
+ ),
998
+ "Category": st.column_config.SelectboxColumn(
999
+ options=[
1000
+ "Media",
1001
+ "Exogenous",
1002
+ "Internal",
1003
+ "Response Metrics",
1004
+ ],
1005
+ required=True,
1006
+ default="Media",
1007
+ width=200,
1008
+ ),
1009
+ },
1010
+ num_rows="dynamic",
1011
+ key="data-editor-4",
1012
+ )
1013
+
1014
+ final_df, edited_stats_df = process_dataframes(
1015
+ merged_df, edited_df, edited_stats_df
1016
+ )
1017
+
1018
+ st.markdown("#### Final DataFrame")
1019
+ sort_col = []
1020
+ for col in final_df.columns:
1021
+ if col in ["Panel_1", "Panel_2", "date"]:
1022
+ sort_col.append(col)
1023
+
1024
+ sorted_final_df = final_df.sort_values(
1025
+ by=sort_col, ascending=True, na_position="first"
1026
+ )
1027
+
1028
+ st.dataframe(sorted_final_df, hide_index=True)
1029
+
1030
+ # Initialize an empty dictionary to hold categories and their variables
1031
+ category_dict = {}
1032
+
1033
+ # Iterate over each row in the edited DataFrame to populate the dictionary
1034
+ for i, row in edited_stats_df.iterrows():
1035
+ column = row["Column"]
1036
+ category = row[
1037
+ "Category"
1038
+ ] # The category chosen by the user for this variable
1039
+
1040
+ # Check if the category already exists in the dictionary
1041
+ if category not in category_dict:
1042
+ # If not, initialize it with the current column as its first element
1043
+ category_dict[category] = [column]
1044
+ else:
1045
+ # If it exists, append the current column to the list of variables under this category
1046
+ category_dict[category].append(column)
1047
+
1048
+ # Add Date, Panel_1 and Panel_12 in category dictionary
1049
+ category_dict.update({"Date": ["date"]})
1050
+ if "Panel_1" in final_df.columns:
1051
+ category_dict["Panel Level 1"] = ["Panel_1"]
1052
+ if "Panel_2" in final_df.columns:
1053
+ category_dict["Panel Level 2"] = ["Panel_2"]
1054
+
1055
+ # Display the dictionary
1056
+ st.markdown("#### Variable Category")
1057
+ for category, variables in category_dict.items():
1058
+ # Check if there are multiple variables to handle "and" insertion correctly
1059
+ if len(variables) > 1:
1060
+ # Join all but the last variable with ", ", then add " and " before the last variable
1061
+ variables_str = ", ".join(variables[:-1]) + " and " + variables[-1]
1062
+ else:
1063
+ # If there's only one variable, no need for "and"
1064
+ variables_str = variables[0]
1065
+
1066
+ # Display the category and its variables in the desired format
1067
+ st.markdown(
1068
+ f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
1069
+ unsafe_allow_html=True,
1070
+ )
1071
+
1072
+ # Function to check if Response Metrics is selected
1073
+ st.write("")
1074
+ response_metrics_col = category_dict.get("Response Metrics", [])
1075
+ if len(response_metrics_col) == 0:
1076
+ st.warning("Please select Response Metrics column", icon="⚠️")
1077
+ st.stop()
1078
+ # elif len(response_metrics_col) > 1:
1079
+ # st.warning("Please select only one Response Metrics column", icon="⚠️")
1080
+ # st.stop()
1081
+
1082
+ # Store final dataframe and bin dictionary into session state
1083
+ st.session_state["final_df"], st.session_state["bin_dict"] = (
1084
+ final_df,
1085
+ category_dict,
1086
+ )
1087
+
1088
+ # Save the DataFrame and dictionary from the session state to the pickle file
1089
+ if st.button(
1090
+ "Accept and Save",
1091
+ use_container_width=True,
1092
+ key="data-editor-button",
1093
+ ):
1094
+ print("test*************")
1095
+ update_db("1_Data_Import.py")
1096
+ final_df = final_df.loc[:, ~final_df.columns.duplicated()]
1097
+
1098
+ project_dct_path = os.path.join(
1099
+ st.session_state["project_path"], "project_dct.pkl"
1100
+ )
1101
+
1102
+ with open(project_dct_path, "wb") as f:
1103
+ pickle.dump(st.session_state["project_dct"], f)
1104
+
1105
+ data_path = os.path.join(
1106
+ st.session_state["project_path"], "data_import.pkl"
1107
+ )
1108
+
1109
+ st.session_state["data_path"] = data_path
1110
+
1111
+ save_to_pickle(
1112
+ data_path,
1113
+ st.session_state["final_df"],
1114
+ st.session_state["bin_dict"],
1115
+ )
1116
+
1117
+ st.session_state["project_dct"]["data_import"][
1118
+ "edited_stats_df"
1119
+ ] = edited_stats_df
1120
+ st.session_state["project_dct"]["data_import"]["merged_df"] = merged_df
1121
+ st.session_state["project_dct"]["data_import"][
1122
+ "missing_stats_df"
1123
+ ] = missing_stats_df
1124
+ st.session_state["project_dct"]["data_import"]["cat_dct"] = {
1125
+ col: cat
1126
+ for col, cat in zip(
1127
+ edited_stats_df["Column"], edited_stats_df["Category"]
1128
+ )
1129
+ }
1130
+ st.session_state["project_dct"]["data_import"][
1131
+ "numeric_columns"
1132
+ ] = numeric_columns
1133
+ st.session_state["project_dct"]["data_import"]["default_df"] = default_df
1134
+ st.session_state["project_dct"]["data_import"]["final_df"] = final_df
1135
+ st.session_state["project_dct"]["data_import"]["edited_df"] = edited_df
1136
+
1137
+ st.toast("💾 Saved Successfully!")
1138
+
1139
+ if accept:
1140
+ # Normalize all data to a daily granularity. This initial standardization simplifies subsequent conversions to other levels of granularity
1141
+ with st.spinner("Processing..."):
1142
+ files_dict = standardize_data_to_daily(files_dict, selections)
1143
+
1144
+ # Convert all data to daily level granularity
1145
+ files_dict = apply_granularity_to_all(
1146
+ files_dict, granularity_selection, selections
1147
+ )
1148
+
1149
+ # Update the 'files_dict' in the session state
1150
+ st.session_state["files_dict"] = files_dict
1151
+
1152
+ # Set a flag in the session state to indicate that selection has been made
1153
+ st.session_state["Panel_1_Panel_2_Selected"] = True
1154
+
1155
+ #########################################################################################################################################################
1156
+ # Display unique Panel_1 and Panel_2 values
1157
+ #########################################################################################################################################################
1158
+
1159
+ # Halts further execution until Panel_1 and Panel_2 columns are selected
1160
+ if st.session_state["project_dct"]["data_import"]["edited_stats_df"] is None:
1161
+
1162
+ if (
1163
+ "files_dict" in st.session_state
1164
+ and st.session_state["Panel_1_Panel_2_Selected"]
1165
+ ):
1166
+ files_dict = st.session_state["files_dict"]
1167
+
1168
+ st.session_state["project_dct"]["data_import"][
1169
+ "files_dict"
1170
+ ] = files_dict # resume
1171
+ else:
1172
+ st.stop()
1173
+
1174
+ # Set to store unique values of Panel_1 and Panel_2
1175
+ with st.spinner("Fetching Panel values..."):
1176
+ all_panel1_values, all_panel2_values = clean_and_extract_unique_values(
1177
+ files_dict, selections
1178
+ )
1179
+
1180
+ # List of Panel_1 and Panel_2 columns unique values
1181
+ list_of_all_panel1_values = list(all_panel1_values)
1182
+ list_of_all_panel2_values = list(all_panel2_values)
1183
+
1184
+ # Format Panel_1 and Panel_2 values for display
1185
+ formatted_panel1_values = format_values_for_display(
1186
+ list_of_all_panel1_values
1187
+ ) ##
1188
+ formatted_panel2_values = format_values_for_display(
1189
+ list_of_all_panel2_values
1190
+ ) ##
1191
+
1192
+ # storing panel values in project_dct
1193
+
1194
+ st.session_state["project_dct"]["data_import"][
1195
+ "formatted_panel1_values"
1196
+ ] = formatted_panel1_values
1197
+ st.session_state["project_dct"]["data_import"][
1198
+ "formatted_panel2_values"
1199
+ ] = formatted_panel2_values
1200
+
1201
+ # Unique Panel_1 and Panel_2 values
1202
+ st.markdown("#### Unique Panel values")
1203
+ # Display Panel_1 and Panel_2 values
1204
+ with st.expander("Unique Panel values"):
1205
+ st.write("")
1206
+ st.markdown(
1207
+ f"""
1208
+ <style>
1209
+ .justify-text {{
1210
+ text-align: justify;
1211
+ }}
1212
+ </style>
1213
+ <div class="justify-text">
1214
+ <strong>Panel Level 1 Values:</strong> {formatted_panel1_values}<br>
1215
+ <strong>Panel Level 2 Values:</strong> {formatted_panel2_values}
1216
+ </div>
1217
+ """,
1218
+ unsafe_allow_html=True,
1219
+ )
1220
+
1221
+ # Display total Panel_1 and Panel_2
1222
+ st.write("")
1223
+ st.markdown(
1224
+ f"""
1225
+ <div style="text-align: justify;">
1226
+ <strong>Number of Level 1 Panels detected:</strong> {len(list_of_all_panel1_values)}<br>
1227
+ <strong>Number of Level 2 Panels detected:</strong> {len(list_of_all_panel2_values)}
1228
+ </div>
1229
+ """,
1230
+ unsafe_allow_html=True,
1231
+ )
1232
+ st.write("")
1233
+
1234
+ #########################################################################################################################################################
1235
+ # Merge all DataFrames
1236
+ #########################################################################################################################################################
1237
+
1238
+ # Merge all DataFrames selected
1239
+
1240
+ main_df = create_main_dataframe(
1241
+ files_dict,
1242
+ all_panel1_values,
1243
+ all_panel2_values,
1244
+ granularity_selection,
1245
+ )
1246
+
1247
+ merged_df = merge_into_main_df(main_df, files_dict, selections) ##
1248
+
1249
+ #########################################################################################################################################################
1250
+ # Categorize Variables and Impute Missing Values
1251
+ #########################################################################################################################################################
1252
+
1253
+ # Create an editable DataFrame in Streamlit
1254
+
1255
+ st.markdown("#### Select Variables Category & Impute Missing Values")
1256
+
1257
+ # Prepare missing stats DataFrame for editing
1258
+ missing_stats_df = prepare_missing_stats_df(merged_df)
1259
+ sorted_missing_stats_df = missing_stats_df.sort_values(
1260
+ by="Missing Values", ascending=False, na_position="first"
1261
+ )
1262
+
1263
+ # storing missing stats df
1264
+
1265
+ edited_stats_df = st.data_editor(
1266
+ sorted_missing_stats_df,
1267
+ column_config={
1268
+ "Impute Method": st.column_config.SelectboxColumn(
1269
+ options=[
1270
+ "Drop Column",
1271
+ "Fill with Mean",
1272
+ "Fill with Median",
1273
+ "Fill with 0",
1274
+ ],
1275
+ required=True,
1276
+ default="Fill with 0",
1277
+ ),
1278
+ "Category": st.column_config.SelectboxColumn(
1279
+ options=[
1280
+ "Media",
1281
+ "Exogenous",
1282
+ "Internal",
1283
+ "Response Metrics",
1284
+ ],
1285
+ required=True,
1286
+ default="Media",
1287
+ ),
1288
+ },
1289
+ disabled=["Column", "Missing Values", "Missing Percentage"],
1290
+ hide_index=True,
1291
+ use_container_width=True,
1292
+ key="data-editor-2",
1293
+ )
1294
+
1295
+ # edited_stats_df_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
1296
+
1297
+ # edited_stats_df.to_pickle(edited_stats_df_path)
1298
+
1299
+ # Apply changes based on edited DataFrame
1300
+ for i, row in edited_stats_df.iterrows():
1301
+ column = row["Column"]
1302
+ if row["Impute Method"] == "Drop Column":
1303
+ merged_df.drop(columns=[column], inplace=True)
1304
+
1305
+ elif row["Impute Method"] == "Fill with Mean":
1306
+ merged_df[column].fillna(merged_df[column].mean(), inplace=True)
1307
+
1308
+ elif row["Impute Method"] == "Fill with Median":
1309
+ merged_df[column].fillna(merged_df[column].median(), inplace=True)
1310
+
1311
+ elif row["Impute Method"] == "Fill with 0":
1312
+ merged_df[column].fillna(0, inplace=True)
1313
+
1314
+ # st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
1315
+
1316
+ #########################################################################################################################################################
1317
+ # Group columns
1318
+ #########################################################################################################################################################
1319
+
1320
+ # Display Group columns header
1321
+ st.markdown("#### Feature engineering")
1322
+
1323
+ # Prepare the numeric columns and an empty DataFrame for user input
1324
+ numeric_columns, default_df = prepare_numeric_columns_and_default_df(
1325
+ merged_df, edited_stats_df
1326
+ )
1327
+
1328
+ # st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
1329
+
1330
+ # Display editable Dataframe
1331
+ edited_df = st.data_editor(
1332
+ default_df,
1333
+ column_config={
1334
+ "Column 1": st.column_config.SelectboxColumn(
1335
+ options=numeric_columns,
1336
+ required=True,
1337
+ width=400,
1338
+ ),
1339
+ "Operator": st.column_config.SelectboxColumn(
1340
+ options=["+", "-", "*", "/"],
1341
+ required=True,
1342
+ default="+",
1343
+ width=100,
1344
+ ),
1345
+ "Column 2": st.column_config.SelectboxColumn(
1346
+ options=numeric_columns,
1347
+ required=True,
1348
+ default=numeric_columns[0],
1349
+ width=400,
1350
+ ),
1351
+ "Category": st.column_config.SelectboxColumn(
1352
+ options=[
1353
+ "Media",
1354
+ "Exogenous",
1355
+ "Internal",
1356
+ "Response Metrics",
1357
+ ],
1358
+ required=True,
1359
+ default="Media",
1360
+ width=200,
1361
+ ),
1362
+ },
1363
+ num_rows="dynamic",
1364
+ key="data-editor-3",
1365
+ )
1366
+
1367
+ # Process the DataFrame based on user inputs and operations specified in edited_df
1368
+ final_df, edited_stats_df = process_dataframes(
1369
+ merged_df, edited_df, edited_stats_df
1370
+ )
1371
+
1372
+ # edited_df_path=os.path.join(st.session_state['project_path'],'edited_df.pkl')
1373
+ # edited_df.to_pickle(edited_df_path)
1374
+
1375
+ #########################################################################################################################################################
1376
+ # Display the Final DataFrame and variables
1377
+ #########################################################################################################################################################
1378
+
1379
+ # Display the Final DataFrame and variables
1380
+
1381
+ st.markdown("#### Final DataFrame")
1382
+
1383
+ sort_col = []
1384
+ for col in final_df.columns:
1385
+ if col in ["Panel_1", "Panel_2", "date"]:
1386
+ sort_col.append(col)
1387
+
1388
+ sorted_final_df = final_df.sort_values(
1389
+ by=sort_col, ascending=True, na_position="first"
1390
+ )
1391
+ st.dataframe(sorted_final_df, hide_index=True)
1392
+
1393
+ # Initialize an empty dictionary to hold categories and their variables
1394
+ category_dict = {}
1395
+
1396
+ # Iterate over each row in the edited DataFrame to populate the dictionary
1397
+ for i, row in edited_stats_df.iterrows():
1398
+ column = row["Column"]
1399
+ category = row[
1400
+ "Category"
1401
+ ] # The category chosen by the user for this variable
1402
+
1403
+ # Check if the category already exists in the dictionary
1404
+ if category not in category_dict:
1405
+ # If not, initialize it with the current column as its first element
1406
+ category_dict[category] = [column]
1407
+ else:
1408
+ # If it exists, append the current column to the list of variables under this category
1409
+ category_dict[category].append(column)
1410
+
1411
+ # Add Date, Panel_1 and Panel_12 in category dictionary
1412
+ category_dict.update({"Date": ["date"]})
1413
+ if "Panel_1" in final_df.columns:
1414
+ category_dict["Panel Level 1"] = ["Panel_1"]
1415
+ if "Panel_2" in final_df.columns:
1416
+ category_dict["Panel Level 2"] = ["Panel_2"]
1417
+
1418
+ # Display the dictionary
1419
+ st.markdown("#### Variable Category")
1420
+ for category, variables in category_dict.items():
1421
+ # Check if there are multiple variables to handle "and" insertion correctly
1422
+ if len(variables) > 1:
1423
+ # Join all but the last variable with ", ", then add " and " before the last variable
1424
+ variables_str = ", ".join(variables[:-1]) + " and " + variables[-1]
1425
+ else:
1426
+ # If there's only one variable, no need for "and"
1427
+ variables_str = variables[0]
1428
+
1429
+ # Display the category and its variables in the desired format
1430
+ st.markdown(
1431
+ f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
1432
+ unsafe_allow_html=True,
1433
+ )
1434
+
1435
+ # Function to check if Response Metrics is selected
1436
+ st.write("")
1437
+
1438
+ response_metrics_col = category_dict.get("Response Metrics", [])
1439
+ if len(response_metrics_col) == 0:
1440
+ st.warning("Please select Response Metrics column", icon="⚠️")
1441
+ st.stop()
1442
+ # elif len(response_metrics_col) > 1:
1443
+ # st.warning("Please select only one Response Metrics column", icon="⚠️")
1444
+ # st.stop()
1445
+
1446
+ # Store final dataframe and bin dictionary into session state
1447
+
1448
+ st.session_state["final_df"], st.session_state["bin_dict"] = (
1449
+ final_df,
1450
+ category_dict,
1451
+ )
1452
+
1453
+ # Save the DataFrame and dictionary from the session state to the pickle file
1454
+
1455
+ if st.button("Accept and Save", use_container_width=True):
1456
+
1457
+ print("test*************")
1458
+ update_db("1_Data_Import.py")
1459
+
1460
+ project_dct_path = os.path.join(
1461
+ st.session_state["project_path"], "project_dct.pkl"
1462
+ )
1463
+
1464
+ with open(project_dct_path, "wb") as f:
1465
+ pickle.dump(st.session_state["project_dct"], f)
1466
+
1467
+ data_path = os.path.join(
1468
+ st.session_state["project_path"], "data_import.pkl"
1469
+ )
1470
+ st.session_state["data_path"] = data_path
1471
+
1472
+ save_to_pickle(
1473
+ data_path,
1474
+ st.session_state["final_df"],
1475
+ st.session_state["bin_dict"],
1476
+ )
1477
+
1478
+ st.session_state["project_dct"]["data_import"][
1479
+ "edited_stats_df"
1480
+ ] = edited_stats_df
1481
+ st.session_state["project_dct"]["data_import"]["merged_df"] = merged_df
1482
+ st.session_state["project_dct"]["data_import"][
1483
+ "missing_stats_df"
1484
+ ] = missing_stats_df
1485
+ st.session_state["project_dct"]["data_import"]["cat_dct"] = {
1486
+ col: cat
1487
+ for col, cat in zip(
1488
+ edited_stats_df["Column"], edited_stats_df["Category"]
1489
+ )
1490
+ }
1491
+ st.session_state["project_dct"]["data_import"][
1492
+ "numeric_columns"
1493
+ ] = numeric_columns
1494
+ st.session_state["project_dct"]["data_import"]["default_df"] = default_df
1495
+ st.session_state["project_dct"]["data_import"]["final_df"] = final_df
1496
+ st.session_state["project_dct"]["data_import"]["edited_df"] = edited_df
1497
+
1498
+ st.toast("💾 Saved Successfully!")
1499
+
1500
+ # *****************************************************************
1501
+ # *********************************Persistant flow****************
pages/2_Data_Validation_and_Insights.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from Eda_functions import *
6
+ import numpy as np
7
+ import pickle
8
+
9
+ import streamlit as st
10
+ import streamlit.components.v1 as components
11
+ import sweetviz as sv
12
+ from utilities import set_header, load_local_css
13
+ from st_aggrid import GridOptionsBuilder, GridUpdateMode
14
+ from st_aggrid import GridOptionsBuilder
15
+ from st_aggrid import AgGrid
16
+ import base64
17
+ import os
18
+ import tempfile
19
+
20
+ # from ydata_profiling import ProfileReport
21
+ import re
22
+
23
+ # from pygwalker.api.streamlit import StreamlitRenderer
24
+ # from Home_redirecting import home
25
+ import sqlite3
26
+ from utilities import update_db
27
+
28
+ st.set_page_config(
29
+ page_title="Data Validation",
30
+ page_icon=":shark:",
31
+ layout="wide",
32
+ initial_sidebar_state="collapsed",
33
+ )
34
+ load_local_css("styles.css")
35
+ set_header()
36
+
37
+
38
+ if "project_dct" not in st.session_state:
39
+ # home()
40
+ st.warning("Please select a project from home page")
41
+ st.stop()
42
+
43
+
44
+ data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")
45
+
46
+ try:
47
+ with open(data_path, "rb") as f:
48
+ data = pickle.load(f)
49
+ except Exception as e:
50
+ st.error(f"Please import data from the Data Import Page")
51
+ st.stop()
52
+
53
+ conn = sqlite3.connect(r"DB\User.db", check_same_thread=False) # connection with sql db
54
+ c = conn.cursor()
55
+ st.session_state["cleaned_data"] = data["final_df"]
56
+ st.session_state["category_dict"] = data["bin_dict"]
57
+ # st.write(st.session_state['category_dict'])
58
+
59
+ st.title("Data Validation and Insights")
60
+
61
+
62
+ target_variables = [
63
+ st.session_state["category_dict"][key]
64
+ for key in st.session_state["category_dict"].keys()
65
+ if key == "Response Metrics"
66
+ ]
67
+
68
+
69
+ def format_display(inp):
70
+ return inp.title().replace("_", " ").strip()
71
+
72
+
73
+ target_variables = list(*target_variables)
74
+ target_column = st.selectbox(
75
+ "Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
76
+ target_variables,
77
+ index=st.session_state["project_dct"]["data_validation"]["target_column"],
78
+ format_func=format_display,
79
+ )
80
+
81
+ st.session_state["project_dct"]["data_validation"]["target_column"] = (
82
+ target_variables.index(target_column)
83
+ )
84
+
85
+ st.session_state["target_column"] = target_column
86
+
87
+ panels = st.session_state["category_dict"]["Panel Level 1"][0]
88
+
89
+ selected_panels = st.multiselect(
90
+ "Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
91
+ st.session_state["cleaned_data"][panels].unique(),
92
+ default=st.session_state["project_dct"]["data_validation"]["selected_panels"],
93
+ )
94
+
95
+ st.session_state["project_dct"]["data_validation"]["selected_panels"] = selected_panels
96
+
97
+ aggregation_dict = {
98
+ item: "sum" if key == "Media" else "mean"
99
+ for key, value in st.session_state["category_dict"].items()
100
+ for item in value
101
+ if item not in ["date", "Panel_1"]
102
+ }
103
+
104
+ with st.expander("**Reponse Metric Analysis**"):
105
+
106
+ if len(selected_panels) > 0:
107
+ st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][
108
+ st.session_state["cleaned_data"]["Panel_1"].isin(selected_panels)
109
+ ]
110
+
111
+ st.session_state["Cleaned_data_panel"] = (
112
+ st.session_state["Cleaned_data_panel"]
113
+ .groupby(by="date")
114
+ .agg(aggregation_dict)
115
+ )
116
+ st.session_state["Cleaned_data_panel"] = st.session_state[
117
+ "Cleaned_data_panel"
118
+ ].reset_index()
119
+ else:
120
+ # st.write(st.session_state['cleaned_data'])
121
+ st.session_state["Cleaned_data_panel"] = (
122
+ st.session_state["cleaned_data"].groupby(by="date").agg(aggregation_dict)
123
+ )
124
+ st.session_state["Cleaned_data_panel"] = st.session_state[
125
+ "Cleaned_data_panel"
126
+ ].reset_index()
127
+
128
+ fig = line_plot_target(
129
+ st.session_state["Cleaned_data_panel"],
130
+ target=target_column,
131
+ title=f"{target_column} Over Time",
132
+ )
133
+ st.plotly_chart(fig, use_container_width=True)
134
+
135
+ media_channel = list(
136
+ *[
137
+ st.session_state["category_dict"][key]
138
+ for key in st.session_state["category_dict"].keys()
139
+ if key == "Media"
140
+ ]
141
+ )
142
+ # st.write(media_channel)
143
+
144
+ exo_var = list(
145
+ *[
146
+ st.session_state["category_dict"][key]
147
+ for key in st.session_state["category_dict"].keys()
148
+ if key == "Exogenous"
149
+ ]
150
+ )
151
+ internal_var = list(
152
+ *[
153
+ st.session_state["category_dict"][key]
154
+ for key in st.session_state["category_dict"].keys()
155
+ if key == "Internal"
156
+ ]
157
+ )
158
+ Non_media_variables = exo_var + internal_var
159
+
160
+ st.markdown("### Annual Data Summary")
161
+
162
+ summary_df = summary(
163
+ st.session_state["Cleaned_data_panel"],
164
+ media_channel + [target_column],
165
+ spends=None,
166
+ Target=True,
167
+ )
168
+
169
+ st.dataframe(
170
+ summary_df,
171
+ use_container_width=True,
172
+ )
173
+
174
+ if st.checkbox("Show raw data"):
175
+ st.cache_resource(show_spinner=False)
176
+
177
+ def raw_df_gen():
178
+ # Convert 'date' to datetime but do not convert to string yet for sorting
179
+ dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"])
180
+
181
+ # Concatenate the dates with other numeric columns formatted
182
+ raw_df = pd.concat(
183
+ [
184
+ dates,
185
+ st.session_state["Cleaned_data_panel"]
186
+ .select_dtypes(np.number)
187
+ .applymap(format_numbers),
188
+ ],
189
+ axis=1,
190
+ )
191
+
192
+ # Now sort raw_df by the 'date' column, which is still in datetime format
193
+ sorted_raw_df = raw_df.sort_values(by="date", ascending=True)
194
+
195
+ # After sorting, convert 'date' to string format for display
196
+ sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y")
197
+
198
+ return sorted_raw_df
199
+
200
+ # Display the sorted DataFrame in Streamlit
201
+ st.dataframe(raw_df_gen())
202
+
203
+ col1 = st.columns(1)
204
+
205
+ if "selected_feature" not in st.session_state:
206
+ st.session_state["selected_feature"] = None
207
+
208
+
209
+ def generate_report_with_target(channel_data, target_feature):
210
+ report = sv.analyze([channel_data, "Dataset"], target_feat=target_feature)
211
+ temp_dir = tempfile.mkdtemp()
212
+ report_path = os.path.join(temp_dir, "report.html")
213
+ report.show_html(
214
+ filepath=report_path, open_browser=False
215
+ ) # Generate the report as an HTML file
216
+ return report_path
217
+
218
+
219
+ def generate_profile_report(df):
220
+ pr = df.profile_report()
221
+ temp_dir = tempfile.mkdtemp()
222
+ report_path = os.path.join(temp_dir, "report.html")
223
+ pr.to_file(report_path)
224
+ return report_path
225
+
226
+
227
+ # st.header()
228
+ with st.expander("Univariate and Bivariate Report"):
229
+ eda_columns = st.columns(2)
230
+ with eda_columns[0]:
231
+ if st.button(
232
+ "Generate Profile Report",
233
+ help="Univariate report which inlcudes all statistical analysis",
234
+ ):
235
+ with st.spinner("Generating Report"):
236
+ report_file = generate_profile_report(
237
+ st.session_state["Cleaned_data_panel"]
238
+ )
239
+
240
+ if os.path.exists(report_file):
241
+ with open(report_file, "rb") as f:
242
+ st.success("Report Generated")
243
+ st.download_button(
244
+ label="Download EDA Report",
245
+ data=f.read(),
246
+ file_name="pandas_profiling_report.html",
247
+ mime="text/html",
248
+ )
249
+ else:
250
+ st.warning(
251
+ "Report generation failed. Unable to find the report file."
252
+ )
253
+
254
+ with eda_columns[1]:
255
+ if st.button(
256
+ "Generate Sweetviz Report",
257
+ help="Bivariate report for selected response metric",
258
+ ):
259
+ with st.spinner("Generating Report"):
260
+ report_file = generate_report_with_target(
261
+ st.session_state["Cleaned_data_panel"], target_column
262
+ )
263
+
264
+ if os.path.exists(report_file):
265
+ with open(report_file, "rb") as f:
266
+ st.success("Report Generated")
267
+ st.download_button(
268
+ label="Download EDA Report",
269
+ data=f.read(),
270
+ file_name="report.html",
271
+ mime="text/html",
272
+ )
273
+ else:
274
+ st.warning("Report generation failed. Unable to find the report file.")
275
+
276
+
277
+ # st.warning('Work in Progress')
278
+ with st.expander("Media Variables Analysis"):
279
+ # Get the selected feature
280
+
281
+ media_variables = [
282
+ col
283
+ for col in media_channel
284
+ if "cost" not in col.lower() and "spend" not in col.lower()
285
+ ]
286
+
287
+ st.session_state["selected_feature"] = st.selectbox(
288
+ "Select media", media_variables, format_func=format_display
289
+ )
290
+
291
+ st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
292
+ media_variables.index(st.session_state["selected_feature"])
293
+ )
294
+
295
+ # Filter spends features based on the selected feature
296
+ spends_features = [
297
+ col
298
+ for col in st.session_state["Cleaned_data_panel"].columns
299
+ if any(keyword in col.lower() for keyword in ["cost", "spend"])
300
+ ]
301
+ spends_feature = [
302
+ col
303
+ for col in spends_features
304
+ if re.split(r"_cost|_spend", col.lower())[0]
305
+ in st.session_state["selected_feature"]
306
+ ]
307
+
308
+ if "validation" not in st.session_state:
309
+
310
+ st.session_state["validation"] = st.session_state["project_dct"][
311
+ "data_validation"
312
+ ]["validated_variables"]
313
+
314
+ val_variables = [col for col in media_channel if col != "date"]
315
+
316
+ if not set(
317
+ st.session_state["project_dct"]["data_validation"]["validated_variables"]
318
+ ).issubset(set(val_variables)):
319
+
320
+ st.session_state["validation"] = []
321
+
322
+ if len(spends_feature) == 0:
323
+ st.warning("No spends varaible available for the selected metric in data")
324
+
325
+ else:
326
+ fig_row1 = line_plot(
327
+ st.session_state["Cleaned_data_panel"],
328
+ x_col="date",
329
+ y1_cols=[st.session_state["selected_feature"]],
330
+ y2_cols=[target_column],
331
+ title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
332
+ )
333
+ st.plotly_chart(fig_row1, use_container_width=True)
334
+ st.markdown("### Summary")
335
+ st.dataframe(
336
+ summary(
337
+ st.session_state["cleaned_data"],
338
+ [st.session_state["selected_feature"]],
339
+ spends=spends_feature[0],
340
+ ),
341
+ use_container_width=True,
342
+ )
343
+
344
+ cols2 = st.columns(2)
345
+
346
+ if len(set(st.session_state["validation"]).intersection(val_variables)) == len(
347
+ val_variables
348
+ ):
349
+ disable = True
350
+ help = "All media variables are validated"
351
+ else:
352
+ disable = False
353
+ help = ""
354
+
355
+ with cols2[0]:
356
+ if st.button("Validate", disabled=disable, help=help):
357
+ st.session_state["validation"].append(
358
+ st.session_state["selected_feature"]
359
+ )
360
+ with cols2[1]:
361
+
362
+ if st.checkbox("Validate all", disabled=disable, help=help):
363
+ st.session_state["validation"].extend(val_variables)
364
+ st.success("All media variables are validated ✅")
365
+
366
+ if len(set(st.session_state["validation"]).intersection(val_variables)) != len(
367
+ val_variables
368
+ ):
369
+ validation_data = pd.DataFrame(
370
+ {
371
+ "Validate": [
372
+ (True if col in st.session_state["validation"] else False)
373
+ for col in val_variables
374
+ ],
375
+ "Variables": val_variables,
376
+ }
377
+ )
378
+
379
+ sorted_validation_df = validation_data.sort_values(
380
+ by="Variables", ascending=True, na_position="first"
381
+ )
382
+ cols3 = st.columns([1, 30])
383
+ with cols3[1]:
384
+ validation_df = st.data_editor(
385
+ sorted_validation_df,
386
+ # column_config={
387
+ # 'Validate':st.column_config.CheckboxColumn(wi)
388
+ # },
389
+ column_config={
390
+ "Validate": st.column_config.CheckboxColumn(
391
+ default=False,
392
+ width=100,
393
+ ),
394
+ "Variables": st.column_config.TextColumn(width=1000),
395
+ },
396
+ hide_index=True,
397
+ )
398
+
399
+ selected_rows = validation_df[validation_df["Validate"] == True][
400
+ "Variables"
401
+ ]
402
+
403
+ # st.write(selected_rows)
404
+
405
+ st.session_state["validation"].extend(selected_rows)
406
+
407
+ st.session_state["project_dct"]["data_validation"][
408
+ "validated_variables"
409
+ ] = st.session_state["validation"]
410
+
411
+ not_validated_variables = [
412
+ col
413
+ for col in val_variables
414
+ if col not in st.session_state["validation"]
415
+ ]
416
+
417
+ if not_validated_variables:
418
+ not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
419
+ st.warning(not_validated_message)
420
+
421
+
422
+ with st.expander("Non Media Variables Analysis"):
423
+ selected_columns_row4 = st.selectbox(
424
+ "Select Channel",
425
+ Non_media_variables,
426
+ format_func=format_display,
427
+ index=st.session_state["project_dct"]["data_validation"]["Non_media_variables"],
428
+ )
429
+
430
+ st.session_state["project_dct"]["data_validation"]["Non_media_variables"] = (
431
+ Non_media_variables.index(selected_columns_row4)
432
+ )
433
+
434
+ # # Create the dual-axis line plot
435
+ fig_row4 = line_plot(
436
+ st.session_state["Cleaned_data_panel"],
437
+ x_col="date",
438
+ y1_cols=[selected_columns_row4],
439
+ y2_cols=[target_column],
440
+ title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
441
+ )
442
+ st.plotly_chart(fig_row4, use_container_width=True)
443
+ selected_non_media = selected_columns_row4
444
+ sum_df = st.session_state["Cleaned_data_panel"][
445
+ ["date", selected_non_media, target_column]
446
+ ]
447
+ sum_df["Year"] = pd.to_datetime(
448
+ st.session_state["Cleaned_data_panel"]["date"]
449
+ ).dt.year
450
+ # st.dataframe(df)
451
+ # st.dataframe(sum_df.head(2))
452
+ print(sum_df)
453
+ sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
454
+ sum_df.loc["Grand Total"] = sum_df.sum()
455
+ sum_df = sum_df.applymap(format_numbers)
456
+ sum_df.fillna("-", inplace=True)
457
+ sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
458
+ st.markdown("### Summary")
459
+ st.dataframe(sum_df, use_container_width=True)
460
+
461
+ # with st.expander('Interactive Dashboard'):
462
+
463
+ # pygg_app=StreamlitRenderer(st.session_state['cleaned_data'])
464
+
465
+ # pygg_app.explorer()
466
+
467
+ with st.expander("Correlation Analysis"):
468
+ options = list(
469
+ st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
470
+ )
471
+
472
+ # selected_options = []
473
+ # num_columns = 4
474
+ # num_rows = -(-len(options) // num_columns) # Ceiling division to calculate rows
475
+
476
+ # # Create a grid of checkboxes
477
+ # st.header('Select Features for Correlation Plot')
478
+ # tick=False
479
+ # if st.checkbox('Select all'):
480
+ # tick=True
481
+ # selected_options = []
482
+ # for row in range(num_rows):
483
+ # cols = st.columns(num_columns)
484
+ # for col in cols:
485
+ # if options:
486
+ # option = options.pop(0)
487
+ # selected = col.checkbox(option,value=tick)
488
+ # if selected:
489
+ # selected_options.append(option)
490
+ # # Display selected options
491
+
492
+ selected_options = st.multiselect(
493
+ "Select Variables For correlation plot",
494
+ [var for var in options if var != target_column],
495
+ default=options[3],
496
+ )
497
+
498
+ st.pyplot(
499
+ correlation_plot(
500
+ st.session_state["Cleaned_data_panel"],
501
+ selected_options,
502
+ target_column,
503
+ )
504
+ )
505
+
506
+ if st.button("Save Changes", use_container_width=True):
507
+
508
+ update_db("2_Data_Validation.py")
509
+
510
+ project_dct_path = os.path.join(st.session_state["project_path"], "project_dct.pkl")
511
+
512
+ with open(project_dct_path, "wb") as f:
513
+ pickle.dump(st.session_state["project_dct"], f)
514
+ st.success("Changes saved")
pages/3_Transformations.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ import streamlit as st
3
+
4
+ st.set_page_config(
5
+ page_title="Transformations",
6
+ page_icon=":shark:",
7
+ layout="wide",
8
+ initial_sidebar_state="collapsed",
9
+ )
10
+
11
+ import pickle
12
+ import numpy as np
13
+ import pandas as pd
14
+ from utilities import set_header, load_local_css
15
+ import streamlit_authenticator as stauth
16
+ import yaml
17
+ from yaml import SafeLoader
18
+ import os
19
+ import sqlite3
20
+ from utilities import update_db
21
+
22
+
23
+ load_local_css("styles.css")
24
+ set_header()
25
+
26
+
27
+ # Check for authentication status
28
+ for k, v in st.session_state.items():
29
+ if k not in ["logout", "login", "config"] and not k.startswith("FormSubmitter"):
30
+ st.session_state[k] = v
31
+ with open("config.yaml") as file:
32
+ config = yaml.load(file, Loader=SafeLoader)
33
+ st.session_state["config"] = config
34
+ authenticator = stauth.Authenticate(
35
+ config["credentials"],
36
+ config["cookie"]["name"],
37
+ config["cookie"]["key"],
38
+ config["cookie"]["expiry_days"],
39
+ config["preauthorized"],
40
+ )
41
+ st.session_state["authenticator"] = authenticator
42
+ name, authentication_status, username = authenticator.login("Login", "main")
43
+ auth_status = st.session_state.get("authentication_status")
44
+
45
+ if auth_status == True:
46
+ authenticator.logout("Logout", "main")
47
+ is_state_initiaized = st.session_state.get("initialized", False)
48
+
49
+ if "project_dct" not in st.session_state:
50
+ st.error("Please load a project from Home page")
51
+ st.stop()
52
+
53
+ conn = sqlite3.connect(
54
+ r"DB/User.db", check_same_thread=False
55
+ ) # connection with sql db
56
+ c = conn.cursor()
57
+
58
+ if not is_state_initiaized:
59
+ if "session_name" not in st.session_state:
60
+ st.session_state["session_name"] = None
61
+
62
+ if not os.path.exists(
63
+ os.path.join(st.session_state["project_path"], "data_import.pkl")
64
+ ):
65
+ st.error("Please move to Data Import page")
66
+ # Deserialize and load the objects from the pickle file
67
+ with open(
68
+ os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
69
+ ) as f:
70
+ data = pickle.load(f)
71
+
72
+ # Accessing the loaded objects
73
+ final_df_loaded = data["final_df"]
74
+ bin_dict_loaded = data["bin_dict"]
75
+ # final_df_loaded.to_csv("Test/final_df_loaded.csv",index=False)
76
+ # Initialize session state==-
77
+ if "transformed_columns_dict" not in st.session_state:
78
+ st.session_state["transformed_columns_dict"] = {} # Default empty dictionary
79
+
80
+ if "final_df" not in st.session_state:
81
+ st.session_state["final_df"] = final_df_loaded # Default as original dataframe
82
+
83
+ if "summary_string" not in st.session_state:
84
+ st.session_state["summary_string"] = None # Default as None
85
+
86
+ # Extract original columns for specified categories
87
+ original_columns = {
88
+ category: bin_dict_loaded[category]
89
+ for category in ["Media", "Internal", "Exogenous"]
90
+ if category in bin_dict_loaded
91
+ }
92
+
93
+ # Retrive Panel columns
94
+ panel_1 = bin_dict_loaded.get("Panel Level 1")
95
+ panel_2 = bin_dict_loaded.get("Panel Level 2")
96
+
97
+ # # For testing on non panel level
98
+ # final_df_loaded = final_df_loaded.drop("Panel_1", axis=1)
99
+ # final_df_loaded = final_df_loaded.groupby("date").mean().reset_index()
100
+ # panel_1 = None
101
+
102
+ # Apply transformations on panel level
103
+ if panel_1:
104
+ panel = panel_1 + panel_2 if panel_2 else panel_1
105
+ else:
106
+ panel = []
107
+
108
+ # Function to build transformation widgets
109
+ def transformation_widgets(category, transform_params, date_granularity):
110
+
111
+ if (
112
+ st.session_state["project_dct"]["transformations"] is None
113
+ or st.session_state["project_dct"]["transformations"] == {}
114
+ ):
115
+ st.session_state["project_dct"]["transformations"] = {}
116
+ if category not in st.session_state["project_dct"]["transformations"].keys():
117
+ st.session_state["project_dct"]["transformations"][category] = {}
118
+
119
+ # Define a dict of pre-defined default values of every transformation
120
+ predefined_defualts = {
121
+ "Lag": (1, 2),
122
+ "Lead": (1, 2),
123
+ "Moving Average": (1, 2),
124
+ "Saturation": (10, 20),
125
+ "Power": (2, 4),
126
+ "Adstock": (0.5, 0.7),
127
+ }
128
+
129
+ def selection_change():
130
+ # Handles removing transformations
131
+ if f"transformation_{category}" in st.session_state:
132
+ current_selection = st.session_state[f"transformation_{category}"]
133
+ past_selection = st.session_state["project_dct"]["transformations"][
134
+ category
135
+ ][f"transformation_{category}"]
136
+ removed_selection = list(set(past_selection) - set(current_selection))
137
+ for selection in removed_selection:
138
+ # Option 1 - revert to defualt
139
+ # st.session_state['project_dct']['transformations'][category][selection] = predefined_defualts[selection]
140
+
141
+ # option 2 - delete from dict
142
+ del st.session_state["project_dct"]["transformations"][category][
143
+ selection
144
+ ]
145
+
146
+ # Transformation Options
147
+ transformation_options = {
148
+ "Media": [
149
+ "Lag",
150
+ "Moving Average",
151
+ "Saturation",
152
+ "Power",
153
+ "Adstock",
154
+ ],
155
+ "Internal": ["Lead", "Lag", "Moving Average"],
156
+ "Exogenous": ["Lead", "Lag", "Moving Average"],
157
+ }
158
+
159
+ expanded = st.session_state["project_dct"]["transformations"][category].get(
160
+ "expanded", False
161
+ )
162
+ st.session_state["project_dct"]["transformations"][category]["expanded"] = False
163
+ with st.expander(f"{category} Transformations", expanded=expanded):
164
+ st.session_state["project_dct"]["transformations"][category][
165
+ "expanded"
166
+ ] = True
167
+
168
+ # Let users select which transformations to apply
169
+ sel_transformations = st.session_state["project_dct"]["transformations"][
170
+ category
171
+ ].get(f"transformation_{category}", [])
172
+ transformations_to_apply = st.multiselect(
173
+ "Select transformations to apply",
174
+ options=transformation_options[category],
175
+ default=sel_transformations,
176
+ key=f"transformation_{category}",
177
+ # on_change=selection_change(),
178
+ )
179
+ st.session_state["project_dct"]["transformations"][category][
180
+ "transformation_" + category
181
+ ] = transformations_to_apply
182
+ # Determine the number of transformations to put in each column
183
+ transformations_per_column = (
184
+ len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
185
+ )
186
+
187
+ # Create two columns
188
+ col1, col2 = st.columns(2)
189
+
190
+ # Assign transformations to each column
191
+ transformations_col1 = transformations_to_apply[:transformations_per_column]
192
+ transformations_col2 = transformations_to_apply[transformations_per_column:]
193
+
194
+ # Define a helper function to create widgets for each transformation
195
+ def create_transformation_widgets(column, transformations):
196
+ with column:
197
+ for transformation in transformations:
198
+ # Conditionally create widgets for selected transformations
199
+ if transformation == "Lead":
200
+ lead_default = st.session_state["project_dct"][
201
+ "transformations"
202
+ ][category].get("Lead", predefined_defualts["Lead"])
203
+ st.markdown(f"**Lead ({date_granularity})**")
204
+ lead = st.slider(
205
+ "Lead periods",
206
+ 1,
207
+ 10,
208
+ lead_default,
209
+ 1,
210
+ key=f"lead_{category}",
211
+ label_visibility="collapsed",
212
+ )
213
+ st.session_state["project_dct"]["transformations"][
214
+ category
215
+ ]["Lead"] = lead
216
+ start = lead[0]
217
+ end = lead[1]
218
+ step = 1
219
+ transform_params[category]["Lead"] = np.arange(
220
+ start, end + step, step
221
+ )
222
+
223
+ if transformation == "Lag":
224
+ lag_default = st.session_state["project_dct"][
225
+ "transformations"
226
+ ][category].get("Lag", predefined_defualts["Lag"])
227
+ st.markdown(f"**Lag ({date_granularity})**")
228
+ lag = st.slider(
229
+ "Lag periods",
230
+ 1,
231
+ 10,
232
+ (1, 2), # lag_default,
233
+ 1,
234
+ key=f"lag_{category}",
235
+ label_visibility="collapsed",
236
+ )
237
+ st.session_state["project_dct"]["transformations"][
238
+ category
239
+ ]["Lag"] = lag
240
+ start = lag[0]
241
+ end = lag[1]
242
+ step = 1
243
+ transform_params[category]["Lag"] = np.arange(
244
+ start, end + step, step
245
+ )
246
+
247
+ if transformation == "Moving Average":
248
+ ma_default = st.session_state["project_dct"][
249
+ "transformations"
250
+ ][category].get("MA", predefined_defualts["Moving Average"])
251
+ st.markdown(f"**Moving Average ({date_granularity})**")
252
+ window = st.slider(
253
+ "Window size for Moving Average",
254
+ 1,
255
+ 10,
256
+ ma_default,
257
+ 1,
258
+ key=f"ma_{category}",
259
+ label_visibility="collapsed",
260
+ )
261
+ st.session_state["project_dct"]["transformations"][
262
+ category
263
+ ]["MA"] = window
264
+ start = window[0]
265
+ end = window[1]
266
+ step = 1
267
+ transform_params[category]["Moving Average"] = np.arange(
268
+ start, end + step, step
269
+ )
270
+
271
+ if transformation == "Saturation":
272
+ st.markdown("**Saturation (%)**")
273
+ saturation_default = st.session_state["project_dct"][
274
+ "transformations"
275
+ ][category].get(
276
+ "Saturation", predefined_defualts["Saturation"]
277
+ )
278
+ saturation_point = st.slider(
279
+ f"Saturation Percentage",
280
+ 0,
281
+ 100,
282
+ saturation_default,
283
+ 10,
284
+ key=f"sat_{category}",
285
+ label_visibility="collapsed",
286
+ )
287
+ st.session_state["project_dct"]["transformations"][
288
+ category
289
+ ]["Saturation"] = saturation_point
290
+ start = saturation_point[0]
291
+ end = saturation_point[1]
292
+ step = 10
293
+ transform_params[category]["Saturation"] = np.arange(
294
+ start, end + step, step
295
+ )
296
+
297
+ if transformation == "Power":
298
+ st.markdown("**Power**")
299
+ power_default = st.session_state["project_dct"][
300
+ "transformations"
301
+ ][category].get("Power", predefined_defualts["Power"])
302
+ power = st.slider(
303
+ f"Power",
304
+ 0,
305
+ 10,
306
+ power_default,
307
+ 1,
308
+ key=f"power_{category}",
309
+ label_visibility="collapsed",
310
+ )
311
+ st.session_state["project_dct"]["transformations"][
312
+ category
313
+ ]["Power"] = power
314
+ start = power[0]
315
+ end = power[1]
316
+ step = 1
317
+ transform_params[category]["Power"] = np.arange(
318
+ start, end + step, step
319
+ )
320
+
321
+ if transformation == "Adstock":
322
+ ads_default = st.session_state["project_dct"][
323
+ "transformations"
324
+ ][category].get("Adstock", predefined_defualts["Adstock"])
325
+ st.markdown("**Adstock**")
326
+ rate = st.slider(
327
+ f"Factor ({category})",
328
+ 0.0,
329
+ 1.0,
330
+ ads_default,
331
+ 0.05,
332
+ key=f"adstock_{category}",
333
+ label_visibility="collapsed",
334
+ )
335
+ st.session_state["project_dct"]["transformations"][
336
+ category
337
+ ]["Adstock"] = rate
338
+ start = rate[0]
339
+ end = rate[1]
340
+ step = 0.05
341
+ adstock_range = [
342
+ round(a, 3) for a in np.arange(start, end + step, step)
343
+ ]
344
+ transform_params[category]["Adstock"] = adstock_range
345
+
346
+ # Create widgets in each column
347
+ create_transformation_widgets(col1, transformations_col1)
348
+ create_transformation_widgets(col2, transformations_col2)
349
+
350
+ # Function to apply Lag transformation
351
+ def apply_lag(df, lag):
352
+ return df.shift(lag)
353
+
354
+ # Function to apply Lead transformation
355
+ def apply_lead(df, lead):
356
+ return df.shift(-lead)
357
+
358
+ # Function to apply Moving Average transformation
359
+ def apply_moving_average(df, window_size):
360
+ return df.rolling(window=window_size).mean()
361
+
362
+ # Function to apply Saturation transformation
363
+ def apply_saturation(df, saturation_percent_100):
364
+ # Convert saturation percentage from 100-based to fraction
365
+ saturation_percent = saturation_percent_100 / 100.0
366
+
367
+ # Calculate saturation point and steepness
368
+ column_max = df.max()
369
+ column_min = df.min()
370
+ saturation_point = (column_min + column_max) / 2
371
+
372
+ numerator = np.log(
373
+ (1 / (saturation_percent if saturation_percent != 1 else 1 - 1e-9)) - 1
374
+ )
375
+ denominator = np.log(saturation_point / max(column_max, 1e-9))
376
+
377
+ steepness = numerator / max(
378
+ denominator, 1e-9
379
+ ) # Avoid division by zero with a small constant
380
+
381
+ # Apply the saturation transformation
382
+ transformed_series = df.apply(
383
+ lambda x: (1 / (1 + (saturation_point / x) ** steepness)) * x
384
+ )
385
+
386
+ return transformed_series
387
+
388
+ # Function to apply Power transformation
389
+ def apply_power(df, power):
390
+ return df**power
391
+
392
+ # Function to apply Adstock transformation
393
+ def apply_adstock(df, factor):
394
+ x = 0
395
+ # Use the walrus operator to update x iteratively with the Adstock formula
396
+ adstock_var = [x := x * factor + v for v in df]
397
+ ans = pd.Series(adstock_var, index=df.index)
398
+ return ans
399
+
400
+ # Function to generate transformed columns names
401
+ @st.cache_resource(show_spinner=False)
402
+ def generate_transformed_columns(original_columns, transform_params):
403
+ transformed_columns, summary = {}, {}
404
+
405
+ for category, columns in original_columns.items():
406
+ for column in columns:
407
+ transformed_columns[column] = []
408
+ summary_details = (
409
+ []
410
+ ) # List to hold transformation details for the current column
411
+
412
+ if category in transform_params:
413
+ for transformation, values in transform_params[category].items():
414
+ # Generate transformed column names for each value
415
+ for value in values:
416
+ transformed_name = f"{column}@{transformation}_{value}"
417
+ transformed_columns[column].append(transformed_name)
418
+
419
+ # Format the values list as a string with commas and "and" before the last item
420
+ if len(values) > 1:
421
+ formatted_values = (
422
+ ", ".join(map(str, values[:-1]))
423
+ + " and "
424
+ + str(values[-1])
425
+ )
426
+ else:
427
+ formatted_values = str(values[0])
428
+
429
+ # Add transformation details
430
+ summary_details.append(f"{transformation} ({formatted_values})")
431
+
432
+ # Only add to summary if there are transformation details for the column
433
+ if summary_details:
434
+ formatted_summary = "⮕ ".join(summary_details)
435
+ # Use <strong> tags to make the column name bold
436
+ summary[column] = f"<strong>{column}</strong>: {formatted_summary}"
437
+
438
+ # Generate a comprehensive summary string for all columns
439
+ summary_items = [
440
+ f"{idx + 1}. {details}" for idx, details in enumerate(summary.values())
441
+ ]
442
+
443
+ summary_string = "\n".join(summary_items)
444
+
445
+ return transformed_columns, summary_string
446
+
447
+ # Function to apply transformations to DataFrame slices based on specified categories and parameters
448
+ @st.cache_resource(show_spinner=False)
449
+ def apply_category_transformations(df, bin_dict, transform_params, panel):
450
+ # Dictionary for function mapping
451
+ transformation_functions = {
452
+ "Lead": apply_lead,
453
+ "Lag": apply_lag,
454
+ "Moving Average": apply_moving_average,
455
+ "Saturation": apply_saturation,
456
+ "Power": apply_power,
457
+ "Adstock": apply_adstock,
458
+ }
459
+
460
+ # Initialize category_df as an empty DataFrame
461
+ category_df = pd.DataFrame()
462
+
463
+ # Iterate through each category specified in transform_params
464
+ for category in ["Media", "Internal", "Exogenous"]:
465
+ if (
466
+ category not in transform_params
467
+ or category not in bin_dict
468
+ or not transform_params[category]
469
+ ):
470
+ continue # Skip categories without transformations
471
+
472
+ # Slice the DataFrame based on the columns specified in bin_dict for the current category
473
+ df_slice = df[bin_dict[category] + panel]
474
+
475
+ # Iterate through each transformation and its parameters for the current category
476
+ for transformation, parameters in transform_params[category].items():
477
+ transformation_function = transformation_functions[transformation]
478
+
479
+ # Check if there is panel data to group by
480
+ if len(panel) > 0:
481
+ # Apply the transformation to each group
482
+ category_df = pd.concat(
483
+ [
484
+ df_slice.groupby(panel)
485
+ .transform(transformation_function, p)
486
+ .add_suffix(f"@{transformation}_{p}")
487
+ for p in parameters
488
+ ],
489
+ axis=1,
490
+ )
491
+
492
+ # Replace all NaN or null values in category_df with 0
493
+ category_df.fillna(0, inplace=True)
494
+
495
+ # Update df_slice
496
+ df_slice = pd.concat(
497
+ [df[panel], category_df],
498
+ axis=1,
499
+ )
500
+
501
+ else:
502
+ for p in parameters:
503
+ # Apply the transformation function to each column
504
+ temp_df = df_slice.apply(
505
+ lambda x: transformation_function(x, p), axis=0
506
+ ).rename(
507
+ lambda x: f"{x}@{transformation}_{p}",
508
+ axis="columns",
509
+ )
510
+ # Concatenate the transformed DataFrame slice to the category DataFrame
511
+ category_df = pd.concat([category_df, temp_df], axis=1)
512
+
513
+ # Replace all NaN or null values in category_df with 0
514
+ category_df.fillna(0, inplace=True)
515
+
516
+ # Update df_slice
517
+ df_slice = pd.concat(
518
+ [df[panel], category_df],
519
+ axis=1,
520
+ )
521
+
522
+ # If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
523
+ if not category_df.empty:
524
+ final_df = pd.concat([df, category_df], axis=1)
525
+ else:
526
+ # If no transformations were applied, use the original DataFrame
527
+ final_df = df
528
+
529
+ return final_df
530
+
531
+ # Function to infers the granularity of the date column in a DataFrame
532
+ @st.cache_resource(show_spinner=False)
533
+ def infer_date_granularity(df):
534
+ # Find the most common difference
535
+ common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
536
+
537
+ # Map the most common difference to a granularity
538
+ if common_freq == 1:
539
+ return "daily"
540
+ elif common_freq == 7:
541
+ return "weekly"
542
+ elif 28 <= common_freq <= 31:
543
+ return "monthly"
544
+ else:
545
+ return "irregular"
546
+
547
+ #########################################################################################################################################################
548
+ # User input for transformations
549
+ #########################################################################################################################################################
550
+
551
+ # Infer date granularity
552
+ date_granularity = infer_date_granularity(final_df_loaded)
553
+
554
+ # Initialize the main dictionary to store the transformation parameters for each category
555
+ transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
556
+
557
+ # User input for transformations
558
+ st.markdown("### Select Transformations to Apply")
559
+ for category in ["Media", "Internal", "Exogenous"]:
560
+ # Skip Internal
561
+ if category == "Internal":
562
+ continue
563
+
564
+ transformation_widgets(category, transform_params, date_granularity)
565
+
566
+ #########################################################################################################################################################
567
+ # Apply transformations
568
+ #########################################################################################################################################################
569
+
570
+ # Apply category-based transformations to the DataFrame
571
+ if st.button("Accept and Proceed", use_container_width=True):
572
+ with st.spinner("Applying transformations..."):
573
+ final_df = apply_category_transformations(
574
+ final_df_loaded, bin_dict_loaded, transform_params, panel
575
+ )
576
+
577
+ # Generate a dictionary mapping original column names to lists of transformed column names
578
+ transformed_columns_dict, summary_string = generate_transformed_columns(
579
+ original_columns, transform_params
580
+ )
581
+
582
+ # Store into transformed dataframe and summary session state
583
+ st.session_state["final_df"] = final_df
584
+ st.session_state["summary_string"] = summary_string
585
+
586
+ #########################################################################################################################################################
587
+ # Display the transformed DataFrame and summary
588
+ #########################################################################################################################################################
589
+
590
+ # Display the transformed DataFrame in the Streamlit app
591
+ st.markdown("### Transformed DataFrame")
592
+ final_df = st.session_state["final_df"].copy()
593
+
594
+ sort_col = []
595
+ for col in final_df.columns:
596
+ if col in ["Panel_1", "Panel_2", "date"]:
597
+ sort_col.append(col)
598
+
599
+ sorted_final_df = final_df.sort_values(
600
+ by=sort_col, ascending=True, na_position="first"
601
+ )
602
+ st.dataframe(sorted_final_df, hide_index=True)
603
+
604
+ # Total rows and columns
605
+ total_rows, total_columns = st.session_state["final_df"].shape
606
+ st.markdown(
607
+ f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
608
+ unsafe_allow_html=True,
609
+ )
610
+
611
+ # Display the summary of transformations as markdown
612
+ if "summary_string" in st.session_state and st.session_state["summary_string"]:
613
+ with st.expander("Summary of Transformations"):
614
+ st.markdown("### Summary of Transformations")
615
+ st.markdown(st.session_state["summary_string"], unsafe_allow_html=True)
616
+
617
+ @st.cache_resource(show_spinner=False)
618
+ def save_to_pickle(file_path, final_df):
619
+ # Open the file in write-binary mode and dump the objects
620
+ with open(file_path, "wb") as f:
621
+ pickle.dump({"final_df_transformed": final_df}, f)
622
+ # Data is now saved to file
623
+
624
+ if st.button("Accept and Save", use_container_width=True):
625
+
626
+ save_to_pickle(
627
+ os.path.join(st.session_state["project_path"], "final_df_transformed.pkl"),
628
+ st.session_state["final_df"],
629
+ )
630
+ project_dct_path = os.path.join(
631
+ st.session_state["project_path"], "project_dct.pkl"
632
+ )
633
+
634
+ with open(project_dct_path, "wb") as f:
635
+ pickle.dump(st.session_state["project_dct"], f)
636
+
637
+ update_db("3_Transformations.py")
638
+
639
+ st.toast("💾 Saved Successfully!")
pages/4_Model_Build 2.py ADDED
@@ -0,0 +1,1288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MMO Build Sprint 3
3
+ additions : adding more variables to session state for saved model : random effect, predicted train & test
4
+
5
+ MMO Build Sprint 4
6
+ additions : ability to run models for different response metrics
7
+ """
8
+
9
+ import streamlit as st
10
+ import pandas as pd
11
+ import plotly.express as px
12
+ import plotly.graph_objects as go
13
+ from Eda_functions import format_numbers
14
+ import numpy as np
15
+ import pickle
16
+ from st_aggrid import AgGrid
17
+ from st_aggrid import GridOptionsBuilder, GridUpdateMode
18
+ from utilities import set_header, load_local_css
19
+ from st_aggrid import GridOptionsBuilder
20
+ import time
21
+ import itertools
22
+ import statsmodels.api as sm
23
+ import numpy as npc
24
+ import re
25
+ import itertools
26
+ from sklearn.metrics import (
27
+ mean_absolute_error,
28
+ r2_score,
29
+ mean_absolute_percentage_error,
30
+ )
31
+ from sklearn.preprocessing import MinMaxScaler
32
+ import os
33
+ import matplotlib.pyplot as plt
34
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
35
+ import yaml
36
+ from yaml import SafeLoader
37
+ import streamlit_authenticator as stauth
38
+
39
+ st.set_option("deprecation.showPyplotGlobalUse", False)
40
+ import statsmodels.api as sm
41
+ import statsmodels.formula.api as smf
42
+
43
+ from datetime import datetime
44
+ import seaborn as sns
45
+ from Data_prep_functions import *
46
+ import sqlite3
47
+ from utilities import update_db
48
+ from datetime import datetime, timedelta
49
+
50
+ @st.cache_resource(show_spinner=False)
51
+ # def save_to_pickle(file_path, final_df):
52
+ # # Open the file in write-binary mode and dump the objects
53
+ # with open(file_path, "wb") as f:
54
+ # pickle.dump({file_path: final_df}, f)
55
+ @st.cache_resource(show_spinner=True)
56
+ def prepare_data_df(data):
57
+ data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
58
+ drop=True
59
+ ) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
60
+ data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
61
+ data.drop_duplicates(subset="Model_iteration", inplace=True)
62
+
63
+ # Applying the function to each row in the DataFrame
64
+ data["coefficients"] = data["coefficients"].apply(process_dict)
65
+
66
+ # Convert dictionary items into separate DataFrame columns
67
+ coefficients_df = data["coefficients"].apply(pd.Series)
68
+
69
+ # Rename the columns to remove any trailing underscores and capitalize the words
70
+ coefficients_df.columns = [
71
+ col.strip("_").replace("_", " ").title() for col in coefficients_df.columns
72
+ ]
73
+
74
+ # Normalize each row so that the sum equals 100%
75
+ coefficients_df = coefficients_df.apply(
76
+ lambda x: round((x / x.sum()) * 100, 2), axis=1
77
+ )
78
+
79
+ # Join the new columns back to the original DataFrame
80
+ data = data.join(coefficients_df)
81
+
82
+ data_df = data[
83
+ [
84
+ "Model_iteration",
85
+ "MAPE",
86
+ "ADJR2",
87
+ "R2",
88
+ "Total Positive Contributions",
89
+ "Significance",
90
+ ]
91
+ + list(coefficients_df.columns)
92
+ ]
93
+ data_df.rename(columns={"Model_iteration": "Model Iteration"}, inplace=True)
94
+ data_df.insert(0, "Rank", range(1, len(data_df) + 1))
95
+
96
+ return coefficients_df, data_df
97
+ def format_display(inp):
98
+ return inp.title().replace("_", " ").strip()
99
+
100
+ def get_random_effects(media_data, panel_col, _mdf):
101
+ random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
102
+
103
+ for i, market in enumerate(media_data[panel_col].unique()):
104
+ print(i, end="\r")
105
+ intercept = _mdf.random_effects[market].values[0]
106
+ random_eff_df.loc[i, "random_effect"] = intercept
107
+ random_eff_df.loc[i, panel_col] = market
108
+
109
+ return random_eff_df
110
+
111
+
112
+ def mdf_predict(X_df, mdf, random_eff_df):
113
+ X = X_df.copy()
114
+ X["fixed_effect"] = mdf.predict(X)
115
+ X = pd.merge(X, random_eff_df, on=panel_col, how="left")
116
+ X["pred"] = X["fixed_effect"] + X["random_effect"]
117
+ # X.to_csv('Test/megred_df.csv',index=False)
118
+ X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
119
+ return X["pred"]
120
+
121
+
122
+ st.set_page_config(
123
+ page_title="Model Build",
124
+ page_icon=":shark:",
125
+ layout="wide",
126
+ initial_sidebar_state="collapsed",
127
+ )
128
+
129
+ load_local_css("styles.css")
130
+ set_header()
131
+
132
+ # Check for authentication status
133
+ for k, v in st.session_state.items():
134
+ if k not in [
135
+ "logout",
136
+ "login",
137
+ "config",
138
+ "model_build_button",
139
+ ] and not k.startswith("FormSubmitter"):
140
+ st.session_state[k] = v
141
+ with open("config.yaml") as file:
142
+ config = yaml.load(file, Loader=SafeLoader)
143
+ st.session_state["config"] = config
144
+ authenticator = stauth.Authenticate(
145
+ config["credentials"],
146
+ config["cookie"]["name"],
147
+ config["cookie"]["key"],
148
+ config["cookie"]["expiry_days"],
149
+ config["preauthorized"],
150
+ )
151
+ st.session_state["authenticator"] = authenticator
152
+ name, authentication_status, username = authenticator.login("Login", "main")
153
+ auth_status = st.session_state.get("authentication_status")
154
+
155
+ if auth_status == True:
156
+ authenticator.logout("Logout", "main")
157
+ is_state_initiaized = st.session_state.get("initialized", False)
158
+
159
+ conn = sqlite3.connect(
160
+ r"DB/User.db", check_same_thread=False
161
+ ) # connection with sql db
162
+ c = conn.cursor()
163
+
164
+ if not is_state_initiaized:
165
+
166
+ if "session_name" not in st.session_state:
167
+ st.session_state["session_name"] = None
168
+
169
+ if "project_dct" not in st.session_state:
170
+ st.error("Please load a project from Home page")
171
+ st.stop()
172
+
173
+ st.title("1. Build Your Model")
174
+
175
+ if not os.path.exists(
176
+ os.path.join(st.session_state["project_path"], "data_import.pkl")
177
+ ):
178
+ st.error("Please move to Data Import Page and save.")
179
+ st.stop()
180
+ with open(
181
+ os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
182
+ ) as f:
183
+ data = pickle.load(f)
184
+ st.session_state["bin_dict"] = data["bin_dict"]
185
+
186
+ if not os.path.exists(
187
+ os.path.join(
188
+ st.session_state["project_path"], "final_df_transformed.pkl"
189
+ )
190
+ ):
191
+ st.error(
192
+ "Please move to Transformation Page and save transformations."
193
+ )
194
+ st.stop()
195
+ with open(
196
+ os.path.join(
197
+ st.session_state["project_path"], "final_df_transformed.pkl"
198
+ ),
199
+ "rb",
200
+ ) as f:
201
+ data = pickle.load(f)
202
+ media_data = data["final_df_transformed"]
203
+
204
+
205
+ # Sprint4 - available response metrics is a list of all reponse metrics in the data
206
+ ## these will be put in a drop down
207
+
208
+ st.session_state["media_data"] = media_data
209
+
210
+ if "available_response_metrics" not in st.session_state:
211
+ # st.session_state['available_response_metrics'] = ['Total Approved Accounts - Revenue',
212
+ # 'Total Approved Accounts - Appsflyer',
213
+ # 'Account Requests - Appsflyer',
214
+ # 'App Installs - Appsflyer']
215
+
216
+ st.session_state["available_response_metrics"] = st.session_state[
217
+ "bin_dict"
218
+ ]["Response Metrics"]
219
+ # Sprint4
220
+ if "is_tuned_model" not in st.session_state:
221
+ st.session_state["is_tuned_model"] = {}
222
+ for resp_metric in st.session_state["available_response_metrics"]:
223
+ resp_metric = (
224
+ resp_metric.lower()
225
+ .replace(" ", "_")
226
+ .replace("-", "")
227
+ .replace(":", "")
228
+ .replace("__", "_")
229
+ )
230
+ st.session_state["is_tuned_model"][resp_metric] = False
231
+
232
+ # Sprint4 - used_response_metrics is a list of resp metrics for which user has created & saved a model
233
+ if "used_response_metrics" not in st.session_state:
234
+ st.session_state["used_response_metrics"] = []
235
+
236
+ # Sprint4 - saved_model_names
237
+ if "saved_model_names" not in st.session_state:
238
+ st.session_state["saved_model_names"] = []
239
+
240
+ if "Model" not in st.session_state:
241
+ if (
242
+ "session_state_saved"
243
+ in st.session_state["project_dct"]["model_build"].keys()
244
+ and st.session_state["project_dct"]["model_build"][
245
+ "session_state_saved"
246
+ ]
247
+ is not None
248
+ and "Model"
249
+ in st.session_state["project_dct"]["model_build"][
250
+ "session_state_saved"
251
+ ].keys()
252
+ ):
253
+ st.session_state["Model"] = st.session_state["project_dct"][
254
+ "model_build"
255
+ ]["session_state_saved"]["Model"]
256
+ else:
257
+ st.session_state["Model"] = {}
258
+
259
+ date_col = "date"
260
+ date = media_data[date_col]
261
+
262
+ # Sprint4 - select a response metric
263
+ default_target_idx = (
264
+ st.session_state["project_dct"]["model_build"].get(
265
+ "sel_target_col", None
266
+ )
267
+ if st.session_state["project_dct"]["model_build"].get(
268
+ "sel_target_col", None
269
+ )
270
+ is not None
271
+ else st.session_state["available_response_metrics"][0]
272
+ )
273
+
274
+ start_cols = st.columns(2)
275
+ min_date = min(date)
276
+ max_date = max(date)
277
+
278
+ with start_cols[0]:
279
+ sel_target_col = st.selectbox(
280
+ "Select the response metric",
281
+ st.session_state["available_response_metrics"],
282
+ index=st.session_state["available_response_metrics"].index(
283
+ default_target_idx
284
+ ),
285
+ format_func=format_display
286
+ )
287
+ # , on_change=reset_save())
288
+ st.session_state["project_dct"]["model_build"][
289
+ "sel_target_col"
290
+ ] = sel_target_col
291
+
292
+
293
+ default_test_start = min_date + (3*(max_date-min_date)/4)
294
+
295
+ with start_cols[1]:
296
+ test_start = st.date_input(
297
+ "Select test start date",
298
+ default_test_start,
299
+ min_value=min_date,
300
+ max_value=max_date,
301
+ )
302
+ train_idx = media_data[media_data[date_col] <= pd.to_datetime(test_start)].index[-1]
303
+ # st.write(train_idx, media_data.index[-1])
304
+
305
+ target_col = (
306
+ sel_target_col.lower()
307
+ .replace(" ", "_")
308
+ .replace("-", "")
309
+ .replace(":", "")
310
+ .replace("__", "_")
311
+ )
312
+ new_name_dct = {
313
+ col: col.lower()
314
+ .replace(".", "_")
315
+ .lower()
316
+ .replace("@", "_")
317
+ .replace(" ", "_")
318
+ .replace("-", "")
319
+ .replace(":", "")
320
+ .replace("__", "_")
321
+ for col in media_data.columns
322
+ }
323
+ media_data.columns = [
324
+ col.lower()
325
+ .replace(".", "_")
326
+ .replace("@", "_")
327
+ .replace(" ", "_")
328
+ .replace("-", "")
329
+ .replace(":", "")
330
+ .replace("__", "_")
331
+ for col in media_data.columns
332
+ ]
333
+ panel_col = [
334
+ col.lower()
335
+ .replace(".", "_")
336
+ .replace("@", "_")
337
+ .replace(" ", "_")
338
+ .replace("-", "")
339
+ .replace(":", "")
340
+ .replace("__", "_")
341
+ for col in st.session_state["bin_dict"]["Panel Level 1"]
342
+ ][0] # set the panel column
343
+
344
+ is_panel = True if len(panel_col) > 0 else False
345
+
346
+ if "is_panel" not in st.session_state:
347
+ st.session_state["is_panel"] = is_panel
348
+
349
+ if is_panel:
350
+ media_data.sort_values([date_col, panel_col], inplace=True)
351
+ else:
352
+ media_data.sort_values(date_col, inplace=True)
353
+
354
+ media_data.reset_index(drop=True, inplace=True)
355
+
356
+ st.session_state["date"] = date
357
+ y = media_data[target_col]
358
+
359
+ if is_panel:
360
+ spends_data = media_data[
361
+ [
362
+ c
363
+ for c in media_data.columns
364
+ if "_cost" in c.lower() or "_spend" in c.lower()
365
+ ]
366
+ + [date_col, panel_col]
367
+ ]
368
+ # Sprint3 - spends for resp curves
369
+ else:
370
+ spends_data = media_data[
371
+ [
372
+ c
373
+ for c in media_data.columns
374
+ if "_cost" in c.lower() or "_spend" in c.lower()
375
+ ]
376
+ + [date_col]
377
+ ]
378
+
379
+ y = media_data[target_col]
380
+ media_data.drop([date_col], axis=1, inplace=True)
381
+ media_data.reset_index(drop=True, inplace=True)
382
+
383
+ columns = st.columns(2)
384
+
385
+ old_shape = media_data.shape
386
+
387
+ if "old_shape" not in st.session_state:
388
+ st.session_state["old_shape"] = old_shape
389
+
390
+ if "media_data" not in st.session_state:
391
+ st.session_state["media_data"] = pd.DataFrame()
392
+
393
+ # Sprint3
394
+ if "orig_media_data" not in st.session_state:
395
+ st.session_state["orig_media_data"] = pd.DataFrame()
396
+
397
+ # Sprint3 additions
398
+ if "random_effects" not in st.session_state:
399
+ st.session_state["random_effects"] = pd.DataFrame()
400
+ if "pred_train" not in st.session_state:
401
+ st.session_state["pred_train"] = []
402
+ if "pred_test" not in st.session_state:
403
+ st.session_state["pred_test"] = []
404
+ # end of Sprint3 additions
405
+
406
+ # Section 3 - Create combinations
407
+
408
+ # bucket=['paid_search', 'kwai','indicacao','infleux', 'influencer','FB: Level Achieved - Tier 1 Impressions',
409
+ # ' FB: Level Achieved - Tier 2 Impressions','paid_social_others',
410
+ # ' GA App: Will And Cid Pequena Baixo Risco Clicks',
411
+ # 'digital_tactic_others',"programmatic"
412
+ # ]
413
+
414
+ # srishti - bucket names changed
415
+ bucket = [
416
+ "paid_search",
417
+ "kwai",
418
+ "indicacao",
419
+ "infleux",
420
+ "influencer",
421
+ "fb_level_achieved_tier_2",
422
+ "fb_level_achieved_tier_1",
423
+ "paid_social_others",
424
+ "ga_app",
425
+ "digital_tactic_others",
426
+ "programmatic",
427
+ ]
428
+
429
+ # with columns[0]:
430
+ # if st.button('Create Combinations of Variables'):
431
+
432
+ top_3_correlated_features = []
433
+ # # for col in st.session_state['media_data'].columns[:19]:
434
+ # original_cols = [c for c in st.session_state['media_data'].columns if
435
+ # "_clicks" in c.lower() or "_impressions" in c.lower()]
436
+ # original_cols = [c for c in original_cols if "_lag" not in c.lower() and "_adstock" not in c.lower()]
437
+
438
+ original_cols = (
439
+ st.session_state["bin_dict"]["Media"]
440
+ + st.session_state["bin_dict"]["Internal"]
441
+ )
442
+
443
+ original_cols = [
444
+ col.lower()
445
+ .replace(".", "_")
446
+ .replace("@", "_")
447
+ .replace(" ", "_")
448
+ .replace("-", "")
449
+ .replace(":", "")
450
+ .replace("__", "_")
451
+ for col in original_cols
452
+ ]
453
+ original_cols = [col for col in original_cols if "_cost" not in col]
454
+ # for col in st.session_state['media_data'].columns[:19]:
455
+ for col in original_cols: # srishti - new
456
+ corr_df = (
457
+ pd.concat(
458
+ [st.session_state["media_data"].filter(regex=col), y], axis=1
459
+ )
460
+ .corr()[target_col]
461
+ .iloc[:-1]
462
+ )
463
+ top_3_correlated_features.append(
464
+ list(corr_df.sort_values(ascending=False).head(2).index)
465
+ )
466
+ flattened_list = [
467
+ item for sublist in top_3_correlated_features for item in sublist
468
+ ]
469
+ # all_features_set={var:[col for col in flattened_list if var in col] for var in bucket}
470
+ all_features_set = {
471
+ var: [col for col in flattened_list if var in col]
472
+ for var in bucket
473
+ if len([col for col in flattened_list if var in col]) > 0
474
+ } # srishti
475
+ channels_all = [values for values in all_features_set.values()]
476
+ st.session_state["combinations"] = list(itertools.product(*channels_all))
477
+ # if 'combinations' not in st.session_state:
478
+ # st.session_state['combinations']=combinations_all
479
+
480
+ st.session_state["final_selection"] = st.session_state["combinations"]
481
+ # st.success('Created combinations')
482
+
483
+ # revenue.reset_index(drop=True,inplace=True)
484
+ y.reset_index(drop=True, inplace=True)
485
+ if "Model_results" not in st.session_state:
486
+ st.session_state["Model_results"] = {
487
+ "Model_object": [],
488
+ "Model_iteration": [],
489
+ "Feature_set": [],
490
+ "MAPE": [],
491
+ "R2": [],
492
+ "ADJR2": [],
493
+ "pos_count": [],
494
+ }
495
+
496
+ def reset_model_result_dct():
497
+ st.session_state["Model_results"] = {
498
+ "Model_object": [],
499
+ "Model_iteration": [],
500
+ "Feature_set": [],
501
+ "MAPE": [],
502
+ "R2": [],
503
+ "ADJR2": [],
504
+ "pos_count": [],
505
+ }
506
+
507
+ # if st.button('Build Model'):
508
+
509
+ if "iterations" not in st.session_state:
510
+ st.session_state["iterations"] = 0
511
+
512
+ if "final_selection" not in st.session_state:
513
+ st.session_state["final_selection"] = False
514
+
515
+ save_path = r"Model/"
516
+ if st.session_state["final_selection"]:
517
+ st.write(
518
+ f'Total combinations created {format_numbers(len(st.session_state["final_selection"]))}'
519
+ )
520
+
521
+ # st.session_state["project_dct"]["model_build"]["all_iters_check"] = False
522
+
523
+ checkbox_default = (
524
+ st.session_state["project_dct"]["model_build"]["all_iters_check"]
525
+ if st.session_state["project_dct"]["model_build"]["all_iters_check"]
526
+ is not None
527
+ else False
528
+ )
529
+ end_date = test_start - timedelta(days=1)
530
+ disp_str = "Data Split -- Training Period: " + min_date.strftime("%B %d, %Y") + " - " + end_date.strftime("%B %d, %Y") +", Testing Period: " + test_start.strftime("%B %d, %Y") + " - " + max_date.strftime("%B %d, %Y")
531
+ st.markdown(disp_str)
532
+ if st.checkbox("Build all iterations", value=checkbox_default):
533
+ # st.session_state["project_dct"]["model_build"]["all_iters_check"]
534
+ iterations = len(st.session_state["final_selection"])
535
+ st.session_state["project_dct"]["model_build"][
536
+ "all_iters_check"
537
+ ] = True
538
+
539
+ else:
540
+ iterations = st.number_input(
541
+ "Select the number of iterations to perform",
542
+ min_value=0,
543
+ step=100,
544
+ value=st.session_state["iterations"],
545
+ on_change=reset_model_result_dct,
546
+ )
547
+ st.session_state["project_dct"]["model_build"][
548
+ "all_iters_check"
549
+ ] = False
550
+ st.session_state["project_dct"]["model_build"][
551
+ "iterations"
552
+ ] = iterations
553
+
554
+ # st.stop()
555
+
556
+ # build_button = st.session_state["project_dct"]["model_build"]["build_button"] if \
557
+ # "build_button" in st.session_state["project_dct"]["model_build"].keys() else False
558
+ # model_button =st.button('Build Model', on_click=reset_model_result_dct, key='model_build_button')
559
+ # if
560
+ # if model_button:
561
+ if st.button(
562
+ "Build Model",
563
+ on_click=reset_model_result_dct,
564
+ key="model_build_button",
565
+ ):
566
+ if iterations < 1:
567
+ st.error("Please select number of iterations")
568
+ st.stop()
569
+ st.session_state["project_dct"]["model_build"]["build_button"] = True
570
+ st.session_state["iterations"] = iterations
571
+
572
+ # Section 4 - Model
573
+ # st.session_state['media_data'] = st.session_state['media_data'].fillna(method='ffill')
574
+ st.session_state["media_data"] = st.session_state["media_data"].ffill()
575
+ progress_bar = st.progress(0) # Initialize the progress bar
576
+ # time_remaining_text = st.empty() # Create an empty space for time remaining text
577
+ start_time = time.time() # Record the start time
578
+ progress_text = st.empty()
579
+
580
+ # time_elapsed_text = st.empty()
581
+ # for i, selected_features in enumerate(st.session_state["final_selection"][40000:40000 + int(iterations)]):
582
+ # for i, selected_features in enumerate(st.session_state["final_selection"]):
583
+
584
+ if is_panel == True:
585
+ for i, selected_features in enumerate(
586
+ st.session_state["final_selection"][0 : int(iterations)]
587
+ ): # srishti
588
+ df = st.session_state["media_data"]
589
+
590
+ fet = [var for var in selected_features if len(var) > 0]
591
+ inp_vars_str = " + ".join(fet) # new
592
+
593
+ X = df[fet]
594
+ y = df[target_col]
595
+ ss = MinMaxScaler()
596
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
597
+
598
+ X[target_col] = y # Sprint2
599
+ X[panel_col] = df[panel_col] # Sprint2
600
+
601
+ X_train = X.iloc[:train_idx]
602
+ X_test = X.iloc[train_idx:]
603
+ y_train = y.iloc[:train_idx]
604
+ y_test = y.iloc[train_idx:]
605
+
606
+ print(X_train.shape)
607
+ # model = sm.OLS(y_train, X_train).fit()
608
+ md_str = target_col + " ~ " + inp_vars_str
609
+ # md = smf.mixedlm("total_approved_accounts_revenue ~ {}".format(inp_vars_str),
610
+ # data=X_train[[target_col] + fet],
611
+ # groups=X_train[panel_col])
612
+ md = smf.mixedlm(
613
+ md_str,
614
+ data=X_train[[target_col] + fet],
615
+ groups=X_train[panel_col],
616
+ )
617
+ mdf = md.fit()
618
+ predicted_values = mdf.fittedvalues
619
+
620
+ coefficients = mdf.fe_params.to_dict()
621
+ model_positive = [
622
+ col for col in coefficients.keys() if coefficients[col] > 0
623
+ ]
624
+
625
+ pvalues = [var for var in list(mdf.pvalues) if var <= 0.06]
626
+
627
+ if (len(model_positive) / len(selected_features)) > 0 and (
628
+ len(pvalues) / len(selected_features)
629
+ ) >= 0: # srishti - changed just for testing, revert later
630
+ # predicted_values = model.predict(X_train)
631
+ mape = mean_absolute_percentage_error(
632
+ y_train, predicted_values
633
+ )
634
+ r2 = r2_score(y_train, predicted_values)
635
+ adjr2 = 1 - (1 - r2) * (len(y_train) - 1) / (
636
+ len(y_train) - len(selected_features) - 1
637
+ )
638
+
639
+ filename = os.path.join(save_path, f"model_{i}.pkl")
640
+ with open(filename, "wb") as f:
641
+ pickle.dump(mdf, f)
642
+ # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
643
+ # model = pickle.load(file)
644
+
645
+ st.session_state["Model_results"]["Model_object"].append(
646
+ filename
647
+ )
648
+ st.session_state["Model_results"][
649
+ "Model_iteration"
650
+ ].append(i)
651
+ st.session_state["Model_results"]["Feature_set"].append(
652
+ fet
653
+ )
654
+ st.session_state["Model_results"]["MAPE"].append(mape)
655
+ st.session_state["Model_results"]["R2"].append(r2)
656
+ st.session_state["Model_results"]["pos_count"].append(
657
+ len(model_positive)
658
+ )
659
+ st.session_state["Model_results"]["ADJR2"].append(adjr2)
660
+
661
+ current_time = time.time()
662
+ time_taken = current_time - start_time
663
+ time_elapsed_minutes = time_taken / 60
664
+ completed_iterations_text = f"{i + 1}/{iterations}"
665
+ progress_bar.progress((i + 1) / int(iterations))
666
+ progress_text.text(
667
+ f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
668
+ )
669
+ st.write(
670
+ f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
671
+ )
672
+
673
+ else:
674
+
675
+ for i, selected_features in enumerate(
676
+ st.session_state["final_selection"][0 : int(iterations)]
677
+ ): # srishti
678
+ df = st.session_state["media_data"]
679
+
680
+ fet = [var for var in selected_features if len(var) > 0]
681
+ inp_vars_str = " + ".join(fet)
682
+
683
+ X = df[fet]
684
+ y = df[target_col]
685
+ ss = MinMaxScaler()
686
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
687
+ X = sm.add_constant(X)
688
+ X_train = X.iloc[:130]
689
+ X_test = X.iloc[130:]
690
+ y_train = y.iloc[:130]
691
+ y_test = y.iloc[130:]
692
+
693
+ model = sm.OLS(y_train, X_train).fit()
694
+
695
+ coefficients = model.params.to_list()
696
+ model_positive = [coef for coef in coefficients if coef > 0]
697
+ predicted_values = model.predict(X_train)
698
+ pvalues = [var for var in list(model.pvalues) if var <= 0.06]
699
+
700
+ # if (len(model_possitive) / len(selected_features)) > 0.9 and (len(pvalues) / len(selected_features)) >= 0.8:
701
+ if (len(model_positive) / len(selected_features)) > 0 and (
702
+ len(pvalues) / len(selected_features)
703
+ ) >= 0.5: # srishti - changed just for testing, revert later VALID MODEL CRITERIA
704
+ # predicted_values = model.predict(X_train)
705
+ mape = mean_absolute_percentage_error(
706
+ y_train, predicted_values
707
+ )
708
+ adjr2 = model.rsquared_adj
709
+ r2 = model.rsquared
710
+
711
+ filename = os.path.join(save_path, f"model_{i}.pkl")
712
+ with open(filename, "wb") as f:
713
+ pickle.dump(model, f)
714
+ # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
715
+ # model = pickle.load(file)
716
+
717
+ st.session_state["Model_results"]["Model_object"].append(
718
+ filename
719
+ )
720
+ st.session_state["Model_results"][
721
+ "Model_iteration"
722
+ ].append(i)
723
+ st.session_state["Model_results"]["Feature_set"].append(
724
+ fet
725
+ )
726
+ st.session_state["Model_results"]["MAPE"].append(mape)
727
+ st.session_state["Model_results"]["R2"].append(r2)
728
+ st.session_state["Model_results"]["ADJR2"].append(adjr2)
729
+ st.session_state["Model_results"]["pos_count"].append(
730
+ len(model_positive)
731
+ )
732
+
733
+ current_time = time.time()
734
+ time_taken = current_time - start_time
735
+ time_elapsed_minutes = time_taken / 60
736
+ completed_iterations_text = f"{i + 1}/{iterations}"
737
+ progress_bar.progress((i + 1) / int(iterations))
738
+ progress_text.text(
739
+ f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
740
+ )
741
+ st.write(
742
+ f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
743
+ )
744
+
745
+ pd.DataFrame(st.session_state["Model_results"]).to_csv(
746
+ "model_output.csv"
747
+ )
748
+
749
+ def to_percentage(value):
750
+ return f"{value * 100:.1f}%"
751
+
752
+ ## Section 5 - Select Model
753
+ st.title("2. Select Models")
754
+ show_results_defualt = (
755
+ st.session_state["project_dct"]["model_build"]["show_results_check"]
756
+ if st.session_state["project_dct"]["model_build"]["show_results_check"]
757
+ is not None
758
+ else False
759
+ )
760
+ if "tick" not in st.session_state:
761
+ st.session_state["tick"] = False
762
+ if st.checkbox(
763
+ "Show results of top 10 models (based on MAPE and Adj. R2)",
764
+ value=True,
765
+ ):
766
+ st.session_state["project_dct"]["model_build"][
767
+ "show_results_check"
768
+ ] = True
769
+ st.session_state["tick"] = True
770
+ st.write(
771
+ "Select one model iteration to generate performance metrics for it:"
772
+ )
773
+ data = pd.DataFrame(st.session_state["Model_results"])
774
+ data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
775
+ drop=True
776
+ ) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
777
+ data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
778
+ data.drop_duplicates(subset="Model_iteration", inplace=True)
779
+ top_10 = data.head(10)
780
+ top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
781
+ top_10[["MAPE", "R2", "ADJR2"]] = np.round(
782
+ top_10[["MAPE", "R2", "ADJR2"]], 4
783
+ ).applymap(to_percentage)
784
+ top_10_table = top_10[
785
+ ["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
786
+ ]
787
+ # top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
788
+ gd = GridOptionsBuilder.from_dataframe(top_10_table)
789
+ gd.configure_pagination(enabled=True)
790
+
791
+ gd.configure_selection(
792
+ use_checkbox=True,
793
+ selection_mode="single",
794
+ pre_select_all_rows=False,
795
+ pre_selected_rows=[1],
796
+ )
797
+
798
+ gridoptions = gd.build()
799
+
800
+ table = AgGrid(
801
+ top_10,
802
+ gridOptions=gridoptions,
803
+ update_mode=GridUpdateMode.SELECTION_CHANGED,
804
+ )
805
+
806
+ selected_rows = table.selected_rows
807
+ # if st.session_state["selected_rows"] != selected_rows:
808
+ # st.session_state["build_rc_cb"] = False
809
+ st.session_state["selected_rows"] = selected_rows
810
+ # st.write(
811
+ # """
812
+ # ### Filter Results
813
+
814
+ # Use the filters below to refine the displayed model results. This helps in isolating models that do not meet the required business criteria, ensuring only the most relevant models are considered for further analysis. If multiple models meet the criteria, select the first model, as it is considered the best-ranked based on evaluation criteria.
815
+ # """
816
+ # )
817
+
818
+ # data = pd.DataFrame(st.session_state["Model_results"])
819
+ # coefficients_df, data_df = prepare_data_df(data)
820
+
821
+ # # Define the structure of the empty DataFrame
822
+ # filter_df_data = {
823
+ # "Channel Name": pd.Series([], dtype="str"),
824
+ # "Filter Condition": pd.Series([], dtype="str"),
825
+ # "Percent Contribution": pd.Series([], dtype="str"),
826
+ # }
827
+ # filter_df = pd.DataFrame(filter_df_data)
828
+
829
+ # filter_df_editable = st.data_editor(
830
+ # filter_df,
831
+ # column_config={
832
+ # "Channel Name": st.column_config.SelectboxColumn(
833
+ # options=list(coefficients_df.columns),
834
+ # required=True,
835
+ # default="Base Sales",
836
+ # ),
837
+ # "Filter Condition": st.column_config.SelectboxColumn(
838
+ # options=[
839
+ # "<",
840
+ # ">",
841
+ # "=",
842
+ # "<=",
843
+ # ">=",
844
+ # ],
845
+ # required=True,
846
+ # default=">",
847
+ # ),
848
+ # "Percent Contribution": st.column_config.NumberColumn(
849
+ # required=True, default=0
850
+ # ),
851
+ # },
852
+ # hide_index=True,
853
+ # use_container_width=True,
854
+ # num_rows="dynamic",
855
+ # )
856
+
857
+ # # Apply filters from filter_df_editable to data_df
858
+ # if "filtered_df" not in st.session_state:
859
+ # st.session_state["filtered_df"] = data_df.copy()
860
+
861
+ # if st.button("Filter", args=(data_df)):
862
+ # st.session_state["filtered_df"] = data_df.copy()
863
+ # for index, row in filter_df_editable.iterrows():
864
+ # channel_name = row["Channel Name"]
865
+ # condition = row["Filter Condition"]
866
+ # value = row["Percent Contribution"]
867
+
868
+ # if channel_name in st.session_state["filtered_df"].columns:
869
+ # # Construct the query string based on the condition
870
+ # query_string = f"`{channel_name}` {condition} {value}"
871
+ # st.session_state["filtered_df"] = st.session_state["filtered_df"].query(
872
+ # query_string
873
+ # )
874
+
875
+ # # After filtering, check if the DataFrame is empty
876
+ # if st.session_state["filtered_df"].empty:
877
+ # # Display a warning message if no rows meet the filter criteria
878
+ # st.warning("No model meets the specified filter conditions", icon="⚠️")
879
+ # st.stop() # Optionally stop further execution
880
+
881
+ # # Output the filtered data
882
+ # st.write("Select one model iteration to generate performance metrics for it:")
883
+ # st.dataframe(st.session_state["filtered_df"], hide_index=True)
884
+
885
+ #############################################################################################
886
+
887
+ # top_10 = data.head(10)
888
+ # top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
889
+ # top_10[["MAPE", "R2", "ADJR2"]] = np.round(
890
+ # top_10[["MAPE", "R2", "ADJR2"]], 4
891
+ # ).applymap(to_percentage)
892
+
893
+ # top_10_table = top_10[
894
+ # ["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
895
+ # + list(coefficients_df.columns)
896
+ # ]
897
+ # top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
898
+
899
+ # gd = GridOptionsBuilder.from_dataframe(top_10_table)
900
+ # gd.configure_pagination(enabled=True)
901
+
902
+ # gd.configure_selection(
903
+ # use_checkbox=True,
904
+ # selection_mode="single",
905
+ # pre_select_all_rows=False,
906
+ # pre_selected_rows=[1],
907
+ # )
908
+
909
+ # gridoptions = gd.build()
910
+
911
+ # table = AgGrid(
912
+ # top_10, gridOptions=gridoptions, update_mode=GridUpdateMode.SELECTION_CHANGED
913
+ # )
914
+
915
+ # selected_rows = table.selected_rows
916
+
917
+ # gd = GridOptionsBuilder.from_dataframe(st.session_state["filtered_df"])
918
+ # gd.configure_pagination(enabled=True)
919
+
920
+ # gd.configure_selection(
921
+ # use_checkbox=True,
922
+ # selection_mode="single",
923
+ # pre_select_all_rows=False,
924
+ # pre_selected_rows=[1],
925
+ # )
926
+
927
+ # gridoptions = gd.build()
928
+
929
+ # table = AgGrid(
930
+ # st.session_state["filtered_df"],
931
+ # gridOptions=gridoptions,
932
+ # update_mode=GridUpdateMode.SELECTION_CHANGED,
933
+ # )
934
+
935
+ # selected_rows_table = table.selected_rows
936
+
937
+ # Dataframe
938
+ # display_df = st.session_state.filtered_df.rename(columns={"Rank": "Model Number"})
939
+ # st.dataframe(display_df, hide_index=True)
940
+
941
+ # min_rank = min(st.session_state["filtered_df"]["Rank"])
942
+ # max_rank = max(st.session_state["filtered_df"]["Rank"])
943
+ # available_ranks = st.session_state["filtered_df"]["Rank"].unique()
944
+
945
+ # # Get row number input from the user
946
+ # rank_number = st.number_input(
947
+ # "Select model by Model Number:",
948
+ # min_value=min_rank,
949
+ # max_value=max_rank,
950
+ # value=min_rank,
951
+ # step=1,
952
+ # )
953
+
954
+ # # Get row
955
+ # if rank_number not in available_ranks:
956
+ # st.warning("No model is available with selected Rank", icon="⚠️")
957
+ # st.stop()
958
+
959
+ # Find the row that matches the selected rank
960
+ # selected_rows = st.session_state["filtered_df"][
961
+ # st.session_state["filtered_df"]["Rank"] == rank_number
962
+ # ]
963
+
964
+ # selected_rows = [
965
+ # (selected_rows.to_dict(orient="records")[0] if not selected_rows.empty else {})
966
+ # ]
967
+
968
+ # if st.session_state["selected_rows"] != selected_rows:
969
+ # st.session_state["build_rc_cb"] = False
970
+ st.session_state["selected_rows"] = selected_rows
971
+ if "Model" not in st.session_state:
972
+ st.session_state["Model"] = {}
973
+
974
+ # Section 6 - Display Results
975
+
976
+
977
+ # Section 6 - Display Results
978
+
979
+ if len(selected_rows) > 0:
980
+ st.header("2.1 Results Summary")
981
+
982
+ model_object = data[
983
+ data["Model_iteration"] == selected_rows[0]["Model_iteration"]
984
+ ]["Model_object"]
985
+ features_set = data[
986
+ data["Model_iteration"] == selected_rows[0]["Model_iteration"]
987
+ ]["Feature_set"]
988
+
989
+ with open(str(model_object.values[0]), "rb") as file:
990
+ # print(file)
991
+ model = pickle.load(file)
992
+ st.write(model.summary())
993
+ st.header("2.2 Actual vs. Predicted Plot")
994
+
995
+ if is_panel:
996
+ df = st.session_state["media_data"]
997
+ X = df[features_set.values[0]]
998
+ y = df[target_col]
999
+
1000
+ ss = MinMaxScaler()
1001
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
1002
+
1003
+ # Sprint2 changes
1004
+ X[target_col] = y # new
1005
+ X[panel_col] = df[panel_col]
1006
+ X[date_col] = date
1007
+
1008
+ X_train = X.iloc[:train_idx]
1009
+ X_test = X.iloc[train_idx:].reset_index(drop=True)
1010
+ y_train = y.iloc[:train_idx]
1011
+ y_test = y.iloc[train_idx:].reset_index(drop=True)
1012
+
1013
+ test_spends = spends_data[
1014
+ train_idx:
1015
+ ] # Sprint3 - test spends for resp curves
1016
+ random_eff_df = get_random_effects(
1017
+ media_data, panel_col, model
1018
+ )
1019
+ train_pred = model.fittedvalues
1020
+ test_pred = mdf_predict(X_test, model, random_eff_df)
1021
+ print("__" * 20, test_pred.isna().sum())
1022
+
1023
+ else:
1024
+ df = st.session_state["media_data"]
1025
+ X = df[features_set.values[0]]
1026
+ y = df[target_col]
1027
+
1028
+ ss = MinMaxScaler()
1029
+ X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
1030
+ X = sm.add_constant(X)
1031
+
1032
+ X[date_col] = date
1033
+
1034
+ X_train = X.iloc[:130]
1035
+ X_test = X.iloc[130:].reset_index(drop=True)
1036
+ y_train = y.iloc[:130]
1037
+ y_test = y.iloc[130:].reset_index(drop=True)
1038
+
1039
+ test_spends = spends_data[
1040
+ 130:
1041
+ ] # Sprint3 - test spends for resp curves
1042
+ train_pred = model.predict(
1043
+ X_train[features_set.values[0] + ["const"]]
1044
+ )
1045
+ test_pred = model.predict(
1046
+ X_test[features_set.values[0] + ["const"]]
1047
+ )
1048
+
1049
+ # save x test to test - srishti
1050
+ # x_test_to_save = X_test.copy()
1051
+ # x_test_to_save['Actuals'] = y_test
1052
+ # x_test_to_save['Predictions'] = test_pred
1053
+ #
1054
+ # x_train_to_save = X_train.copy()
1055
+ # x_train_to_save['Actuals'] = y_train
1056
+ # x_train_to_save['Predictions'] = train_pred
1057
+ #
1058
+ # x_train_to_save.to_csv('Test/x_train_to_save.csv', index=False)
1059
+ # x_test_to_save.to_csv('Test/x_test_to_save.csv', index=False)
1060
+
1061
+ st.session_state["X"] = X_train
1062
+ st.session_state["features_set"] = features_set.values[0]
1063
+ print(
1064
+ "**" * 20, "selected model features : ", features_set.values[0]
1065
+ )
1066
+ metrics_table, line, actual_vs_predicted_plot = (
1067
+ plot_actual_vs_predicted(
1068
+ X_train[date_col],
1069
+ y_train,
1070
+ train_pred,
1071
+ model,
1072
+ target_column=sel_target_col,
1073
+ is_panel=is_panel,
1074
+ )
1075
+ ) # Sprint2
1076
+
1077
+ st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
1078
+
1079
+ st.markdown("## 2.3 Residual Analysis")
1080
+ columns = st.columns(2)
1081
+ with columns[0]:
1082
+ fig = plot_residual_predicted(
1083
+ y_train, train_pred, X_train
1084
+ ) # Sprint2
1085
+ st.plotly_chart(fig)
1086
+
1087
+ with columns[1]:
1088
+ st.empty()
1089
+ fig = qqplot(y_train, train_pred) # Sprint2
1090
+ st.plotly_chart(fig)
1091
+
1092
+ with columns[0]:
1093
+ fig = residual_distribution(y_train, train_pred) # Sprint2
1094
+ st.pyplot(fig)
1095
+
1096
+ vif_data = pd.DataFrame()
1097
+ # X=X.drop('const',axis=1)
1098
+ X_train_orig = (
1099
+ X_train.copy()
1100
+ ) # Sprint2 -- creating a copy of xtrain. Later deleting panel, target & date from xtrain
1101
+ del_col_list = list(
1102
+ set([target_col, panel_col, date_col]).intersection(
1103
+ set(X_train.columns)
1104
+ )
1105
+ )
1106
+ X_train.drop(columns=del_col_list, inplace=True) # Sprint2
1107
+
1108
+ vif_data["Variable"] = X_train.columns
1109
+ vif_data["VIF"] = [
1110
+ variance_inflation_factor(X_train.values, i)
1111
+ for i in range(X_train.shape[1])
1112
+ ]
1113
+ vif_data.sort_values(by=["VIF"], ascending=False, inplace=True)
1114
+ vif_data = np.round(vif_data)
1115
+ vif_data["VIF"] = vif_data["VIF"].astype(float)
1116
+ st.header("2.4 Variance Inflation Factor (VIF)")
1117
+ # st.dataframe(vif_data)
1118
+ color_mapping = {
1119
+ "darkgreen": (vif_data["VIF"] < 3),
1120
+ "orange": (vif_data["VIF"] >= 3) & (vif_data["VIF"] <= 10),
1121
+ "darkred": (vif_data["VIF"] > 10),
1122
+ }
1123
+
1124
+ # Create a horizontal bar plot
1125
+ fig, ax = plt.subplots()
1126
+ fig.set_figwidth(10) # Adjust the width of the figure as needed
1127
+
1128
+ # Sort the bars by descending VIF values
1129
+ vif_data = vif_data.sort_values(by="VIF", ascending=False)
1130
+
1131
+ # Iterate through the color mapping and plot bars with corresponding colors
1132
+ for color, condition in color_mapping.items():
1133
+ subset = vif_data[condition]
1134
+ bars = ax.barh(
1135
+ subset["Variable"], subset["VIF"], color=color, label=color
1136
+ )
1137
+
1138
+ # Add text annotations on top of the bars
1139
+ for bar in bars:
1140
+ width = bar.get_width()
1141
+ ax.annotate(
1142
+ f"{width:}",
1143
+ xy=(width, bar.get_y() + bar.get_height() / 2),
1144
+ xytext=(5, 0),
1145
+ textcoords="offset points",
1146
+ va="center",
1147
+ )
1148
+
1149
+ # Customize the plot
1150
+ ax.set_xlabel("VIF Values")
1151
+ # ax.set_title('2.4 Variance Inflation Factor (VIF)')
1152
+ # ax.legend(loc='upper right')
1153
+
1154
+ # Display the plot in Streamlit
1155
+ st.pyplot(fig)
1156
+
1157
+ with st.expander("Results Summary Test data"):
1158
+ # ss = MinMaxScaler()
1159
+ # X_test = pd.DataFrame(ss.fit_transform(X_test), columns=X_test.columns)
1160
+ st.header("2.2 Actual vs. Predicted Plot")
1161
+
1162
+ metrics_table, line, actual_vs_predicted_plot = (
1163
+ plot_actual_vs_predicted(
1164
+ X_test[date_col],
1165
+ y_test,
1166
+ test_pred,
1167
+ model,
1168
+ target_column=sel_target_col,
1169
+ is_panel=is_panel,
1170
+ )
1171
+ ) # Sprint2
1172
+
1173
+ st.plotly_chart(
1174
+ actual_vs_predicted_plot, use_container_width=True
1175
+ )
1176
+
1177
+ st.markdown("## 2.3 Residual Analysis")
1178
+ columns = st.columns(2)
1179
+ with columns[0]:
1180
+ fig = plot_residual_predicted(
1181
+ y, test_pred, X_test
1182
+ ) # Sprint2
1183
+ st.plotly_chart(fig)
1184
+
1185
+ with columns[1]:
1186
+ st.empty()
1187
+ fig = qqplot(y, test_pred) # Sprint2
1188
+ st.plotly_chart(fig)
1189
+
1190
+ with columns[0]:
1191
+ fig = residual_distribution(y, test_pred) # Sprint2
1192
+ st.pyplot(fig)
1193
+
1194
+ value = False
1195
+ save_button_model = st.checkbox(
1196
+ "Save this model to tune", key="build_rc_cb"
1197
+ ) # , on_click=set_save())
1198
+
1199
+ if save_button_model:
1200
+ mod_name = st.text_input("Enter model name")
1201
+ if len(mod_name) > 0:
1202
+ mod_name = (
1203
+ mod_name + "__" + target_col
1204
+ ) # Sprint4 - adding target col to model name
1205
+ if is_panel:
1206
+ pred_train = model.fittedvalues
1207
+ pred_test = mdf_predict(X_test, model, random_eff_df)
1208
+ else:
1209
+ st.session_state["features_set"] = st.session_state[
1210
+ "features_set"
1211
+ ] + ["const"]
1212
+ pred_train = model.predict(
1213
+ X_train_orig[st.session_state["features_set"]]
1214
+ )
1215
+ pred_test = model.predict(
1216
+ X_test[st.session_state["features_set"]]
1217
+ )
1218
+
1219
+ st.session_state["Model"][mod_name] = {
1220
+ "Model_object": model,
1221
+ "feature_set": st.session_state["features_set"],
1222
+ "X_train": X_train_orig,
1223
+ "X_test": X_test,
1224
+ "y_train": y_train,
1225
+ "y_test": y_test,
1226
+ "pred_train": pred_train,
1227
+ "pred_test": pred_test,
1228
+ }
1229
+ st.session_state["X_train"] = X_train_orig
1230
+ st.session_state["X_test_spends"] = test_spends
1231
+ st.session_state["saved_model_names"].append(mod_name)
1232
+ # Sprint3 additions
1233
+ if is_panel:
1234
+ random_eff_df = get_random_effects(
1235
+ media_data, panel_col, model
1236
+ )
1237
+ st.session_state["random_effects"] = random_eff_df
1238
+
1239
+ with open(
1240
+ os.path.join(
1241
+ st.session_state["project_path"], "best_models.pkl"
1242
+ ),
1243
+ "wb",
1244
+ ) as f:
1245
+ pickle.dump(st.session_state["Model"], f)
1246
+ st.success(
1247
+ mod_name
1248
+ + " model saved! Proceed to the next page to tune the model"
1249
+ )
1250
+
1251
+ urm = st.session_state["used_response_metrics"]
1252
+ urm.append(sel_target_col)
1253
+ st.session_state["used_response_metrics"] = list(
1254
+ set(urm)
1255
+ )
1256
+ mod_name = ""
1257
+ # Sprint4 - add the formatted name of the target col to used resp metrics
1258
+ value = False
1259
+
1260
+ st.session_state["project_dct"]["model_build"][
1261
+ "session_state_saved"
1262
+ ] = {}
1263
+ for key in [
1264
+ "Model",
1265
+ "bin_dict",
1266
+ "used_response_metrics",
1267
+ "date",
1268
+ "saved_model_names",
1269
+ "media_data",
1270
+ "X_test_spends",
1271
+ ]:
1272
+ st.session_state["project_dct"]["model_build"][
1273
+ "session_state_saved"
1274
+ ][key] = st.session_state[key]
1275
+
1276
+ project_dct_path = os.path.join(
1277
+ st.session_state["project_path"], "project_dct.pkl"
1278
+ )
1279
+ with open(project_dct_path, "wb") as f:
1280
+ pickle.dump(st.session_state["project_dct"], f)
1281
+
1282
+ update_db("4_Model_Build.py")
1283
+
1284
+ st.toast("💾 Saved Successfully!")
1285
+ else:
1286
+ st.session_state["project_dct"]["model_build"][
1287
+ "show_results_check"
1288
+ ] = False
pages/5_Model_Tuning.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MMO Build Sprint 3
3
+ date :
4
+ changes : capability to tune MixedLM as well as simple LR in the same page
5
+ """
6
+
7
+ import os
8
+
9
+ import streamlit as st
10
+ import pandas as pd
11
+ from Eda_functions import format_numbers
12
+ import pickle
13
+ from utilities import set_header, load_local_css
14
+ import statsmodels.api as sm
15
+ import re
16
+ from sklearn.preprocessing import MinMaxScaler
17
+ import matplotlib.pyplot as plt
18
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
19
+ import yaml
20
+ from yaml import SafeLoader
21
+ import streamlit_authenticator as stauth
22
+
23
+ st.set_option("deprecation.showPyplotGlobalUse", False)
24
+ import statsmodels.formula.api as smf
25
+ from Data_prep_functions import *
26
+ import sqlite3
27
+ from utilities import update_db
28
+
29
+ # for i in ["model_tuned", "X_train_tuned", "X_test_tuned", "tuned_model_features", "tuned_model", "tuned_model_dict"] :
30
+
31
+ st.set_page_config(
32
+ page_title="Model Tuning",
33
+ page_icon=":shark:",
34
+ layout="wide",
35
+ initial_sidebar_state="collapsed",
36
+ )
37
+ load_local_css("styles.css")
38
+ set_header()
39
+ # Check for authentication status
40
+ for k, v in st.session_state.items():
41
+ # print(k, v)
42
+ if k not in [
43
+ "logout",
44
+ "login",
45
+ "config",
46
+ "build_tuned_model",
47
+ ] and not k.startswith("FormSubmitter"):
48
+ st.session_state[k] = v
49
+ with open("config.yaml") as file:
50
+ config = yaml.load(file, Loader=SafeLoader)
51
+ st.session_state["config"] = config
52
+ authenticator = stauth.Authenticate(
53
+ config["credentials"],
54
+ config["cookie"]["name"],
55
+ config["cookie"]["key"],
56
+ config["cookie"]["expiry_days"],
57
+ config["preauthorized"],
58
+ )
59
+ st.session_state["authenticator"] = authenticator
60
+ name, authentication_status, username = authenticator.login("Login", "main")
61
+ auth_status = st.session_state.get("authentication_status")
62
+
63
+ if auth_status == True:
64
+ authenticator.logout("Logout", "main")
65
+ is_state_initiaized = st.session_state.get("initialized", False)
66
+
67
+ if "project_dct" not in st.session_state:
68
+ st.error("Please load a project from Home page")
69
+ st.stop()
70
+
71
+ if not os.path.exists(
72
+ os.path.join(st.session_state["project_path"], "best_models.pkl")
73
+ ):
74
+ st.error("Please save a model before tuning")
75
+ st.stop()
76
+
77
+ conn = sqlite3.connect(
78
+ r"DB/User.db", check_same_thread=False
79
+ ) # connection with sql db
80
+ c = conn.cursor()
81
+
82
+ if not is_state_initiaized:
83
+ if "session_name" not in st.session_state:
84
+ st.session_state["session_name"] = None
85
+
86
+ if (
87
+ "session_state_saved"
88
+ in st.session_state["project_dct"]["model_build"].keys()
89
+ ):
90
+ for key in [
91
+ "Model",
92
+ "date",
93
+ "saved_model_names",
94
+ "media_data",
95
+ "X_test_spends",
96
+ ]:
97
+ if key not in st.session_state:
98
+ st.session_state[key] = st.session_state["project_dct"][
99
+ "model_build"
100
+ ]["session_state_saved"][key]
101
+ st.session_state["bin_dict"] = st.session_state["project_dct"][
102
+ "model_build"
103
+ ]["session_state_saved"]["bin_dict"]
104
+ if (
105
+ "used_response_metrics" not in st.session_state
106
+ or st.session_state["used_response_metrics"] == []
107
+ ):
108
+ st.session_state["used_response_metrics"] = st.session_state[
109
+ "project_dct"
110
+ ]["model_build"]["session_state_saved"][
111
+ "used_response_metrics"
112
+ ]
113
+ else:
114
+ st.error("Please load a session with a built model")
115
+ st.stop()
116
+
117
+ # if 'sel_model' not in st.session_state["project_dct"]["model_tuning"].keys():
118
+ # st.session_state["project_dct"]["model_tuning"]['sel_model']= {}
119
+
120
+ for key in ["select_all_flags_check", "selected_flags", "sel_model"]:
121
+ if key not in st.session_state["project_dct"]["model_tuning"].keys():
122
+ st.session_state["project_dct"]["model_tuning"][key] = {}
123
+ # Sprint3
124
+ # is_panel = st.session_state['is_panel']
125
+ # panel_col = 'markets' # set the panel column
126
+ date_col = "date"
127
+
128
+ panel_col = [
129
+ col.lower()
130
+ .replace(".", "_")
131
+ .replace("@", "_")
132
+ .replace(" ", "_")
133
+ .replace("-", "")
134
+ .replace(":", "")
135
+ .replace("__", "_")
136
+ for col in st.session_state["bin_dict"]["Panel Level 1"]
137
+ ][
138
+ 0
139
+ ] # set the panel column
140
+ is_panel = True if len(panel_col) > 0 else False
141
+
142
+ # flag indicating there is not tuned model till now
143
+
144
+ # Sprint4 - model tuned dict
145
+ if "Model_Tuned" not in st.session_state:
146
+ st.session_state["Model_Tuned"] = {}
147
+
148
+ st.title("1. Model Tuning")
149
+
150
+ if "is_tuned_model" not in st.session_state:
151
+ st.session_state["is_tuned_model"] = {}
152
+ # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
153
+ if (
154
+ "used_response_metrics" in st.session_state
155
+ and st.session_state["used_response_metrics"] != []
156
+ ):
157
+ default_target_idx = (
158
+ st.session_state["project_dct"]["model_tuning"].get(
159
+ "sel_target_col", None
160
+ )
161
+ if st.session_state["project_dct"]["model_tuning"].get(
162
+ "sel_target_col", None
163
+ )
164
+ is not None
165
+ else st.session_state["used_response_metrics"][0]
166
+ )
167
+
168
+ def format_display(inp):
169
+ return inp.title().replace("_", " ").strip()
170
+
171
+ sel_target_col = st.selectbox(
172
+ "Select the response metric",
173
+ st.session_state["used_response_metrics"],
174
+ index=st.session_state["used_response_metrics"].index(
175
+ default_target_idx
176
+ ),
177
+ format_func=format_display,
178
+ )
179
+ target_col = (
180
+ sel_target_col.lower()
181
+ .replace(" ", "_")
182
+ .replace("-", "")
183
+ .replace(":", "")
184
+ .replace("__", "_")
185
+ )
186
+ st.session_state["project_dct"]["model_tuning"][
187
+ "sel_target_col"
188
+ ] = sel_target_col
189
+
190
+ else:
191
+ sel_target_col = "Total Approved Accounts - Revenue"
192
+ target_col = "total_approved_accounts_revenue"
193
+
194
+ # Sprint4 - Look through all saved models, only show saved models of the sel resp metric (target_col)
195
+ # saved_models = st.session_state['saved_model_names']
196
+ with open(
197
+ os.path.join(st.session_state["project_path"], "best_models.pkl"), "rb"
198
+ ) as file:
199
+ model_dict = pickle.load(file)
200
+
201
+ saved_models = model_dict.keys()
202
+ required_saved_models = [
203
+ m.split("__")[0]
204
+ for m in saved_models
205
+ if m.split("__")[1] == target_col
206
+ ]
207
+
208
+ if len(required_saved_models) > 0:
209
+ default_model_idx = st.session_state["project_dct"]["model_tuning"][
210
+ "sel_model"
211
+ ].get(sel_target_col, required_saved_models[0])
212
+ sel_model = st.selectbox(
213
+ "Select the model to tune",
214
+ required_saved_models,
215
+ index=required_saved_models.index(default_model_idx),
216
+ )
217
+ else:
218
+ default_model_idx = st.session_state["project_dct"]["model_tuning"][
219
+ "sel_model"
220
+ ].get(sel_target_col, 0)
221
+ sel_model = st.selectbox(
222
+ "Select the model to tune", required_saved_models
223
+ )
224
+
225
+ st.session_state["project_dct"]["model_tuning"]["sel_model"][
226
+ sel_target_col
227
+ ] = default_model_idx
228
+
229
+ sel_model_dict = model_dict[
230
+ sel_model + "__" + target_col
231
+ ] # Sprint4 - get the model obj of the selected model
232
+
233
+ X_train = sel_model_dict["X_train"]
234
+ X_test = sel_model_dict["X_test"]
235
+ y_train = sel_model_dict["y_train"]
236
+ y_test = sel_model_dict["y_test"]
237
+ df = st.session_state["media_data"]
238
+
239
+ if "selected_model" not in st.session_state:
240
+ st.session_state["selected_model"] = 0
241
+
242
+ st.markdown("### 1.1 Event Flags")
243
+ st.markdown(
244
+ "Helps in quantifying the impact of specific occurrences of events"
245
+ )
246
+
247
+ flag_expander_default = (
248
+ st.session_state["project_dct"]["model_tuning"].get(
249
+ "flag_expander", None
250
+ )
251
+ if st.session_state["project_dct"]["model_tuning"].get(
252
+ "flag_expander", None
253
+ )
254
+ is not None
255
+ else False
256
+ )
257
+
258
+ with st.expander("Apply Event Flags", flag_expander_default):
259
+ st.session_state["project_dct"]["model_tuning"]["flag_expander"] = True
260
+
261
+ model = sel_model_dict["Model_object"]
262
+ date = st.session_state["date"]
263
+ date = pd.to_datetime(date)
264
+ X_train = sel_model_dict["X_train"]
265
+
266
+ # features_set= model_dict[st.session_state["selected_model"]]['feature_set']
267
+ features_set = sel_model_dict["feature_set"]
268
+
269
+ col = st.columns(3)
270
+ min_date = min(date)
271
+ max_date = max(date)
272
+
273
+ start_date_default = (
274
+ st.session_state["project_dct"]["model_tuning"].get(
275
+ "start_date_default"
276
+ )
277
+ if st.session_state["project_dct"]["model_tuning"].get(
278
+ "start_date_default"
279
+ )
280
+ is not None
281
+ else min_date
282
+ )
283
+ end_date_default = (
284
+ st.session_state["project_dct"]["model_tuning"].get(
285
+ "end_date_default"
286
+ )
287
+ if st.session_state["project_dct"]["model_tuning"].get(
288
+ "end_date_default"
289
+ )
290
+ is not None
291
+ else max_date
292
+ )
293
+ with col[0]:
294
+ start_date = st.date_input(
295
+ "Select Start Date",
296
+ start_date_default,
297
+ min_value=min_date,
298
+ max_value=max_date,
299
+ )
300
+ with col[1]:
301
+ end_date_default = (
302
+ end_date_default
303
+ if end_date_default >= start_date
304
+ else start_date
305
+ )
306
+ end_date = st.date_input(
307
+ "Select End Date",
308
+ end_date_default,
309
+ min_value=max(min_date, start_date),
310
+ max_value=max_date,
311
+ )
312
+ with col[2]:
313
+ repeat_default = (
314
+ st.session_state["project_dct"]["model_tuning"].get(
315
+ "repeat_default"
316
+ )
317
+ if st.session_state["project_dct"]["model_tuning"].get(
318
+ "repeat_default"
319
+ )
320
+ is not None
321
+ else "No"
322
+ )
323
+ repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1
324
+ repeat = st.selectbox(
325
+ "Repeat Annually", ["Yes", "No"], index=repeat_default_idx
326
+ )
327
+ st.session_state["project_dct"]["model_tuning"][
328
+ "start_date_default"
329
+ ] = start_date
330
+ st.session_state["project_dct"]["model_tuning"][
331
+ "end_date_default"
332
+ ] = end_date
333
+ st.session_state["project_dct"]["model_tuning"][
334
+ "repeat_default"
335
+ ] = repeat
336
+
337
+ if repeat == "Yes":
338
+ repeat = True
339
+ else:
340
+ repeat = False
341
+
342
+ if "Flags" not in st.session_state:
343
+ st.session_state["Flags"] = {}
344
+ if "flags" in st.session_state["project_dct"]["model_tuning"].keys():
345
+ st.session_state["Flags"] = st.session_state["project_dct"][
346
+ "model_tuning"
347
+ ]["flags"]
348
+ # print("**"*50)
349
+ # print(y_train)
350
+ # print("**"*50)
351
+ # print(model.fittedvalues)
352
+ if is_panel: # Sprint3
353
+ met, line_values, fig_flag = plot_actual_vs_predicted(
354
+ X_train[date_col],
355
+ y_train,
356
+ model.fittedvalues,
357
+ model,
358
+ target_column=sel_target_col,
359
+ flag=(start_date, end_date),
360
+ repeat_all_years=repeat,
361
+ is_panel=True,
362
+ )
363
+ st.plotly_chart(fig_flag, use_container_width=True)
364
+
365
+ # create flag on test
366
+ met, test_line_values, fig_flag = plot_actual_vs_predicted(
367
+ X_test[date_col],
368
+ y_test,
369
+ sel_model_dict["pred_test"],
370
+ model,
371
+ target_column=sel_target_col,
372
+ flag=(start_date, end_date),
373
+ repeat_all_years=repeat,
374
+ is_panel=True,
375
+ )
376
+
377
+ else:
378
+ pred_train = model.predict(X_train[features_set])
379
+ met, line_values, fig_flag = plot_actual_vs_predicted(
380
+ X_train[date_col],
381
+ y_train,
382
+ pred_train,
383
+ model,
384
+ flag=(start_date, end_date),
385
+ repeat_all_years=repeat,
386
+ is_panel=False,
387
+ )
388
+ st.plotly_chart(fig_flag, use_container_width=True)
389
+
390
+ pred_test = model.predict(X_test[features_set])
391
+ met, test_line_values, fig_flag = plot_actual_vs_predicted(
392
+ X_test[date_col],
393
+ y_test,
394
+ pred_test,
395
+ model,
396
+ flag=(start_date, end_date),
397
+ repeat_all_years=repeat,
398
+ is_panel=False,
399
+ )
400
+ flag_name = "f1_flag"
401
+ flag_name = st.text_input("Enter Flag Name")
402
+ # Sprint4 - add selected target col to flag name
403
+ if st.button("Update flag"):
404
+ st.session_state["Flags"][flag_name + "__" + target_col] = {}
405
+ st.session_state["Flags"][flag_name + "__" + target_col][
406
+ "train"
407
+ ] = line_values
408
+ st.session_state["Flags"][flag_name + "__" + target_col][
409
+ "test"
410
+ ] = test_line_values
411
+ st.success(f'{flag_name + "__" + target_col} stored')
412
+
413
+ st.session_state["project_dct"]["model_tuning"]["flags"] = (
414
+ st.session_state["Flags"]
415
+ )
416
+ # Sprint4 - only show flag created for the particular target col
417
+ if st.session_state["Flags"] is None:
418
+ st.session_state["Flags"] = {}
419
+ target_model_flags = [
420
+ f.split("__")[0]
421
+ for f in st.session_state["Flags"].keys()
422
+ if f.split("__")[1] == target_col
423
+ ]
424
+ options = list(target_model_flags)
425
+ selected_options = []
426
+ num_columns = 4
427
+ num_rows = -(-len(options) // num_columns)
428
+
429
+ tick = False
430
+ if st.checkbox(
431
+ "Select all",
432
+ value=st.session_state["project_dct"]["model_tuning"][
433
+ "select_all_flags_check"
434
+ ].get(sel_target_col, False),
435
+ ):
436
+ tick = True
437
+ st.session_state["project_dct"]["model_tuning"][
438
+ "select_all_flags_check"
439
+ ][sel_target_col] = True
440
+ else:
441
+ st.session_state["project_dct"]["model_tuning"][
442
+ "select_all_flags_check"
443
+ ][sel_target_col] = False
444
+ selection_defualts = st.session_state["project_dct"]["model_tuning"][
445
+ "selected_flags"
446
+ ].get(sel_target_col, [])
447
+ selected_options = selection_defualts
448
+ for row in range(num_rows):
449
+ cols = st.columns(num_columns)
450
+ for col in cols:
451
+ if options:
452
+ option = options.pop(0)
453
+ option_default = (
454
+ True if option in selection_defualts else False
455
+ )
456
+ selected = col.checkbox(option, value=(tick or option_default))
457
+ if selected:
458
+ selected_options.append(option)
459
+ st.session_state["project_dct"]["model_tuning"]["selected_flags"][
460
+ sel_target_col
461
+ ] = selected_options
462
+
463
+ st.markdown("### 1.2 Select Parameters to Apply")
464
+ parameters = st.columns(3)
465
+ with parameters[0]:
466
+ Trend = st.checkbox(
467
+ "**Trend**",
468
+ value=st.session_state["project_dct"]["model_tuning"].get(
469
+ "trend_check", False
470
+ ),
471
+ )
472
+ st.markdown(
473
+ "Helps account for long-term trends or seasonality that could influence advertising effectiveness"
474
+ )
475
+ with parameters[1]:
476
+ week_number = st.checkbox(
477
+ "**Week_number**",
478
+ value=st.session_state["project_dct"]["model_tuning"].get(
479
+ "week_num_check", False
480
+ ),
481
+ )
482
+ st.markdown(
483
+ "Assists in detecting and incorporating weekly patterns or seasonality"
484
+ )
485
+ with parameters[2]:
486
+ sine_cosine = st.checkbox(
487
+ "**Sine and Cosine Waves**",
488
+ value=st.session_state["project_dct"]["model_tuning"].get(
489
+ "sine_cosine_check", False
490
+ ),
491
+ )
492
+ st.markdown(
493
+ "Helps in capturing cyclical patterns or seasonality in the data"
494
+ )
495
+ #
496
+ # def get_tuned_model():
497
+ # st.session_state['build_tuned_model']=True
498
+
499
+ if st.button(
500
+ "Build model with Selected Parameters and Flags",
501
+ key="build_tuned_model",use_container_width=True
502
+ ):
503
+ new_features = features_set
504
+ st.header("2.1 Results Summary")
505
+ # date=list(df.index)
506
+ # df = df.reset_index(drop=True)
507
+ # X_train=df[features_set]
508
+ ss = MinMaxScaler()
509
+ if is_panel == True:
510
+ X_train_tuned = X_train[features_set]
511
+ # X_train_tuned = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
512
+ X_train_tuned[target_col] = X_train[target_col]
513
+ X_train_tuned[date_col] = X_train[date_col]
514
+ X_train_tuned[panel_col] = X_train[panel_col]
515
+
516
+ X_test_tuned = X_test[features_set]
517
+ # X_test_tuned = pd.DataFrame(ss.transform(X), columns=X.columns)
518
+ X_test_tuned[target_col] = X_test[target_col]
519
+ X_test_tuned[date_col] = X_test[date_col]
520
+ X_test_tuned[panel_col] = X_test[panel_col]
521
+
522
+ else:
523
+ X_train_tuned = X_train[features_set]
524
+ # X_train_tuned = pd.DataFrame(ss.fit_transform(X_train_tuned), columns=X_train_tuned.columns)
525
+
526
+ X_test_tuned = X_test[features_set]
527
+ # X_test_tuned = pd.DataFrame(ss.transform(X_test_tuned), columns=X_test_tuned.columns)
528
+
529
+ for flag in selected_options:
530
+ # Spirnt4 - added target_col in flag name
531
+ X_train_tuned[flag] = st.session_state["Flags"][
532
+ flag + "__" + target_col
533
+ ]["train"]
534
+ X_test_tuned[flag] = st.session_state["Flags"][
535
+ flag + "__" + target_col
536
+ ]["test"]
537
+
538
+ # test
539
+ # X_train_tuned.to_csv("Test/X_train_tuned_flag.csv",index=False)
540
+ # X_test_tuned.to_csv("Test/X_test_tuned_flag.csv",index=False)
541
+
542
+ # print("()()"*20,flag, len(st.session_state['Flags'][flag]))
543
+ if Trend:
544
+ st.session_state["project_dct"]["model_tuning"][
545
+ "trend_check"
546
+ ] = True
547
+ # Sprint3 - group by panel, calculate trend of each panel spearately. Add trend to new feature set
548
+ if is_panel:
549
+ newdata = pd.DataFrame()
550
+ panel_wise_end_point_train = {}
551
+ for panel, groupdf in X_train_tuned.groupby(panel_col):
552
+ groupdf.sort_values(date_col, inplace=True)
553
+ groupdf["Trend"] = np.arange(1, len(groupdf) + 1, 1)
554
+ newdata = pd.concat([newdata, groupdf])
555
+ panel_wise_end_point_train[panel] = len(groupdf)
556
+ X_train_tuned = newdata.copy()
557
+
558
+ test_newdata = pd.DataFrame()
559
+ for panel, test_groupdf in X_test_tuned.groupby(panel_col):
560
+ test_groupdf.sort_values(date_col, inplace=True)
561
+ start = panel_wise_end_point_train[panel] + 1
562
+ end = start + len(test_groupdf) # should be + 1? - Sprint4
563
+ # print("??"*20, panel, len(test_groupdf), len(np.arange(start, end, 1)), start)
564
+ test_groupdf["Trend"] = np.arange(start, end, 1)
565
+ test_newdata = pd.concat([test_newdata, test_groupdf])
566
+ X_test_tuned = test_newdata.copy()
567
+
568
+ new_features = new_features + ["Trend"]
569
+
570
+ else:
571
+ X_train_tuned["Trend"] = np.arange(
572
+ 1, len(X_train_tuned) + 1, 1
573
+ )
574
+ X_test_tuned["Trend"] = np.arange(
575
+ len(X_train_tuned) + 1,
576
+ len(X_train_tuned) + len(X_test_tuned) + 1,
577
+ 1,
578
+ )
579
+ new_features = new_features + ["Trend"]
580
+ else:
581
+ st.session_state["project_dct"]["model_tuning"][
582
+ "trend_check"
583
+ ] = False
584
+
585
+ if week_number:
586
+ st.session_state["project_dct"]["model_tuning"][
587
+ "week_num_check"
588
+ ] = True
589
+ # Sprint3 - create weeknumber from date column in xtrain tuned. add week num to new feature set
590
+ if is_panel:
591
+ X_train_tuned[date_col] = pd.to_datetime(
592
+ X_train_tuned[date_col]
593
+ )
594
+ X_train_tuned["Week_number"] = X_train_tuned[
595
+ date_col
596
+ ].dt.day_of_week
597
+ if X_train_tuned["Week_number"].nunique() == 1:
598
+ st.write(
599
+ "All dates in the data are of the same week day. Hence Week number can't be used."
600
+ )
601
+ else:
602
+ X_test_tuned[date_col] = pd.to_datetime(
603
+ X_test_tuned[date_col]
604
+ )
605
+ X_test_tuned["Week_number"] = X_test_tuned[
606
+ date_col
607
+ ].dt.day_of_week
608
+ new_features = new_features + ["Week_number"]
609
+
610
+ else:
611
+ date = pd.to_datetime(date.values)
612
+ X_train_tuned["Week_number"] = pd.to_datetime(
613
+ X_train[date_col]
614
+ ).dt.day_of_week
615
+ X_test_tuned["Week_number"] = pd.to_datetime(
616
+ X_test[date_col]
617
+ ).dt.day_of_week
618
+ new_features = new_features + ["Week_number"]
619
+ else:
620
+ st.session_state["project_dct"]["model_tuning"][
621
+ "week_num_check"
622
+ ] = False
623
+
624
+ if sine_cosine:
625
+ st.session_state["project_dct"]["model_tuning"][
626
+ "sine_cosine_check"
627
+ ] = True
628
+ # Sprint3 - create panel wise sine cosine waves in xtrain tuned. add to new feature set
629
+ if is_panel:
630
+ new_features = new_features + ["sine_wave", "cosine_wave"]
631
+ newdata = pd.DataFrame()
632
+ newdata_test = pd.DataFrame()
633
+ groups = X_train_tuned.groupby(panel_col)
634
+ frequency = 2 * np.pi / 365 # Adjust the frequency as needed
635
+
636
+ train_panel_wise_end_point = {}
637
+ for panel, groupdf in groups:
638
+ num_samples = len(groupdf)
639
+ train_panel_wise_end_point[panel] = num_samples
640
+ days_since_start = np.arange(num_samples)
641
+ sine_wave = np.sin(frequency * days_since_start)
642
+ cosine_wave = np.cos(frequency * days_since_start)
643
+ sine_cosine_df = pd.DataFrame(
644
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
645
+ )
646
+ assert len(sine_cosine_df) == len(groupdf)
647
+ # groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
648
+ groupdf["sine_wave"] = sine_wave
649
+ groupdf["cosine_wave"] = cosine_wave
650
+ newdata = pd.concat([newdata, groupdf])
651
+
652
+ X_train_tuned = newdata.copy()
653
+
654
+ test_groups = X_test_tuned.groupby(panel_col)
655
+ for panel, test_groupdf in test_groups:
656
+ num_samples = len(test_groupdf)
657
+ start = train_panel_wise_end_point[panel]
658
+ days_since_start = np.arange(start, start + num_samples, 1)
659
+ # print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1)))
660
+ sine_wave = np.sin(frequency * days_since_start)
661
+ cosine_wave = np.cos(frequency * days_since_start)
662
+ sine_cosine_df = pd.DataFrame(
663
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
664
+ )
665
+ assert len(sine_cosine_df) == len(test_groupdf)
666
+ # groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
667
+ test_groupdf["sine_wave"] = sine_wave
668
+ test_groupdf["cosine_wave"] = cosine_wave
669
+ newdata_test = pd.concat([newdata_test, test_groupdf])
670
+
671
+ X_test_tuned = newdata_test.copy()
672
+
673
+ else:
674
+ new_features = new_features + ["sine_wave", "cosine_wave"]
675
+
676
+ num_samples = len(X_train_tuned)
677
+ frequency = 2 * np.pi / 365 # Adjust the frequency as needed
678
+ days_since_start = np.arange(num_samples)
679
+ sine_wave = np.sin(frequency * days_since_start)
680
+ cosine_wave = np.cos(frequency * days_since_start)
681
+ sine_cosine_df = pd.DataFrame(
682
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
683
+ )
684
+ # Concatenate the sine and cosine waves with the scaled X DataFrame
685
+ X_train_tuned = pd.concat(
686
+ [X_train_tuned, sine_cosine_df], axis=1
687
+ )
688
+
689
+ test_num_samples = len(X_test_tuned)
690
+ start = num_samples
691
+ days_since_start = np.arange(
692
+ start, start + test_num_samples, 1
693
+ )
694
+ sine_wave = np.sin(frequency * days_since_start)
695
+ cosine_wave = np.cos(frequency * days_since_start)
696
+ sine_cosine_df = pd.DataFrame(
697
+ {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
698
+ )
699
+ # Concatenate the sine and cosine waves with the scaled X DataFrame
700
+ X_test_tuned = pd.concat(
701
+ [X_test_tuned, sine_cosine_df], axis=1
702
+ )
703
+ else:
704
+ st.session_state["project_dct"]["model_tuning"][
705
+ "sine_cosine_check"
706
+ ] = False
707
+
708
+ # model
709
+ if selected_options:
710
+ new_features = new_features + selected_options
711
+ if is_panel:
712
+ inp_vars_str = " + ".join(new_features)
713
+ new_features = list(set(new_features))
714
+
715
+ md_str = target_col + " ~ " + inp_vars_str
716
+ md_tuned = smf.mixedlm(
717
+ md_str,
718
+ data=X_train_tuned[[target_col] + new_features],
719
+ groups=X_train_tuned[panel_col],
720
+ )
721
+ model_tuned = md_tuned.fit()
722
+
723
+ # plot act v pred for original model and tuned model
724
+ metrics_table, line, actual_vs_predicted_plot = (
725
+ plot_actual_vs_predicted(
726
+ X_train[date_col],
727
+ y_train,
728
+ model.fittedvalues,
729
+ model,
730
+ target_column=sel_target_col,
731
+ is_panel=True,
732
+ )
733
+ )
734
+ metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
735
+ plot_actual_vs_predicted(
736
+ X_train_tuned[date_col],
737
+ X_train_tuned[target_col],
738
+ model_tuned.fittedvalues,
739
+ model_tuned,
740
+ target_column=sel_target_col,
741
+ is_panel=True,
742
+ )
743
+ )
744
+
745
+ else:
746
+ new_features = list(set(new_features))
747
+ model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit()
748
+ metrics_table, line, actual_vs_predicted_plot = (
749
+ plot_actual_vs_predicted(
750
+ date[:130],
751
+ y_train,
752
+ model.predict(X_train[features_set]),
753
+ model,
754
+ target_column=sel_target_col,
755
+ )
756
+ )
757
+ metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
758
+ plot_actual_vs_predicted(
759
+ date[:130],
760
+ y_train,
761
+ model_tuned.predict(X_train_tuned),
762
+ model_tuned,
763
+ target_column=sel_target_col,
764
+ )
765
+ )
766
+
767
+ mape = np.round(metrics_table.iloc[0, 1], 2)
768
+ r2 = np.round(metrics_table.iloc[1, 1], 2)
769
+ adjr2 = np.round(metrics_table.iloc[2, 1], 2)
770
+
771
+ mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2)
772
+ r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2)
773
+ adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2)
774
+
775
+ parameters_ = st.columns(3)
776
+ with parameters_[0]:
777
+ st.metric("R2", r2_tuned, np.round(r2_tuned - r2, 2))
778
+ with parameters_[1]:
779
+ st.metric(
780
+ "Adjusted R2", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2)
781
+ )
782
+ with parameters_[2]:
783
+ st.metric(
784
+ "MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse"
785
+ )
786
+ st.write(model_tuned.summary())
787
+
788
+ X_train_tuned[date_col] = X_train[date_col]
789
+ X_test_tuned[date_col] = X_test[date_col]
790
+ X_train_tuned[target_col] = y_train
791
+ X_test_tuned[target_col] = y_test
792
+
793
+ st.header("2.2 Actual vs. Predicted Plot")
794
+ # if is_panel:
795
+ # metrics_table, line, actual_vs_predicted_plot = plot_actual_vs_predicted(date, y_train, model.predict(X_train),
796
+ # model, target_column='Revenue',is_panel=True)
797
+ # else:
798
+ # metrics_table,line,actual_vs_predicted_plot=plot_actual_vs_predicted(date, y_train, model.predict(X_train), model,target_column='Revenue')
799
+ if is_panel:
800
+ metrics_table, line, actual_vs_predicted_plot = (
801
+ plot_actual_vs_predicted(
802
+ X_train_tuned[date_col],
803
+ X_train_tuned[target_col],
804
+ model_tuned.fittedvalues,
805
+ model_tuned,
806
+ target_column=sel_target_col,
807
+ is_panel=True,
808
+ )
809
+ )
810
+ else:
811
+ metrics_table, line, actual_vs_predicted_plot = (
812
+ plot_actual_vs_predicted(
813
+ X_train_tuned[date_col],
814
+ X_train_tuned[target_col],
815
+ model_tuned.predict(X_train_tuned[new_features]),
816
+ model_tuned,
817
+ target_column=sel_target_col,
818
+ is_panel=False,
819
+ )
820
+ )
821
+ # plot_actual_vs_predicted(X_train[date_col], y_train,
822
+ # model.fittedvalues, model,
823
+ # target_column='Revenue',
824
+ # is_panel=is_panel)
825
+
826
+ st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
827
+
828
+ st.markdown("## 2.3 Residual Analysis")
829
+ if is_panel:
830
+ columns = st.columns(2)
831
+ with columns[0]:
832
+ fig = plot_residual_predicted(
833
+ y_train, model_tuned.fittedvalues, X_train_tuned
834
+ )
835
+ st.plotly_chart(fig)
836
+
837
+ with columns[1]:
838
+ st.empty()
839
+ fig = qqplot(y_train, model_tuned.fittedvalues)
840
+ st.plotly_chart(fig)
841
+
842
+ with columns[0]:
843
+ fig = residual_distribution(y_train, model_tuned.fittedvalues)
844
+ st.pyplot(fig)
845
+ else:
846
+ columns = st.columns(2)
847
+ with columns[0]:
848
+ fig = plot_residual_predicted(
849
+ y_train,
850
+ model_tuned.predict(X_train_tuned[new_features]),
851
+ X_train,
852
+ )
853
+ st.plotly_chart(fig)
854
+
855
+ with columns[1]:
856
+ st.empty()
857
+ fig = qqplot(
858
+ y_train, model_tuned.predict(X_train_tuned[new_features])
859
+ )
860
+ st.plotly_chart(fig)
861
+
862
+ with columns[0]:
863
+ fig = residual_distribution(
864
+ y_train, model_tuned.predict(X_train_tuned[new_features])
865
+ )
866
+ st.pyplot(fig)
867
+
868
+ # st.session_state['is_tuned_model'][target_col] = True
869
+ # Sprint4 - saved tuned model in a dict
870
+ st.session_state["Model_Tuned"][sel_model + "__" + target_col] = {
871
+ "Model_object": model_tuned,
872
+ "feature_set": new_features,
873
+ "X_train_tuned": X_train_tuned,
874
+ "X_test_tuned": X_test_tuned,
875
+ }
876
+
877
+ # Pending
878
+ # if st.session_state['build_tuned_model']==True:
879
+ if st.session_state["Model_Tuned"] is not None:
880
+ if st.button(
881
+ "Use This model for Media Planning",use_container_width=True
882
+ ):
883
+ # save_model = st.button('Use this model to build response curves', key='saved_tuned_model')
884
+ # if save_model:
885
+ st.session_state["is_tuned_model"][target_col] = True
886
+ with open(
887
+ os.path.join(
888
+ st.session_state["project_path"], "tuned_model.pkl"
889
+ ),
890
+ "wb",
891
+ ) as f:
892
+ # pickle.dump(st.session_state['tuned_model'], f)
893
+ pickle.dump(st.session_state["Model_Tuned"], f) # Sprint4
894
+
895
+ st.session_state["project_dct"]["model_tuning"][
896
+ "session_state_saved"
897
+ ] = {}
898
+ for key in [
899
+ "bin_dict",
900
+ "used_response_metrics",
901
+ "is_tuned_model",
902
+ "media_data",
903
+ "X_test_spends",
904
+ ]:
905
+ st.session_state["project_dct"]["model_tuning"][
906
+ "session_state_saved"
907
+ ][key] = st.session_state[key]
908
+
909
+ project_dct_path = os.path.join(
910
+ st.session_state["project_path"], "project_dct.pkl"
911
+ )
912
+ with open(project_dct_path, "wb") as f:
913
+ pickle.dump(st.session_state["project_dct"], f)
914
+
915
+ update_db("5_Model_Tuning.py")
916
+
917
+ st.success(sel_model + "__" + target_col + " Tuned saved!")
pages/6_AI_Model_Results.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.express as px
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import statsmodels.api as sm
7
+ from sklearn.metrics import mean_absolute_percentage_error
8
+ import sys
9
+ import os
10
+ from utilities import set_header, load_local_css, load_authenticator
11
+ import seaborn as sns
12
+ import matplotlib.pyplot as plt
13
+ import sweetviz as sv
14
+ import tempfile
15
+ from sklearn.preprocessing import MinMaxScaler
16
+ from st_aggrid import AgGrid
17
+ from st_aggrid import GridOptionsBuilder, GridUpdateMode
18
+ from st_aggrid import GridOptionsBuilder
19
+ import sys
20
+ import re
21
+ import pickle
22
+ from sklearn.metrics import r2_score, mean_absolute_percentage_error
23
+ from Data_prep_functions import plot_actual_vs_predicted
24
+ import sqlite3
25
+ from utilities import update_db
26
+
27
+ sys.setrecursionlimit(10**6)
28
+
29
+ original_stdout = sys.stdout
30
+ sys.stdout = open("temp_stdout.txt", "w")
31
+ sys.stdout.close()
32
+ sys.stdout = original_stdout
33
+
34
+ st.set_page_config(layout="wide")
35
+ load_local_css("styles.css")
36
+ set_header()
37
+
38
+ # TODO :
39
+ ## 1. Add non panel model support
40
+ ## 2. EDA Function
41
+
42
+ for k, v in st.session_state.items():
43
+ if k not in ["logout", "login", "config"] and not k.startswith("FormSubmitter"):
44
+ st.session_state[k] = v
45
+
46
+ authenticator = st.session_state.get("authenticator")
47
+ if authenticator is None:
48
+ authenticator = load_authenticator()
49
+
50
+ name, authentication_status, username = authenticator.login("Login", "main")
51
+ auth_status = st.session_state.get("authentication_status")
52
+
53
+ if auth_status == True:
54
+ is_state_initiaized = st.session_state.get("initialized", False)
55
+ if not is_state_initiaized:
56
+ if "session_name" not in st.session_state:
57
+ st.session_state["session_name"] = None
58
+
59
+ if "project_dct" not in st.session_state:
60
+ st.error("Please load a project from Home page")
61
+ st.stop()
62
+
63
+ conn = sqlite3.connect(
64
+ r"DB/User.db", check_same_thread=False
65
+ ) # connection with sql db
66
+ c = conn.cursor()
67
+
68
+ if not os.path.exists(
69
+ os.path.join(st.session_state["project_path"], "tuned_model.pkl")
70
+ ):
71
+ st.error("Please save a tuned model")
72
+ st.stop()
73
+
74
+ if (
75
+ "session_state_saved" in st.session_state["project_dct"]["model_tuning"].keys()
76
+ and st.session_state["project_dct"]["model_tuning"]["session_state_saved"] != []
77
+ ):
78
+ for key in ["used_response_metrics", "media_data", "bin_dict"]:
79
+ if key not in st.session_state:
80
+ st.session_state[key] = st.session_state["project_dct"]["model_tuning"][
81
+ "session_state_saved"
82
+ ][key]
83
+ st.session_state["bin_dict"] = st.session_state["project_dct"][
84
+ "model_build"
85
+ ]["session_state_saved"]["bin_dict"]
86
+
87
+ media_data = st.session_state["media_data"]
88
+
89
+ st.write(media_data.columns)
90
+
91
+ panel_col = [
92
+ col.lower()
93
+ .replace(".", "_")
94
+ .replace("@", "_")
95
+ .replace(" ", "_")
96
+ .replace("-", "")
97
+ .replace(":", "")
98
+ .replace("__", "_")
99
+ for col in st.session_state["bin_dict"]["Panel Level 1"]
100
+ ][
101
+ 0
102
+ ] # set the panel column
103
+ is_panel = True if len(panel_col) > 0 else False
104
+ date_col = "date"
105
+
106
+ def plot_residual_predicted(actual, predicted, df_):
107
+ df_["Residuals"] = actual - pd.Series(predicted)
108
+ df_["StdResidual"] = (df_["Residuals"] - df_["Residuals"].mean()) / df_[
109
+ "Residuals"
110
+ ].std()
111
+
112
+ # Create a Plotly scatter plot
113
+ fig = px.scatter(
114
+ df_,
115
+ x=predicted,
116
+ y="StdResidual",
117
+ opacity=0.5,
118
+ color_discrete_sequence=["#11B6BD"],
119
+ )
120
+
121
+ # Add horizontal lines
122
+ fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
123
+ fig.add_hline(y=2, line_color="red")
124
+ fig.add_hline(y=-2, line_color="red")
125
+
126
+ fig.update_xaxes(title="Predicted")
127
+ fig.update_yaxes(title="Standardized Residuals (Actual - Predicted)")
128
+
129
+ # Set the same width and height for both figures
130
+ fig.update_layout(
131
+ title="Residuals over Predicted Values",
132
+ autosize=False,
133
+ width=600,
134
+ height=400,
135
+ )
136
+
137
+ return fig
138
+
139
+ def residual_distribution(actual, predicted):
140
+ Residuals = actual - pd.Series(predicted)
141
+
142
+ # Create a Seaborn distribution plot
143
+ sns.set(style="whitegrid")
144
+ plt.figure(figsize=(6, 4))
145
+ sns.histplot(Residuals, kde=True, color="#11B6BD")
146
+
147
+ plt.title(" Distribution of Residuals")
148
+ plt.xlabel("Residuals")
149
+ plt.ylabel("Probability Density")
150
+
151
+ return plt
152
+
153
+ def qqplot(actual, predicted):
154
+ Residuals = actual - pd.Series(predicted)
155
+ Residuals = pd.Series(Residuals)
156
+ Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
157
+
158
+ # Create a QQ plot using Plotly with custom colors
159
+ fig = go.Figure()
160
+ fig.add_trace(
161
+ go.Scatter(
162
+ x=sm.ProbPlot(Resud_std).theoretical_quantiles,
163
+ y=sm.ProbPlot(Resud_std).sample_quantiles,
164
+ mode="markers",
165
+ marker=dict(size=5, color="#11B6BD"),
166
+ name="QQ Plot",
167
+ )
168
+ )
169
+
170
+ # Add the 45-degree reference line
171
+ diagonal_line = go.Scatter(
172
+ x=[
173
+ -2,
174
+ 2,
175
+ ], # Adjust the x values as needed to fit the range of your data
176
+ y=[-2, 2], # Adjust the y values accordingly
177
+ mode="lines",
178
+ line=dict(color="red"), # Customize the line color and style
179
+ name=" ",
180
+ )
181
+ fig.add_trace(diagonal_line)
182
+
183
+ # Customize the layout
184
+ fig.update_layout(
185
+ title="QQ Plot of Residuals",
186
+ title_x=0.5,
187
+ autosize=False,
188
+ width=600,
189
+ height=400,
190
+ xaxis_title="Theoretical Quantiles",
191
+ yaxis_title="Sample Quantiles",
192
+ )
193
+
194
+ return fig
195
+
196
+ def get_random_effects(media_data, panel_col, mdf):
197
+ random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
198
+ for i, market in enumerate(media_data[panel_col].unique()):
199
+ print(i, end="\r")
200
+ intercept = mdf.random_effects[market].values[0]
201
+ random_eff_df.loc[i, "random_effect"] = intercept
202
+ random_eff_df.loc[i, panel_col] = market
203
+
204
+ return random_eff_df
205
+
206
+ def mdf_predict(X_df, mdf, random_eff_df):
207
+ X = X_df.copy()
208
+ X = pd.merge(
209
+ X,
210
+ random_eff_df[[panel_col, "random_effect"]],
211
+ on=panel_col,
212
+ how="left",
213
+ )
214
+ X["pred_fixed_effect"] = mdf.predict(X)
215
+
216
+ X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
217
+ X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
218
+ return X
219
+
220
+ def metrics_df_panel(model_dict):
221
+ metrics_df = pd.DataFrame(
222
+ columns=[
223
+ "Model",
224
+ "R2",
225
+ "ADJR2",
226
+ "Train Mape",
227
+ "Test Mape",
228
+ "Summary",
229
+ "Model_object",
230
+ ]
231
+ )
232
+ i = 0
233
+ for key in model_dict.keys():
234
+ target = key.split("__")[1]
235
+ metrics_df.at[i, "Model"] = target
236
+ y = model_dict[key]["X_train_tuned"][target]
237
+
238
+ random_df = get_random_effects(
239
+ media_data, panel_col, model_dict[key]["Model_object"]
240
+ )
241
+ pred = mdf_predict(
242
+ model_dict[key]["X_train_tuned"],
243
+ model_dict[key]["Model_object"],
244
+ random_df,
245
+ )["pred"]
246
+
247
+ ytest = model_dict[key]["X_test_tuned"][target]
248
+ predtest = mdf_predict(
249
+ model_dict[key]["X_test_tuned"],
250
+ model_dict[key]["Model_object"],
251
+ random_df,
252
+ )["pred"]
253
+
254
+ metrics_df.at[i, "R2"] = r2_score(y, pred)
255
+ metrics_df.at[i, "ADJR2"] = 1 - (1 - metrics_df.loc[i, "R2"]) * (
256
+ len(y) - 1
257
+ ) / (len(y) - len(model_dict[key]["feature_set"]) - 1)
258
+ metrics_df.at[i, "Train Mape"] = mean_absolute_percentage_error(y, pred)
259
+ metrics_df.at[i, "Test Mape"] = mean_absolute_percentage_error(
260
+ ytest, predtest
261
+ )
262
+ metrics_df.at[i, "Summary"] = model_dict[key]["Model_object"].summary()
263
+ metrics_df.at[i, "Model_object"] = model_dict[key]["Model_object"]
264
+ i += 1
265
+ metrics_df = np.round(metrics_df, 2)
266
+ return metrics_df
267
+
268
+ with open(
269
+ os.path.join(st.session_state["project_path"], "final_df_transformed.pkl"),
270
+ "rb",
271
+ ) as f:
272
+ data = pickle.load(f)
273
+ transformed_data = data["final_df_transformed"]
274
+ with open(
275
+ os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
276
+ ) as f:
277
+ data = pickle.load(f)
278
+ st.session_state["bin_dict"] = data["bin_dict"]
279
+ with open(
280
+ os.path.join(st.session_state["project_path"], "tuned_model.pkl"), "rb"
281
+ ) as file:
282
+ tuned_model_dict = pickle.load(file)
283
+ feature_set_dct = {
284
+ key.split("__")[1]: key_dict["feature_set"]
285
+ for key, key_dict in tuned_model_dict.items()
286
+ }
287
+
288
+ # """ the above part should be modified so that we are fetching features set from the saved model"""
289
+
290
+ def contributions(X, model, target):
291
+ X1 = X.copy()
292
+ for j, col in enumerate(X1.columns):
293
+ X1[col] = X1[col] * model.params.values[j]
294
+
295
+ contributions = np.round(
296
+ (X1.sum() / sum(X1.sum()) * 100).sort_values(ascending=False), 2
297
+ )
298
+ contributions = (
299
+ pd.DataFrame(contributions, columns=target)
300
+ .reset_index()
301
+ .rename(columns={"index": "Channel"})
302
+ )
303
+ contributions["Channel"] = [
304
+ re.split(r"_imp|_cli", col)[0] for col in contributions["Channel"]
305
+ ]
306
+
307
+ return contributions
308
+
309
+ if "contribution_df" not in st.session_state:
310
+ st.session_state["contribution_df"] = None
311
+
312
+ def contributions_panel(model_dict):
313
+ media_data = st.session_state["media_data"]
314
+ contribution_df = pd.DataFrame(columns=["Channel"])
315
+ for key in model_dict.keys():
316
+ best_feature_set = model_dict[key]["feature_set"]
317
+ model = model_dict[key]["Model_object"]
318
+ target = key.split("__")[1]
319
+ X_train = model_dict[key]["X_train_tuned"]
320
+ contri_df = pd.DataFrame()
321
+
322
+ y = []
323
+ y_pred = []
324
+
325
+ random_eff_df = get_random_effects(media_data, panel_col, model)
326
+ random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
327
+ random_eff_df["panel_effect"] = (
328
+ random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
329
+ )
330
+
331
+ coef_df = pd.DataFrame(model.fe_params)
332
+ coef_df.reset_index(inplace=True)
333
+ coef_df.columns = ["feature", "coef"]
334
+
335
+ x_train_contribution = X_train.copy()
336
+ x_train_contribution = mdf_predict(
337
+ x_train_contribution, model, random_eff_df
338
+ )
339
+
340
+ x_train_contribution = pd.merge(
341
+ x_train_contribution,
342
+ random_eff_df[[panel_col, "panel_effect"]],
343
+ on=panel_col,
344
+ how="left",
345
+ )
346
+
347
+ for i in range(len(coef_df))[1:]:
348
+ coef = coef_df.loc[i, "coef"]
349
+ col = coef_df.loc[i, "feature"]
350
+ x_train_contribution[str(col) + "_contr"] = (
351
+ coef * x_train_contribution[col]
352
+ )
353
+
354
+ # x_train_contribution['sum_contributions'] = x_train_contribution.filter(regex="contr").sum(axis=1)
355
+ # x_train_contribution['sum_contributions'] = x_train_contribution['sum_contributions'] + x_train_contribution[
356
+ # 'panel_effect']
357
+
358
+ base_cols = ["panel_effect"] + [
359
+ c
360
+ for c in x_train_contribution.filter(regex="contr").columns
361
+ if c
362
+ in [
363
+ "Week_number_contr",
364
+ "Trend_contr",
365
+ "sine_wave_contr",
366
+ "cosine_wave_contr",
367
+ ]
368
+ ]
369
+ x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(
370
+ axis=1
371
+ )
372
+ x_train_contribution.drop(columns=base_cols, inplace=True)
373
+ # x_train_contribution.to_csv("Test/smr_x_train_contribution.csv", index=False)
374
+
375
+ contri_df = pd.DataFrame(
376
+ x_train_contribution.filter(regex="contr").sum(axis=0)
377
+ )
378
+ contri_df.reset_index(inplace=True)
379
+ contri_df.columns = ["Channel", target]
380
+ contri_df["Channel"] = (
381
+ contri_df["Channel"]
382
+ .str.split("(_impres|_clicks)")
383
+ .apply(lambda c: c[0])
384
+ )
385
+ contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
386
+ contri_df["Channel"].replace("base_contr", "base", inplace=True)
387
+ contribution_df = pd.merge(
388
+ contribution_df, contri_df, on="Channel", how="outer"
389
+ )
390
+ # st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
391
+ return contribution_df
392
+
393
+ metrics_table = metrics_df_panel(tuned_model_dict)
394
+
395
+ st.title("AI Model Results")
396
+
397
+ st.header('Contribution Overview')
398
+
399
+ options = st.session_state["used_response_metrics"]
400
+ st.write(options)
401
+
402
+ options = [
403
+ opt.lower()
404
+ .replace(" ", "_")
405
+ .replace("-", "")
406
+ .replace(":", "")
407
+ .replace("__", "_")
408
+ for opt in options
409
+ ]
410
+
411
+ default_options = (
412
+ st.session_state["project_dct"]["saved_model_results"].get("selected_options")
413
+ if st.session_state["project_dct"]["saved_model_results"].get(
414
+ "selected_options"
415
+ )
416
+ is not None
417
+ else [options[-1]]
418
+ )
419
+ for i in default_options:
420
+ if i not in options:
421
+ st.write(i)
422
+ default_options.remove(i)
423
+
424
+ def format_display(inp):
425
+ return inp.title().replace("_", " ").strip()
426
+
427
+ contribution_selections = st.multiselect(
428
+ "Select the Response Metrics to compare contributions",
429
+ options,
430
+ default=default_options,
431
+ format_func=format_display,
432
+ )
433
+ trace_data = []
434
+
435
+ st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
436
+ st.write(st.session_state["contribution_df"].columns)
437
+ # for selection in contribution_selections:
438
+
439
+ # trace = go.Bar(
440
+ # x=st.session_state["contribution_df"]["Channel"],
441
+ # y=st.session_state["contribution_df"][selection],
442
+ # name=selection,
443
+ # text=np.round(st.session_state["contribution_df"][selection], 0)
444
+ # .astype(int)
445
+ # .astype(str)
446
+ # + "%",
447
+ # textposition="outside",
448
+ # )
449
+ # trace_data.append(trace)
450
+
451
+ # layout = go.Layout(
452
+ # title="Metrics Contribution by Channel",
453
+ # xaxis=dict(title="Channel Name"),
454
+ # yaxis=dict(title="Metrics Contribution"),
455
+ # barmode="group",
456
+ # )
457
+ # fig = go.Figure(data=trace_data, layout=layout)
458
+ # st.plotly_chart(fig, use_container_width=True)
459
+
460
+ def create_grouped_bar_plot(contribution_df, contribution_selections):
461
+ # Extract the 'Channel' names
462
+ channel_names = contribution_df["Channel"].tolist()
463
+
464
+ # Dictionary to store all contributions except 'const' and 'base'
465
+ all_contributions = {
466
+ name: [] for name in channel_names if name not in ["const", "base"]
467
+ }
468
+
469
+ # Dictionary to store base sales for each selection
470
+ base_sales_dict = {}
471
+
472
+ # Accumulate contributions for each channel from each selection
473
+ for selection in contribution_selections:
474
+ contributions = contribution_df[selection].values.astype(float)
475
+ base_sales = 0 # Initialize base sales for the current selection
476
+
477
+ for channel_name, contribution in zip(channel_names, contributions):
478
+ if channel_name in all_contributions:
479
+ all_contributions[channel_name].append(contribution)
480
+ elif channel_name == "base":
481
+ base_sales = (
482
+ contribution # Capture base sales for the current selection
483
+ )
484
+
485
+ # Store base sales for each selection
486
+ base_sales_dict[selection] = base_sales
487
+
488
+ # Calculate the average of contributions and sort by this average
489
+ sorted_channels = sorted(
490
+ all_contributions.items(), key=lambda x: -np.mean(x[1])
491
+ )
492
+ sorted_channel_names = [name for name, _ in sorted_channels]
493
+ sorted_channel_names = [
494
+ "Base Sales"
495
+ ] + sorted_channel_names # Adding 'Base Sales' at the start
496
+
497
+ trace_data = []
498
+ max_value = (
499
+ 0 # Initialize max_value to find the highest bar for y-axis adjustment
500
+ )
501
+
502
+ # Create traces for the grouped bar chart
503
+ for selection in contribution_selections:
504
+ display_name = sorted_channel_names
505
+ display_contribution = [base_sales_dict[selection]] + [
506
+ np.mean(all_contributions[name]) for name in sorted_channel_names[1:]
507
+ ] # Start with base sales for the current selection
508
+
509
+ # Generating text labels for each bar
510
+ text_values = [
511
+ f"{val}%" for val in np.round(display_contribution, 0).astype(int)
512
+ ]
513
+
514
+ # Find the max value for y-axis calculation
515
+ max_contribution = max(display_contribution)
516
+ if max_contribution > max_value:
517
+ max_value = max_contribution
518
+
519
+ # Create a bar trace for each selection
520
+ trace = go.Bar(
521
+ x=display_name,
522
+ y=display_contribution,
523
+ name=selection,
524
+ text=text_values,
525
+ textposition="outside",
526
+ )
527
+ trace_data.append(trace)
528
+
529
+ # Define layout for the bar chart
530
+ layout = go.Layout(
531
+ title="Metrics Contribution by Channel",
532
+ xaxis=dict(title="Channel Name"),
533
+ yaxis=dict(
534
+ title="Metrics Contribution", range=[0, max_value * 1.2]
535
+ ), # Set y-axis 20% higher than the max bar
536
+ barmode="group",
537
+ plot_bgcolor="white",
538
+ )
539
+
540
+ # Create the figure with trace data and layout
541
+ fig = go.Figure(data=trace_data, layout=layout)
542
+
543
+ return fig
544
+
545
+ # Display the chart in Streamlit
546
+ st.plotly_chart(
547
+ create_grouped_bar_plot(
548
+ st.session_state["contribution_df"], contribution_selections
549
+ ),
550
+ use_container_width=True,
551
+ )
552
+
553
+ ############################################ Waterfall Chart ############################################
554
+
555
+ import plotly.graph_objects as go
556
+
557
+ # # Initialize a Plotly figure
558
+ # fig = go.Figure()
559
+
560
+ # for selection in contribution_selections:
561
+ # # Ensure contributions are numeric
562
+ # contributions = (
563
+ # st.session_state["contribution_df"][selection].values.astype(float).tolist()
564
+ # )
565
+ # channel_names = st.session_state["contribution_df"]["Channel"].tolist()
566
+
567
+ # display_name, display_contribution, base_contribution = [], [], 0
568
+ # for channel_name, contribution in zip(channel_names, contributions):
569
+ # if channel_name != "const" and channel_name != "base":
570
+ # display_name.append(channel_name)
571
+ # display_contribution.append(contribution)
572
+ # else:
573
+ # base_contribution = contribution
574
+
575
+ # display_name = ["Base Sales"] + display_name
576
+ # display_contribution = [base_contribution] + display_contribution
577
+
578
+ # # Generating text labels for each bar, ensuring operations are compatible with string formats
579
+ # text_values = [
580
+ # f"{val}%" for val in np.round(display_contribution, 0).astype(int)
581
+ # ]
582
+
583
+ # fig.add_trace(
584
+ # go.Waterfall(
585
+ # orientation="v",
586
+ # measure=["relative"] * len(display_contribution),
587
+ # x=display_name,
588
+ # text=text_values,
589
+ # textposition="outside",
590
+ # y=display_contribution,
591
+ # increasing={"marker": {"color": "green"}},
592
+ # decreasing={"marker": {"color": "red"}},
593
+ # totals={"marker": {"color": "blue"}},
594
+ # name=selection,
595
+ # )
596
+ # )
597
+
598
+ # fig.update_layout(
599
+ # title="Metrics Contribution by Channel",
600
+ # xaxis={"title": "Channel Name"},
601
+ # yaxis={"title": "Metrics Contribution"},
602
+ # height=600,
603
+ # )
604
+
605
+ # # Displaying the waterfall chart in Streamlit
606
+ # st.plotly_chart(fig, use_container_width=True)
607
+
608
+ def preprocess_and_plot(contribution_df, contribution_selections):
609
+ # Extract the 'Channel' names
610
+ channel_names = contribution_df["Channel"].tolist()
611
+
612
+ # Dictionary to store all contributions except 'const' and 'base'
613
+ all_contributions = {
614
+ name: [] for name in channel_names if name not in ["const", "base"]
615
+ }
616
+
617
+ # Dictionary to store base sales for each selection
618
+ base_sales_dict = {}
619
+
620
+ # Accumulate contributions for each channel from each selection
621
+ for selection in contribution_selections:
622
+ contributions = contribution_df[selection].values.astype(float)
623
+ base_sales = 0 # Initialize base sales for the current selection
624
+
625
+ for channel_name, contribution in zip(channel_names, contributions):
626
+ if channel_name in all_contributions:
627
+ all_contributions[channel_name].append(contribution)
628
+ elif channel_name == "base":
629
+ base_sales = (
630
+ contribution # Capture base sales for the current selection
631
+ )
632
+
633
+ # Store base sales for each selection
634
+ base_sales_dict[selection] = base_sales
635
+
636
+ # Calculate the average of contributions and sort by this average
637
+ sorted_channels = sorted(
638
+ all_contributions.items(), key=lambda x: -np.mean(x[1])
639
+ )
640
+ sorted_channel_names = [name for name, _ in sorted_channels]
641
+ sorted_channel_names = [
642
+ "Base Sales"
643
+ ] + sorted_channel_names # Adding 'Base Sales' at the start
644
+
645
+ # Initialize a Plotly figure
646
+ fig = go.Figure()
647
+
648
+ for selection in contribution_selections:
649
+ display_name = ["Base Sales"] + sorted_channel_names[
650
+ 1:
651
+ ] # Channel names for the plot
652
+ display_contribution = [
653
+ base_sales_dict[selection]
654
+ ] # Start with base sales for the current selection
655
+
656
+ # Append average contributions for other channels
657
+ for name in sorted_channel_names[1:]:
658
+ display_contribution.append(np.mean(all_contributions[name]))
659
+
660
+ # Generating text labels for each bar
661
+ text_values = [
662
+ f"{val}%" for val in np.round(display_contribution, 0).astype(int)
663
+ ]
664
+
665
+ # Add a waterfall trace for each selection
666
+ fig.add_trace(
667
+ go.Waterfall(
668
+ orientation="v",
669
+ measure=["relative"] * len(display_contribution),
670
+ x=display_name,
671
+ text=text_values,
672
+ textposition="outside",
673
+ y=display_contribution,
674
+ increasing={"marker": {"color": "green"}},
675
+ decreasing={"marker": {"color": "red"}},
676
+ totals={"marker": {"color": "blue"}},
677
+ name=selection,
678
+ )
679
+ )
680
+
681
+ # Update layout of the figure
682
+ fig.update_layout(
683
+ title="Metrics Contribution by Channel",
684
+ xaxis={"title": "Channel Name"},
685
+ yaxis=dict(title="Metrics Contribution", range=[0, 100 * 1.2]),
686
+ )
687
+
688
+ return fig
689
+
690
+ # Displaying the waterfall chart
691
+ st.plotly_chart(
692
+ preprocess_and_plot(
693
+ st.session_state["contribution_df"], contribution_selections
694
+ ),
695
+ use_container_width=True,
696
+ )
697
+
698
+ ############################################ Waterfall Chart ############################################
699
+
700
+ st.header("Analysis of Models Result")
701
+ # st.markdown()
702
+ previous_selection = st.session_state["project_dct"]["saved_model_results"].get(
703
+ "model_grid_sel", [1]
704
+ )
705
+ # st.write(np.round(metrics_table, 2))
706
+ gd_table = metrics_table.iloc[:, :-2]
707
+
708
+ gd = GridOptionsBuilder.from_dataframe(gd_table)
709
+ # gd.configure_pagination(enabled=True)
710
+ gd.configure_selection(
711
+ use_checkbox=True,
712
+ selection_mode="single",
713
+ pre_select_all_rows=False,
714
+ pre_selected_rows=previous_selection,
715
+ )
716
+
717
+ gridoptions = gd.build()
718
+ table = AgGrid(
719
+ gd_table,
720
+ gridOptions=gridoptions,
721
+ fit_columns_on_grid_load=True,
722
+ height=200,
723
+ )
724
+ # table=metrics_table.iloc[:,:-2]
725
+ # table.insert(0, "Select", False)
726
+ # selection_table=st.data_editor(table,column_config={"Select": st.column_config.CheckboxColumn(required=True)})
727
+ if len(table.selected_rows) > 0:
728
+ st.session_state["project_dct"]["saved_model_results"]["model_grid_sel"] = (
729
+ table.selected_rows[0]["_selectedRowNodeInfo"]["nodeRowIndex"]
730
+ )
731
+ if len(table.selected_rows) == 0:
732
+ st.warning(
733
+ "Click on the checkbox to view comprehensive results of the selected model."
734
+ )
735
+ st.stop()
736
+ else:
737
+ target_column = table.selected_rows[0]["Model"]
738
+ feature_set = feature_set_dct[target_column]
739
+
740
+
741
+ model = metrics_table[metrics_table["Model"] == target_column]["Model_object"].iloc[
742
+ 0
743
+ ]
744
+ target = metrics_table[metrics_table["Model"] == target_column]["Model"].iloc[0]
745
+ st.header("Model Summary")
746
+ st.write(model.summary())
747
+
748
+ sel_dict = tuned_model_dict[
749
+ [k for k in tuned_model_dict.keys() if k.split("__")[1] == target][0]
750
+ ]
751
+ X_train = sel_dict["X_train_tuned"]
752
+ y_train = X_train[target]
753
+ random_effects = get_random_effects(media_data, panel_col, model)
754
+ pred = mdf_predict(X_train, model, random_effects)["pred"]
755
+
756
+ X_test = sel_dict["X_test_tuned"]
757
+ y_test = X_test[target]
758
+ predtest = mdf_predict(X_test, model, random_effects)["pred"]
759
+ metrics_table_train, _, fig_train = plot_actual_vs_predicted(
760
+ X_train[date_col],
761
+ y_train,
762
+ pred,
763
+ model,
764
+ target_column=target_column,
765
+ flag=None,
766
+ repeat_all_years=False,
767
+ is_panel=is_panel,
768
+ )
769
+
770
+ metrics_table_test, _, fig_test = plot_actual_vs_predicted(
771
+ X_test[date_col],
772
+ y_test,
773
+ predtest,
774
+ model,
775
+ target_column=target_column,
776
+ flag=None,
777
+ repeat_all_years=False,
778
+ is_panel=is_panel,
779
+ )
780
+
781
+ metrics_table_train = metrics_table_train.set_index("Metric").transpose()
782
+ metrics_table_train.index = ["Train"]
783
+ metrics_table_test = metrics_table_test.set_index("Metric").transpose()
784
+ metrics_table_test.index = ["Test"]
785
+ metrics_table = np.round(pd.concat([metrics_table_train, metrics_table_test]), 2)
786
+
787
+ st.markdown("Result Overview")
788
+ st.dataframe(np.round(metrics_table, 2), use_container_width=True)
789
+
790
+ st.subheader("Actual vs Predicted Plot Train")
791
+
792
+ st.plotly_chart(fig_train, use_container_width=True)
793
+ st.subheader("Actual vs Predicted Plot Test")
794
+ st.plotly_chart(fig_test, use_container_width=True)
795
+
796
+ st.markdown("## Residual Analysis")
797
+ columns = st.columns(2)
798
+
799
+ Xtrain1 = X_train.copy()
800
+ with columns[0]:
801
+ fig = plot_residual_predicted(y_train, model.predict(Xtrain1), Xtrain1)
802
+ st.plotly_chart(fig)
803
+
804
+ with columns[1]:
805
+ st.empty()
806
+ fig = qqplot(y_train, model.predict(X_train))
807
+ st.plotly_chart(fig)
808
+
809
+ with columns[0]:
810
+ fig = residual_distribution(y_train, model.predict(X_train))
811
+ st.pyplot(fig)
812
+
813
+ update_db("6_AI_Model_Result.py")
814
+
815
+
816
+ elif auth_status == False:
817
+ st.error("Username/Password is incorrect")
818
+ try:
819
+ username_forgot_pw, email_forgot_password, random_password = (
820
+ authenticator.forgot_password("Forgot password")
821
+ )
822
+ if username_forgot_pw:
823
+ st.success("New password sent securely")
824
+ # Random password to be transferred to the user securely
825
+ elif username_forgot_pw == False:
826
+ st.error("Username not found")
827
+ except Exception as e:
828
+ st.error(e)
pages/7_Current_Media_Performance.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MMO Build Sprint 3
3
+ additions : contributions calculated using tuned Mixed LM model
4
+ pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
5
+
6
+ MMO Build Sprint 4
7
+ additions : response metrics selection
8
+ pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
9
+ """
10
+
11
+ import streamlit as st
12
+ import pandas as pd
13
+ from sklearn.preprocessing import MinMaxScaler
14
+ import pickle
15
+ import os
16
+
17
+ from utilities_with_panel import load_local_css, set_header
18
+ import yaml
19
+ from yaml import SafeLoader
20
+ import streamlit_authenticator as stauth
21
+ import sqlite3
22
+ from utilities import update_db
23
+
24
+ st.set_page_config(layout="wide")
25
+ load_local_css("styles.css")
26
+ set_header()
27
+ for k, v in st.session_state.items():
28
+ # print(k, v)
29
+ if k not in [
30
+ "logout",
31
+ "login",
32
+ "config",
33
+ "build_tuned_model",
34
+ ] and not k.startswith("FormSubmitter"):
35
+ st.session_state[k] = v
36
+ with open("config.yaml") as file:
37
+ config = yaml.load(file, Loader=SafeLoader)
38
+ st.session_state["config"] = config
39
+ authenticator = stauth.Authenticate(
40
+ config["credentials"],
41
+ config["cookie"]["name"],
42
+ config["cookie"]["key"],
43
+ config["cookie"]["expiry_days"],
44
+ config["preauthorized"],
45
+ )
46
+ st.session_state["authenticator"] = authenticator
47
+ name, authentication_status, username = authenticator.login("Login", "main")
48
+ auth_status = st.session_state.get("authentication_status")
49
+
50
+ if auth_status == True:
51
+ authenticator.logout("Logout", "main")
52
+ is_state_initiaized = st.session_state.get("initialized", False)
53
+
54
+ if "project_dct" not in st.session_state:
55
+ st.error("Please load a project from Home page")
56
+ st.stop()
57
+
58
+ conn = sqlite3.connect(
59
+ r"DB/User.db", check_same_thread=False
60
+ ) # connection with sql db
61
+ c = conn.cursor()
62
+
63
+ if not os.path.exists(
64
+ os.path.join(st.session_state["project_path"], "tuned_model.pkl")
65
+ ):
66
+ st.error("Please save a tuned model")
67
+ st.stop()
68
+
69
+ if (
70
+ "session_state_saved"
71
+ in st.session_state["project_dct"]["model_tuning"].keys()
72
+ and st.session_state["project_dct"]["model_tuning"][
73
+ "session_state_saved"
74
+ ]
75
+ != []
76
+ ):
77
+ for key in [
78
+ "used_response_metrics",
79
+ "is_tuned_model",
80
+ "media_data",
81
+ "X_test_spends",
82
+ ]:
83
+ st.session_state[key] = st.session_state["project_dct"][
84
+ "model_tuning"
85
+ ]["session_state_saved"][key]
86
+ elif (
87
+ "session_state_saved"
88
+ in st.session_state["project_dct"]["model_build"].keys()
89
+ and st.session_state["project_dct"]["model_build"][
90
+ "session_state_saved"
91
+ ]
92
+ != []
93
+ ):
94
+ for key in [
95
+ "used_response_metrics",
96
+ "date",
97
+ "saved_model_names",
98
+ "media_data",
99
+ "X_test_spends",
100
+ ]:
101
+ st.session_state[key] = st.session_state["project_dct"][
102
+ "model_build"
103
+ ]["session_state_saved"][key]
104
+ else:
105
+ st.error("Please tune a model first")
106
+ st.session_state["bin_dict"] = st.session_state["project_dct"][
107
+ "model_build"
108
+ ]["session_state_saved"]["bin_dict"]
109
+ st.session_state["media_data"].columns = [
110
+ c.lower() for c in st.session_state["media_data"].columns
111
+ ]
112
+
113
+ from utilities_with_panel import (
114
+ overview_test_data_prep_panel,
115
+ overview_test_data_prep_nonpanel,
116
+ initialize_data,
117
+ create_channel_summary,
118
+ create_contribution_pie,
119
+ create_contribuion_stacked_plot,
120
+ create_channel_spends_sales_plot,
121
+ format_numbers,
122
+ channel_name_formating,
123
+ )
124
+
125
+ import plotly.graph_objects as go
126
+ import streamlit_authenticator as stauth
127
+ import yaml
128
+ from yaml import SafeLoader
129
+ import time
130
+
131
+ def get_random_effects(media_data, panel_col, mdf):
132
+ random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
133
+ for i, market in enumerate(media_data[panel_col].unique()):
134
+ print(i, end="\r")
135
+ intercept = mdf.random_effects[market].values[0]
136
+ random_eff_df.loc[i, "random_effect"] = intercept
137
+ random_eff_df.loc[i, panel_col] = market
138
+
139
+ return random_eff_df
140
+
141
+ def process_train_and_test(train, test, features, panel_col, target_col):
142
+ X1 = train[features]
143
+
144
+ ss = MinMaxScaler()
145
+ X1 = pd.DataFrame(ss.fit_transform(X1), columns=X1.columns)
146
+
147
+ X1[panel_col] = train[panel_col]
148
+ X1[target_col] = train[target_col]
149
+
150
+ if test is not None:
151
+ X2 = test[features]
152
+ X2 = pd.DataFrame(ss.transform(X2), columns=X2.columns)
153
+ X2[panel_col] = test[panel_col]
154
+ X2[target_col] = test[target_col]
155
+ return X1, X2
156
+ return X1
157
+
158
+ def mdf_predict(X_df, mdf, random_eff_df):
159
+ X = X_df.copy()
160
+ X = pd.merge(
161
+ X,
162
+ random_eff_df[[panel_col, "random_effect"]],
163
+ on=panel_col,
164
+ how="left",
165
+ )
166
+ X["pred_fixed_effect"] = mdf.predict(X)
167
+
168
+ X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
169
+ X.to_csv("Test/merged_df_contri.csv", index=False)
170
+ X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
171
+
172
+ return X
173
+
174
+ # target='Revenue'
175
+
176
+ # is_panel=False
177
+ # is_panel = st.session_state['is_panel']
178
+ panel_col = [
179
+ col.lower()
180
+ .replace(".", "_")
181
+ .replace("@", "_")
182
+ .replace(" ", "_")
183
+ .replace("-", "")
184
+ .replace(":", "")
185
+ .replace("__", "_")
186
+ for col in st.session_state["bin_dict"]["Panel Level 1"]
187
+ ][
188
+ 0
189
+ ] # set the panel column
190
+ is_panel = True if len(panel_col) > 0 else False
191
+ date_col = "date"
192
+
193
+ # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
194
+ if (
195
+ "used_response_metrics" in st.session_state
196
+ and st.session_state["used_response_metrics"] != []
197
+ ):
198
+ sel_target_col = st.selectbox(
199
+ "Select the response metric",
200
+ st.session_state["used_response_metrics"],
201
+ )
202
+ target_col = (
203
+ sel_target_col.lower()
204
+ .replace(" ", "_")
205
+ .replace("-", "")
206
+ .replace(":", "")
207
+ .replace("__", "_")
208
+ )
209
+ else:
210
+ sel_target_col = "Total Approved Accounts - Revenue"
211
+ target_col = "total_approved_accounts_revenue"
212
+
213
+ target = sel_target_col
214
+
215
+ # Sprint4 - Look through all saved tuned models, only show saved models of the sel resp metric (target_col)
216
+ # saved_models = st.session_state['saved_model_names']
217
+ # Sprint4 - get the model obj of the selected model
218
+ # st.write(sel_model_dict)
219
+
220
+ # Sprint3 - Contribution
221
+ if is_panel:
222
+ # read tuned mixedLM model
223
+ # if st.session_state["tuned_model"] is not None :
224
+ if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
225
+ with open(
226
+ os.path.join(
227
+ st.session_state["project_path"], "tuned_model.pkl"
228
+ ),
229
+ "rb",
230
+ ) as file:
231
+ model_dict = pickle.load(file)
232
+ saved_models = list(model_dict.keys())
233
+ # st.write(saved_models)
234
+ required_saved_models = [
235
+ m.split("__")[0]
236
+ for m in saved_models
237
+ if m.split("__")[1] == target_col
238
+ ]
239
+ sel_model = st.selectbox(
240
+ "Select the model to review", required_saved_models
241
+ )
242
+ sel_model_dict = model_dict[sel_model + "__" + target_col]
243
+
244
+ model = sel_model_dict["Model_object"]
245
+ X_train = sel_model_dict["X_train_tuned"]
246
+ X_test = sel_model_dict["X_test_tuned"]
247
+ best_feature_set = sel_model_dict["feature_set"]
248
+
249
+ else: # if non tuned model to be used # Pending
250
+ with open(
251
+ os.path.join(
252
+ st.session_state["project_path"], "best_models.pkl"
253
+ ),
254
+ "rb",
255
+ ) as file:
256
+ model_dict = pickle.load(file)
257
+ # st.write(model_dict)
258
+ saved_models = list(model_dict.keys())
259
+ required_saved_models = [
260
+ m.split("__")[0]
261
+ for m in saved_models
262
+ if m.split("__")[1] == target_col
263
+ ]
264
+ sel_model = st.selectbox(
265
+ "Select the model to review", required_saved_models
266
+ )
267
+ sel_model_dict = model_dict[sel_model + "__" + target_col]
268
+ # st.write(sel_model, sel_model_dict)
269
+ model = sel_model_dict["Model_object"]
270
+ X_train = sel_model_dict["X_train"]
271
+ X_test = sel_model_dict["X_test"]
272
+ best_feature_set = sel_model_dict["feature_set"]
273
+
274
+ # Calculate contributions
275
+
276
+ with open(
277
+ os.path.join(st.session_state["project_path"], "data_import.pkl"),
278
+ "rb",
279
+ ) as f:
280
+ data = pickle.load(f)
281
+
282
+ # Accessing the loaded objects
283
+ st.session_state["orig_media_data"] = data["final_df"]
284
+
285
+ st.session_state["orig_media_data"].columns = [
286
+ col.lower()
287
+ .replace(".", "_")
288
+ .replace("@", "_")
289
+ .replace(" ", "_")
290
+ .replace("-", "")
291
+ .replace(":", "")
292
+ .replace("__", "_")
293
+ for col in st.session_state["orig_media_data"].columns
294
+ ]
295
+
296
+ media_data = st.session_state["media_data"]
297
+
298
+ # st.session_state['orig_media_data']=st.session_state["media_data"]
299
+
300
+ # st.write(media_data)
301
+
302
+ contri_df = pd.DataFrame()
303
+
304
+ y = []
305
+ y_pred = []
306
+
307
+ random_eff_df = get_random_effects(media_data, panel_col, model)
308
+ random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
309
+ random_eff_df["panel_effect"] = (
310
+ random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
311
+ )
312
+ # random_eff_df.to_csv("Test/random_eff_df_contri.csv", index=False)
313
+
314
+ coef_df = pd.DataFrame(model.fe_params)
315
+ coef_df.reset_index(inplace=True)
316
+ coef_df.columns = ["feature", "coef"]
317
+
318
+ # coef_df.reset_index().to_csv("Test/coef_df_contri1.csv",index=False)
319
+ # print(model.fe_params)
320
+
321
+ x_train_contribution = X_train.copy()
322
+ x_test_contribution = X_test.copy()
323
+
324
+ # preprocessing not needed since X_train is already preprocessed
325
+ # X1, X2 = process_train_and_test(x_train_contribution, x_test_contribution, best_feature_set, panel_col, target_col)
326
+ # x_train_contribution[best_feature_set] = X1[best_feature_set]
327
+ # x_test_contribution[best_feature_set] = X2[best_feature_set]
328
+
329
+ x_train_contribution = mdf_predict(
330
+ x_train_contribution, model, random_eff_df
331
+ )
332
+ x_test_contribution = mdf_predict(
333
+ x_test_contribution, model, random_eff_df
334
+ )
335
+
336
+ x_train_contribution = pd.merge(
337
+ x_train_contribution,
338
+ random_eff_df[[panel_col, "panel_effect"]],
339
+ on=panel_col,
340
+ how="left",
341
+ )
342
+ x_test_contribution = pd.merge(
343
+ x_test_contribution,
344
+ random_eff_df[[panel_col, "panel_effect"]],
345
+ on=panel_col,
346
+ how="left",
347
+ )
348
+
349
+ for i in range(len(coef_df))[1:]:
350
+ coef = coef_df.loc[i, "coef"]
351
+ col = coef_df.loc[i, "feature"]
352
+ x_train_contribution[str(col) + "_contr"] = (
353
+ coef * x_train_contribution[col]
354
+ )
355
+ x_test_contribution[str(col) + "_contr"] = (
356
+ coef * x_train_contribution[col]
357
+ )
358
+
359
+ x_train_contribution["sum_contributions"] = (
360
+ x_train_contribution.filter(regex="contr").sum(axis=1)
361
+ )
362
+ x_train_contribution["sum_contributions"] = (
363
+ x_train_contribution["sum_contributions"]
364
+ + x_train_contribution["panel_effect"]
365
+ )
366
+
367
+ x_test_contribution["sum_contributions"] = x_test_contribution.filter(
368
+ regex="contr"
369
+ ).sum(axis=1)
370
+ x_test_contribution["sum_contributions"] = (
371
+ x_test_contribution["sum_contributions"]
372
+ + x_test_contribution["panel_effect"]
373
+ )
374
+
375
+ # # test
376
+ x_train_contribution.to_csv(
377
+ "Test/x_train_contribution.csv", index=False
378
+ )
379
+ x_test_contribution.to_csv("Test/x_test_contribution.csv", index=False)
380
+ #
381
+ # st.session_state['orig_media_data'].to_csv("Test/transformed_data.csv",index=False)
382
+ # st.session_state['X_test_spends'].to_csv("Test/test_spends.csv",index=False)
383
+ # # st.write(st.session_state['orig_media_data'].columns)
384
+
385
+ # st.write(date_col,panel_col)
386
+ # st.write(x_test_contribution)
387
+
388
+ overview_test_data_prep_panel(
389
+ x_test_contribution,
390
+ st.session_state["orig_media_data"],
391
+ st.session_state["X_test_spends"],
392
+ date_col,
393
+ panel_col,
394
+ target_col,
395
+ )
396
+
397
+ else: # NON PANEL
398
+ if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
399
+ with open(
400
+ os.path.join(
401
+ st.session_state["project_path"], "tuned_model.pkl"
402
+ ),
403
+ "rb",
404
+ ) as file:
405
+ model_dict = pickle.load(file)
406
+ saved_models = list(model_dict.keys())
407
+ required_saved_models = [
408
+ m.split("__")[0]
409
+ for m in saved_models
410
+ if m.split("__")[1] == target_col
411
+ ]
412
+ sel_model = st.selectbox(
413
+ "Select the model to review", required_saved_models
414
+ )
415
+ sel_model_dict = model_dict[sel_model + "__" + target_col]
416
+
417
+ model = sel_model_dict["Model_object"]
418
+ X_train = sel_model_dict["X_train_tuned"]
419
+ X_test = sel_model_dict["X_test_tuned"]
420
+ best_feature_set = sel_model_dict["feature_set"]
421
+
422
+ else: # Sprint4
423
+ with open(
424
+ os.path.join(
425
+ st.session_state["project_path"], "best_models.pkl"
426
+ ),
427
+ "rb",
428
+ ) as file:
429
+ model_dict = pickle.load(file)
430
+ saved_models = list(model_dict.keys())
431
+ required_saved_models = [
432
+ m.split("__")[0]
433
+ for m in saved_models
434
+ if m.split("__")[1] == target_col
435
+ ]
436
+ sel_model = st.selectbox(
437
+ "Select the model to review", required_saved_models
438
+ )
439
+ sel_model_dict = model_dict[sel_model + "__" + target_col]
440
+
441
+ model = sel_model_dict["Model_object"]
442
+ X_train = sel_model_dict["X_train"]
443
+ X_test = sel_model_dict["X_test"]
444
+ best_feature_set = sel_model_dict["feature_set"]
445
+
446
+ x_train_contribution = X_train.copy()
447
+ x_test_contribution = X_test.copy()
448
+
449
+ x_train_contribution["pred"] = model.predict(
450
+ x_train_contribution[best_feature_set]
451
+ )
452
+ x_test_contribution["pred"] = model.predict(
453
+ x_test_contribution[best_feature_set]
454
+ )
455
+
456
+ for num, i in enumerate(model.params.values):
457
+ col = best_feature_set[num]
458
+ x_train_contribution[col + "_contr"] = X_train[col] * i
459
+ x_test_contribution[col + "_contr"] = X_test[col] * i
460
+
461
+ x_test_contribution.to_csv(
462
+ "Test/x_test_contribution_non_panel.csv", index=False
463
+ )
464
+ overview_test_data_prep_nonpanel(
465
+ x_test_contribution,
466
+ st.session_state["orig_media_data"].copy(),
467
+ st.session_state["X_test_spends"].copy(),
468
+ date_col,
469
+ target_col,
470
+ )
471
+ # for k, v in st.session_sta
472
+ # te.items():
473
+
474
+ # if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
475
+ # st.session_state[k] = v
476
+
477
+ # authenticator = st.session_state.get('authenticator')
478
+
479
+ # if authenticator is None:
480
+ # authenticator = load_authenticator()
481
+
482
+ # name, authentication_status, username = authenticator.login('Login', 'main')
483
+ # auth_status = st.session_state['authentication_status']
484
+
485
+ # if auth_status:
486
+ # authenticator.logout('Logout', 'main')
487
+
488
+ # is_state_initiaized = st.session_state.get('initialized',False)
489
+ # if not is_state_initiaized:
490
+
491
+ initialize_data(target_col)
492
+ scenario = st.session_state["scenario"]
493
+ raw_df = st.session_state["raw_df"]
494
+ st.header("Overview of previous spends")
495
+
496
+ # st.write(scenario.actual_total_spends)
497
+ # st.write(scenario.actual_total_sales)
498
+ columns = st.columns((1, 1, 3))
499
+
500
+ with columns[0]:
501
+ st.metric(
502
+ label="Spends",
503
+ value=format_numbers(float(scenario.actual_total_spends)),
504
+ )
505
+ ###print(f"##################### {scenario.actual_total_sales} ##################")
506
+ with columns[1]:
507
+ st.metric(
508
+ label=target,
509
+ value=format_numbers(
510
+ float(scenario.actual_total_sales), include_indicator=False
511
+ ),
512
+ )
513
+
514
+ actual_summary_df = create_channel_summary(scenario)
515
+ actual_summary_df["Channel"] = actual_summary_df["Channel"].apply(
516
+ channel_name_formating
517
+ )
518
+
519
+ columns = st.columns((2, 1))
520
+ with columns[0]:
521
+ with st.expander("Channel wise overview"):
522
+ st.markdown(
523
+ actual_summary_df.style.set_table_styles(
524
+ [
525
+ {
526
+ "selector": "th",
527
+ "props": [("background-color", "#11B6BD")],
528
+ },
529
+ {
530
+ "selector": "tr:nth-child(even)",
531
+ "props": [("background-color", "#11B6BD")],
532
+ },
533
+ ]
534
+ ).to_html(),
535
+ unsafe_allow_html=True,
536
+ )
537
+
538
+ st.markdown("<hr>", unsafe_allow_html=True)
539
+ ##############################
540
+
541
+ st.plotly_chart(
542
+ create_contribution_pie(scenario), use_container_width=True
543
+ )
544
+ st.markdown("<hr>", unsafe_allow_html=True)
545
+
546
+ ################################3
547
+ st.plotly_chart(
548
+ create_contribuion_stacked_plot(scenario), use_container_width=True
549
+ )
550
+ st.markdown("<hr>", unsafe_allow_html=True)
551
+ #######################################
552
+
553
+ selected_channel_name = st.selectbox(
554
+ "Channel",
555
+ st.session_state["channels_list"] + ["non media"],
556
+ format_func=channel_name_formating,
557
+ )
558
+ selected_channel = scenario.channels.get(selected_channel_name, None)
559
+
560
+ st.plotly_chart(
561
+ create_channel_spends_sales_plot(selected_channel),
562
+ use_container_width=True,
563
+ )
564
+
565
+ st.markdown("<hr>", unsafe_allow_html=True)
566
+
567
+ if st.checkbox("Save this session", key="save"):
568
+ project_dct_path = os.path.join(
569
+ st.session_state["project_path"], "project_dct.pkl"
570
+ )
571
+ with open(project_dct_path, "wb") as f:
572
+ pickle.dump(st.session_state["project_dct"], f)
573
+ update_db("7_Current_Media_Performance.py")
pages/8_Response_Curves.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.express as px
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ from utilities import (
6
+ channel_name_formating,
7
+ load_authenticator,
8
+ initialize_data,
9
+ fetch_actual_data,
10
+ )
11
+ from sklearn.metrics import r2_score
12
+ from collections import OrderedDict
13
+ from classes import class_from_dict, class_to_dict
14
+ import pickle
15
+ import json
16
+ import sqlite3
17
+ from utilities import update_db
18
+
19
+ for k, v in st.session_state.items():
20
+ if k not in ["logout", "login", "config"] and not k.startswith(
21
+ "FormSubmitter"
22
+ ):
23
+ st.session_state[k] = v
24
+
25
+
26
+ def s_curve(x, K, b, a, x0):
27
+ return K / (1 + b * np.exp(-a * (x - x0)))
28
+
29
+
30
+ def save_scenario(scenario_name):
31
+ """
32
+ Save the current scenario with the mentioned name in the session state
33
+
34
+ Parameters
35
+ ----------
36
+ scenario_name
37
+ Name of the scenario to be saved
38
+ """
39
+ if "saved_scenarios" not in st.session_state:
40
+ st.session_state = OrderedDict()
41
+
42
+ # st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
43
+ st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
44
+ st.session_state["scenario"]
45
+ )
46
+ st.session_state["scenario_input"] = ""
47
+ print(type(st.session_state["saved_scenarios"]))
48
+ with open("../saved_scenarios.pkl", "wb") as f:
49
+ pickle.dump(st.session_state["saved_scenarios"], f)
50
+
51
+
52
+ def reset_curve_parameters(
53
+ metrics=None, panel=None, selected_channel_name=None
54
+ ):
55
+ del st.session_state["K"]
56
+ del st.session_state["b"]
57
+ del st.session_state["a"]
58
+ del st.session_state["x0"]
59
+
60
+ if (
61
+ metrics is not None
62
+ and panel is not None
63
+ and selected_channel_name is not None
64
+ ):
65
+ if f"{metrics}#@{panel}#@{selected_channel_name}" in list(
66
+ st.session_state["update_rcs"].keys()
67
+ ):
68
+ del st.session_state["update_rcs"][
69
+ f"{metrics}#@{panel}#@{selected_channel_name}"
70
+ ]
71
+
72
+
73
+ def update_response_curve(
74
+ K_updated,
75
+ b_updated,
76
+ a_updated,
77
+ x0_updated,
78
+ metrics=None,
79
+ panel=None,
80
+ selected_channel_name=None,
81
+ ):
82
+ print(
83
+ "[DEBUG] update_response_curves: ",
84
+ st.session_state["project_dct"]["scenario_planner"].keys(),
85
+ )
86
+ st.session_state["project_dct"]["scenario_planner"][unique_key].channels[
87
+ selected_channel_name
88
+ ].response_curve_params = {
89
+ "K": st.session_state["K"],
90
+ "b": st.session_state["b"],
91
+ "a": st.session_state["a"],
92
+ "x0": st.session_state["x0"],
93
+ }
94
+
95
+ # if (
96
+ # metrics is not None
97
+ # and panel is not None
98
+ # and selected_channel_name is not None
99
+ # ):
100
+ # st.session_state["update_rcs"][
101
+ # f"{metrics}#@{panel}#@{selected_channel_name}"
102
+ # ] = {
103
+ # "K": K_updated,
104
+ # "b": b_updated,
105
+ # "a": a_updated,
106
+ # "x0": x0_updated,
107
+ # }
108
+
109
+ # st.session_state["scenario"].channels[
110
+ # selected_channel_name
111
+ # ].response_curve_params = {
112
+ # "K": K_updated,
113
+ # "b": b_updated,
114
+ # "a": a_updated,
115
+ # "x0": x0_updated,
116
+ # }
117
+
118
+
119
+ # authenticator = st.session_state.get('authenticator')
120
+ # if authenticator is None:
121
+ # authenticator = load_authenticator()
122
+
123
+ # name, authentication_status, username = authenticator.login('Login', 'main')
124
+ # auth_status = st.session_state.get('authentication_status')
125
+
126
+ # if auth_status == True:
127
+ # is_state_initiaized = st.session_state.get('initialized',False)
128
+ # if not is_state_initiaized:
129
+ # print("Scenario page state reloaded")
130
+
131
+ import pandas as pd
132
+
133
+
134
+ @st.cache_resource(show_spinner=False)
135
+ def panel_fetch(file_selected):
136
+ raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
137
+
138
+ if "Panel" in raw_data_mmm_df.columns:
139
+ panel = list(set(raw_data_mmm_df["Panel"]))
140
+ else:
141
+ raw_data_mmm_df = None
142
+ panel = None
143
+
144
+ return panel
145
+
146
+
147
+ import glob
148
+ import os
149
+
150
+
151
+ def get_excel_names(directory):
152
+ # Create a list to hold the final parts of the filenames
153
+ last_portions = []
154
+
155
+ # Patterns to match Excel files (.xlsx and .xls) that contain @#
156
+ patterns = [
157
+ os.path.join(directory, "*@#*.xlsx"),
158
+ os.path.join(directory, "*@#*.xls"),
159
+ ]
160
+
161
+ # Process each pattern
162
+ for pattern in patterns:
163
+ files = glob.glob(pattern)
164
+
165
+ # Extracting the last portion after @# for each file
166
+ for file in files:
167
+ base_name = os.path.basename(file)
168
+ last_portion = base_name.split("@#")[-1]
169
+ last_portion = last_portion.replace(".xlsx", "").replace(
170
+ ".xls", ""
171
+ ) # Removing extensions
172
+ last_portions.append(last_portion)
173
+
174
+ return last_portions
175
+
176
+
177
+ def name_formating(channel_name):
178
+ # Replace underscores with spaces
179
+ name_mod = channel_name.replace("_", " ")
180
+
181
+ # Capitalize the first letter of each word
182
+ name_mod = name_mod.title()
183
+
184
+ return name_mod
185
+
186
+
187
+ def fetch_panel_data():
188
+ print("DEBUG etch_panel_data: running... ")
189
+ file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx"
190
+ panel_selected = st.session_state["panel_selected_selectbox"]
191
+ print(panel_selected)
192
+ if panel_selected == "Aggregated":
193
+ (
194
+ st.session_state["actual_input_df"],
195
+ st.session_state["actual_contribution_df"],
196
+ ) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
197
+ else:
198
+ (
199
+ st.session_state["actual_input_df"],
200
+ st.session_state["actual_contribution_df"],
201
+ ) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
202
+
203
+ unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
204
+ print("unique_key")
205
+ if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
206
+ if panel_selected == "Aggregated":
207
+ initialize_data(
208
+ panel=panel_selected,
209
+ target_file=file_selected,
210
+ updated_rcs={},
211
+ metrics=metrics_selected,
212
+ )
213
+ panel = None
214
+ else:
215
+ initialize_data(
216
+ panel=panel_selected,
217
+ target_file=file_selected,
218
+ updated_rcs={},
219
+ metrics=metrics_selected,
220
+ )
221
+ st.session_state["project_dct"]["scenario_planner"][unique_key] = (
222
+ st.session_state["scenario"]
223
+ )
224
+ # print(
225
+ # "DEBUG etch_panel_data: ",
226
+ # st.session_state["project_dct"]["scenario_planner"][
227
+ # unique_key
228
+ # ].keys(),
229
+ # )
230
+
231
+ else:
232
+ st.session_state["scenario"] = st.session_state["project_dct"][
233
+ "scenario_planner"
234
+ ][unique_key]
235
+ st.session_state["rcs"] = {}
236
+ st.session_state["powers"] = {}
237
+
238
+ for channel_name, _channel in st.session_state["project_dct"][
239
+ "scenario_planner"
240
+ ][unique_key].channels.items():
241
+ st.session_state["rcs"][
242
+ channel_name
243
+ ] = _channel.response_curve_params
244
+ st.session_state["powers"][channel_name] = _channel.power
245
+
246
+ if "K" in st.session_state:
247
+ del st.session_state["K"]
248
+
249
+ if "b" in st.session_state:
250
+ del st.session_state["b"]
251
+
252
+ if "a" in st.session_state:
253
+ del st.session_state["a"]
254
+
255
+ if "x0" in st.session_state:
256
+ del st.session_state["x0"]
257
+
258
+
259
+ if "project_dct" not in st.session_state:
260
+ st.error("Please load a project from home")
261
+ st.stop()
262
+
263
+ database_file = r"DB\User.db"
264
+
265
+ conn = sqlite3.connect(
266
+ database_file, check_same_thread=False
267
+ ) # connection with sql db
268
+ c = conn.cursor()
269
+
270
+ st.subheader("Build Response Curves")
271
+
272
+
273
+ if "update_rcs" not in st.session_state:
274
+ st.session_state["update_rcs"] = {}
275
+
276
+ st.session_state["first_time"] = True
277
+
278
+ col1, col2, col3 = st.columns([1, 1, 1])
279
+
280
+ directory = "metrics_level_data"
281
+ metrics_list = get_excel_names(directory)
282
+
283
+
284
+ metrics_selected = col1.selectbox(
285
+ "Response Metrics",
286
+ metrics_list,
287
+ on_change=fetch_panel_data,
288
+ format_func=name_formating,
289
+ key="response_metrics_selectbox",
290
+ )
291
+
292
+
293
+ file_selected = (
294
+ f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
295
+ )
296
+
297
+ panel_list = panel_fetch(file_selected)
298
+ final_panel_list = ["Aggregated"] + panel_list
299
+
300
+ panel_selected = col3.selectbox(
301
+ "Panel",
302
+ final_panel_list,
303
+ on_change=fetch_panel_data,
304
+ key="panel_selected_selectbox",
305
+ )
306
+
307
+
308
+ is_state_initiaized = st.session_state.get("initialized_rcs", False)
309
+ print(is_state_initiaized)
310
+ if not is_state_initiaized:
311
+ print("DEBUG.....", "Here")
312
+ fetch_panel_data()
313
+ # if panel_selected == "Aggregated":
314
+ # initialize_data(panel=panel_selected, target_file=file_selected)
315
+ # panel = None
316
+ # else:
317
+ # initialize_data(panel=panel_selected, target_file=file_selected)
318
+
319
+ st.session_state["initialized_rcs"] = True
320
+
321
+ # channels_list = st.session_state["channels_list"]
322
+ unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
323
+ chanel_list_final = list(
324
+ st.session_state["project_dct"]["scenario_planner"][
325
+ unique_key
326
+ ].channels.keys()
327
+ ) + ["Others"]
328
+
329
+
330
+ selected_channel_name = col2.selectbox(
331
+ "Channel",
332
+ chanel_list_final,
333
+ format_func=channel_name_formating,
334
+ on_change=reset_curve_parameters,
335
+ key="selected_channel_name_selectbox",
336
+ )
337
+
338
+
339
+ rcs = st.session_state["rcs"]
340
+
341
+ if "K" not in st.session_state:
342
+ st.session_state["K"] = rcs[selected_channel_name]["K"]
343
+
344
+ if "b" not in st.session_state:
345
+ st.session_state["b"] = rcs[selected_channel_name]["b"]
346
+
347
+
348
+ if "a" not in st.session_state:
349
+ st.session_state["a"] = rcs[selected_channel_name]["a"]
350
+
351
+ if "x0" not in st.session_state:
352
+ st.session_state["x0"] = rcs[selected_channel_name]["x0"]
353
+
354
+
355
+ x = st.session_state["actual_input_df"][selected_channel_name].values
356
+ y = st.session_state["actual_contribution_df"][selected_channel_name].values
357
+
358
+
359
+ power = np.ceil(np.log(x.max()) / np.log(10)) - 3
360
+
361
+ print(f"DEBUG BUILD RCS: {selected_channel_name}")
362
+ print(f"DEBUG BUILD RCS: K : {st.session_state['K']}")
363
+ print(f"DEBUG BUILD RCS: b : {st.session_state['b']}")
364
+ print(f"DEBUG BUILD RCS: a : {st.session_state['a']}")
365
+ print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}")
366
+
367
+ # fig = px.scatter(x, s_curve(x/10**power,
368
+ # st.session_state['K'],
369
+ # st.session_state['b'],
370
+ # st.session_state['a'],
371
+ # st.session_state['x0']))
372
+
373
+ x_plot = np.linspace(0, 5 * max(x), 50)
374
+
375
+ fig = px.scatter(x=x, y=y)
376
+ fig.add_trace(
377
+ go.Scatter(
378
+ x=x_plot,
379
+ y=s_curve(
380
+ x_plot / 10**power,
381
+ st.session_state["K"],
382
+ st.session_state["b"],
383
+ st.session_state["a"],
384
+ st.session_state["x0"],
385
+ ),
386
+ line=dict(color="red"),
387
+ name="Modified",
388
+ ),
389
+ )
390
+
391
+ fig.add_trace(
392
+ go.Scatter(
393
+ x=x_plot,
394
+ y=s_curve(
395
+ x_plot / 10**power,
396
+ rcs[selected_channel_name]["K"],
397
+ rcs[selected_channel_name]["b"],
398
+ rcs[selected_channel_name]["a"],
399
+ rcs[selected_channel_name]["x0"],
400
+ ),
401
+ line=dict(color="rgba(0, 255, 0, 0.4)"),
402
+ name="Actual",
403
+ ),
404
+ )
405
+
406
+ fig.update_layout(title_text="Response Curve", showlegend=True)
407
+ fig.update_annotations(font_size=10)
408
+ fig.update_xaxes(title="Spends")
409
+ fig.update_yaxes(title="Revenue")
410
+
411
+ st.plotly_chart(fig, use_container_width=True)
412
+
413
+ r2 = r2_score(
414
+ y,
415
+ s_curve(
416
+ x / 10**power,
417
+ st.session_state["K"],
418
+ st.session_state["b"],
419
+ st.session_state["a"],
420
+ st.session_state["x0"],
421
+ ),
422
+ )
423
+
424
+ r2_actual = r2_score(
425
+ y,
426
+ s_curve(
427
+ x / 10**power,
428
+ rcs[selected_channel_name]["K"],
429
+ rcs[selected_channel_name]["b"],
430
+ rcs[selected_channel_name]["a"],
431
+ rcs[selected_channel_name]["x0"],
432
+ ),
433
+ )
434
+
435
+ columns = st.columns((1, 1, 2))
436
+ with columns[0]:
437
+ st.metric("R2 Modified", round(r2, 2))
438
+ with columns[1]:
439
+ st.metric("R2 Actual", round(r2_actual, 2))
440
+
441
+
442
+ st.markdown("#### Set Parameters", unsafe_allow_html=True)
443
+ columns = st.columns(4)
444
+
445
+ if "updated_parms" not in st.session_state:
446
+ st.session_state["updated_parms"] = {
447
+ "K_updated": 0,
448
+ "b_updated": 0,
449
+ "a_updated": 0,
450
+ "x0_updated": 0,
451
+ }
452
+
453
+ with columns[0]:
454
+ st.session_state["updated_parms"]["K_updated"] = st.number_input(
455
+ "K", key="K", format="%0.5f"
456
+ )
457
+ with columns[1]:
458
+ st.session_state["updated_parms"]["b_updated"] = st.number_input(
459
+ "b", key="b", format="%0.5f"
460
+ )
461
+ with columns[2]:
462
+ st.session_state["updated_parms"]["a_updated"] = st.number_input(
463
+ "a", key="a", step=0.0001, format="%0.5f"
464
+ )
465
+ with columns[3]:
466
+ st.session_state["updated_parms"]["x0_updated"] = st.number_input(
467
+ "x0", key="x0", format="%0.5f"
468
+ )
469
+
470
+ # st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = (
471
+ # st.session_state["updated_parms"]["K_updated"]
472
+ # )
473
+ # st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = (
474
+ # st.session_state["updated_parms"]["b_updated"]
475
+ # )
476
+ # st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = (
477
+ # st.session_state["updated_parms"]["a_updated"]
478
+ # )
479
+ # st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = (
480
+ # st.session_state["updated_parms"]["x0_updated"]
481
+ # )
482
+
483
+ update_col, reset_col = st.columns([1, 1])
484
+ if update_col.button(
485
+ "Update Parameters",
486
+ on_click=update_response_curve,
487
+ args=(
488
+ st.session_state["updated_parms"]["K_updated"],
489
+ st.session_state["updated_parms"]["b_updated"],
490
+ st.session_state["updated_parms"]["a_updated"],
491
+ st.session_state["updated_parms"]["x0_updated"],
492
+ metrics_selected,
493
+ panel_selected,
494
+ selected_channel_name,
495
+ ),
496
+ use_container_width=True,
497
+ ):
498
+ st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[
499
+ "updated_parms"
500
+ ]["K_updated"]
501
+ st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[
502
+ "updated_parms"
503
+ ]["b_updated"]
504
+ st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[
505
+ "updated_parms"
506
+ ]["a_updated"]
507
+ st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[
508
+ "updated_parms"
509
+ ]["x0_updated"]
510
+
511
+ reset_col.button(
512
+ "Reset Parameters",
513
+ on_click=reset_curve_parameters,
514
+ args=(metrics_selected, panel_selected, selected_channel_name),
515
+ use_container_width=True,
516
+ )
517
+
518
+ st.divider()
519
+ save_col, down_col = st.columns([1, 1])
520
+
521
+
522
+ with save_col:
523
+ file_name = st.text_input(
524
+ "rcs download file name",
525
+ key="file_name_input",
526
+ placeholder="File name",
527
+ label_visibility="collapsed",
528
+ )
529
+ down_col.download_button(
530
+ label="Download response curves",
531
+ data=json.dumps(rcs),
532
+ file_name=f"{file_name}.json",
533
+ mime="application/json",
534
+ disabled=len(file_name) == 0,
535
+ use_container_width=True,
536
+ )
537
+
538
+
539
+ def s_curve_derivative(x, K, b, a, x0):
540
+ # Derivative of the S-curve function
541
+ return (
542
+ a
543
+ * b
544
+ * K
545
+ * np.exp(-a * (x - x0))
546
+ / ((1 + b * np.exp(-a * (x - x0))) ** 2)
547
+ )
548
+
549
+
550
+ # Parameters of the S-curve
551
+ K = st.session_state["K"]
552
+ b = st.session_state["b"]
553
+ a = st.session_state["a"]
554
+ x0 = st.session_state["x0"]
555
+
556
+ # # Optimized spend value obtained from the tool
557
+ # optimized_spend = st.number_input(
558
+ # "value of x"
559
+ # ) # Replace this with your optimized spend value
560
+
561
+ # # Calculate the slope at the optimized spend value
562
+ # slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0)
563
+
564
+ # st.write("Slope ", slope_at_optimized_spend)
565
+
566
+
567
+ # Initialize a list to hold our rows
568
+ rows = []
569
+
570
+ # Iterate over the dictionary
571
+ for key, value in st.session_state["update_rcs"].items():
572
+ # Split the key into its components
573
+ metrics, panel, channel_name = key.split("#@")
574
+ # Create a new row with the components and the values
575
+ row = {
576
+ "Metrics": name_formating(metrics),
577
+ "Panel": name_formating(panel),
578
+ "Channel Name": channel_name,
579
+ "K": value["K"],
580
+ "b": value["b"],
581
+ "a": value["a"],
582
+ "x0": value["x0"],
583
+ }
584
+ # Append the row to our list
585
+ rows.append(row)
586
+
587
+ # Convert the list of rows into a DataFrame
588
+ updated_parms_df = pd.DataFrame(rows)
589
+
590
+ if len(list(st.session_state["update_rcs"].keys())) > 0:
591
+ st.markdown("#### Updated Parameters", unsafe_allow_html=True)
592
+ st.dataframe(updated_parms_df, hide_index=True)
593
+ else:
594
+ st.info("No parameters are updated")
595
+
596
+ update_db("8_Build_Response_Curves.py")
pages/9_Scenario_Planner.py ADDED
@@ -0,0 +1,1715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from numerize.numerize import numerize
3
+ import numpy as np
4
+ from functools import partial
5
+ from collections import OrderedDict
6
+ from plotly.subplots import make_subplots
7
+ import plotly.graph_objects as go
8
+ from utilities import (
9
+ format_numbers,
10
+ load_local_css,
11
+ set_header,
12
+ initialize_data,
13
+ load_authenticator,
14
+ send_email,
15
+ channel_name_formating,
16
+ )
17
+ from classes import class_from_dict, class_to_dict
18
+ import pickle
19
+ import streamlit_authenticator as stauth
20
+ import yaml
21
+ from yaml import SafeLoader
22
+ import re
23
+ import pandas as pd
24
+ import plotly.express as px
25
+ import logging
26
+ from utilities import update_db
27
+ import sqlite3
28
+
29
+
30
+ st.set_page_config(layout="wide")
31
+ load_local_css("styles.css")
32
+ set_header()
33
+
34
+ for k, v in st.session_state.items():
35
+ if k not in ["logout", "login", "config"] and not k.startswith(
36
+ "FormSubmitter"
37
+ ):
38
+ st.session_state[k] = v
39
+ # ======================================================== #
40
+ # ======================= Functions ====================== #
41
+ # ======================================================== #
42
+
43
+
44
+ def optimize(key, status_placeholder):
45
+ """
46
+ Optimize the spends for the sales
47
+ """
48
+
49
+ channel_list = [
50
+ key
51
+ for key, value in st.session_state["optimization_channels"].items()
52
+ if value
53
+ ]
54
+
55
+ if len(channel_list) > 0:
56
+ scenario = st.session_state["scenario"]
57
+ if key.lower() == "media spends":
58
+ with status_placeholder:
59
+ with st.spinner("Optimizing"):
60
+ result = st.session_state["scenario"].optimize(
61
+ st.session_state["total_spends_change"], channel_list
62
+ )
63
+ # elif key.lower() == "revenue":
64
+ else:
65
+ with status_placeholder:
66
+ with st.spinner("Optimizing"):
67
+
68
+ result = st.session_state["scenario"].optimize_spends(
69
+ st.session_state["total_sales_change"], channel_list
70
+ )
71
+ for channel_name, modified_spends in result:
72
+
73
+ st.session_state[channel_name] = numerize(
74
+ modified_spends
75
+ * scenario.channels[channel_name].conversion_rate,
76
+ 1,
77
+ )
78
+ prev_spends = (
79
+ st.session_state["scenario"]
80
+ .channels[channel_name]
81
+ .actual_total_spends
82
+ )
83
+ st.session_state[f"{channel_name}_change"] = round(
84
+ 100 * (modified_spends - prev_spends) / prev_spends, 2
85
+ )
86
+
87
+
88
+ def save_scenario(scenario_name):
89
+ """
90
+ Save the current scenario with the mentioned name in the session state
91
+
92
+ Parameters
93
+ ----------
94
+ scenario_name
95
+ Name of the scenario to be saved
96
+ """
97
+ if "saved_scenarios" not in st.session_state:
98
+ st.session_state = OrderedDict()
99
+
100
+ # st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
101
+ st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
102
+ st.session_state["scenario"]
103
+ )
104
+ st.session_state["scenario_input"] = ""
105
+ # print(type(st.session_state['saved_scenarios']))
106
+ with open("../saved_scenarios.pkl", "wb") as f:
107
+ pickle.dump(st.session_state["saved_scenarios"], f)
108
+
109
+
110
+ def update_sales_abs_slider():
111
+ actual_sales = st.session_state["scenario"].actual_total_sales
112
+ if validate_input(st.session_state["total_sales_change_abs_slider"]):
113
+ modified_sales = extract_number_for_string(
114
+ st.session_state["total_sales_change_abs_slider"]
115
+ )
116
+ st.session_state["total_sales_change"] = round(
117
+ ((modified_sales / actual_sales) - 1) * 100
118
+ )
119
+ st.session_state["total_sales_change_abs"] = numerize(
120
+ modified_sales, 1
121
+ )
122
+
123
+ st.session_state["project_dct"]["scenario_planner"][
124
+ "total_sales_change"
125
+ ] = st.session_state.total_sales_change
126
+
127
+
128
+ def update_sales_abs():
129
+ actual_sales = st.session_state["scenario"].actual_total_sales
130
+ if validate_input(st.session_state["total_sales_change_abs"]):
131
+ modified_sales = extract_number_for_string(
132
+ st.session_state["total_sales_change_abs"]
133
+ )
134
+ st.session_state["total_sales_change"] = round(
135
+ ((modified_sales / actual_sales) - 1) * 100
136
+ )
137
+ st.session_state["total_sales_change_abs_slider"] = numerize(
138
+ modified_sales, 1
139
+ )
140
+
141
+
142
+ def update_sales():
143
+ # print("DEBUG: running update_sales")
144
+ # st.session_state["project_dct"]["scenario_planner"][
145
+ # "total_sales_change"
146
+ # ] = st.session_state.total_sales_change
147
+ # st.session_state["total_spends_change"] = st.session_state[
148
+ # "total_sales_change"
149
+ # ]
150
+
151
+ st.session_state["total_sales_change_abs"] = numerize(
152
+ (1 + st.session_state["total_sales_change"] / 100)
153
+ * st.session_state["scenario"].actual_total_sales,
154
+ 1,
155
+ )
156
+ st.session_state["total_sales_change_abs_slider"] = numerize(
157
+ (1 + st.session_state["total_sales_change"] / 100)
158
+ * st.session_state["scenario"].actual_total_sales,
159
+ 1,
160
+ )
161
+ # update_spends()
162
+
163
+
164
+ def update_all_spends_abs_slider():
165
+ actual_spends = st.session_state["scenario"].actual_total_spends
166
+ if validate_input(st.session_state["total_spends_change_abs_slider"]):
167
+ modified_spends = extract_number_for_string(
168
+ st.session_state["total_spends_change_abs_slider"]
169
+ )
170
+ st.session_state["total_spends_change"] = round(
171
+ ((modified_spends / actual_spends) - 1) * 100
172
+ )
173
+ st.session_state["total_spends_change_abs"] = numerize(
174
+ modified_spends, 1
175
+ )
176
+
177
+ st.session_state["project_dct"]["scenario_planner"][
178
+ "total_spends_change"
179
+ ] = st.session_state.total_spends_change
180
+
181
+ update_all_spends()
182
+
183
+
184
+ # def update_all_spends_abs_slider():
185
+ # actual_spends = _scenario.actual_total_spends
186
+ # if validate_input(st.session_state["total_spends_change_abs_slider"]):
187
+ # print("#" * 100)
188
+ # print(st.session_state["total_spends_change_abs_slider"])
189
+ # print("#" * 100)
190
+
191
+ # modified_spends = extract_number_for_string(
192
+ # st.session_state["total_spends_change_abs_slider"]
193
+ # )
194
+ # st.session_state["total_spends_change"] = (
195
+ # (modified_spends / actual_spends) - 1
196
+ # ) * 100
197
+ # st.session_state["total_spends_change_abs"] = st.session_state[
198
+ # "total_spends_change_abs_slider"
199
+ # ]
200
+
201
+ # update_all_spends()
202
+
203
+
204
+ def update_all_spends_abs():
205
+ print("DEBUG: ", "inside update_all_spends_abs")
206
+ # print(st.session_state["total_spends_change_abs_slider_options"])
207
+
208
+ actual_spends = st.session_state["scenario"].actual_total_spends
209
+ if validate_input(st.session_state["total_spends_change_abs"]):
210
+ modified_spends = extract_number_for_string(
211
+ st.session_state["total_spends_change_abs"]
212
+ )
213
+ st.session_state["total_spends_change"] = (
214
+ (modified_spends / actual_spends) - 1
215
+ ) * 100
216
+ st.session_state["total_spends_change_abs_slider"] = numerize(
217
+ extract_number_for_string(
218
+ st.session_state["total_spends_change_abs"]
219
+ ),
220
+ 1,
221
+ )
222
+
223
+ st.session_state["project_dct"]["scenario_planner"][
224
+ "total_spends_change"
225
+ ] = st.session_state.total_spends_change
226
+
227
+ # print(
228
+ # "DEBUG UPDATE_ALL_SPENDS_ABS: ",
229
+ # st.session_state["total_spends_change"],
230
+ # )
231
+ update_all_spends()
232
+
233
+
234
+ def update_spends():
235
+ print("update_spends")
236
+ st.session_state["total_spends_change_abs"] = numerize(
237
+ (1 + st.session_state["total_spends_change"] / 100)
238
+ * st.session_state["scenario"].actual_total_spends,
239
+ 1,
240
+ )
241
+ st.session_state["total_spends_change_abs_slider"] = numerize(
242
+ (1 + st.session_state["total_spends_change"] / 100)
243
+ * st.session_state["scenario"].actual_total_spends,
244
+ 1,
245
+ )
246
+
247
+ st.session_state["project_dct"]["scenario_planner"][
248
+ "total_spends_change"
249
+ ] = st.session_state.total_spends_change
250
+
251
+ update_all_spends()
252
+
253
+
254
+ def update_all_spends():
255
+ """
256
+ Updates spends for all the channels with the given overall spends change
257
+ """
258
+ percent_change = st.session_state["total_spends_change"]
259
+ print("runs update_all")
260
+ for channel_name in list(
261
+ st.session_state["project_dct"]["scenario_planner"][
262
+ unique_key
263
+ ].channels.keys()
264
+ ):
265
+ st.session_state[f"{channel_name}_percent"] = percent_change
266
+ channel = st.session_state["scenario"].channels[channel_name]
267
+ current_spends = channel.actual_total_spends
268
+ modified_spends = (1 + percent_change / 100) * current_spends
269
+ st.session_state["scenario"].update(channel_name, modified_spends)
270
+ st.session_state[channel_name] = numerize(
271
+ modified_spends * channel.conversion_rate, 1
272
+ )
273
+ st.session_state[f"{channel_name}_change"] = percent_change
274
+
275
+
276
+ def extract_number_for_string(string_input):
277
+ string_input = string_input.upper()
278
+ if string_input.endswith("K"):
279
+ return float(string_input[:-1]) * 10**3
280
+ elif string_input.endswith("M"):
281
+ return float(string_input[:-1]) * 10**6
282
+ elif string_input.endswith("B"):
283
+ return float(string_input[:-1]) * 10**9
284
+
285
+
286
+ def validate_input(string_input):
287
+ pattern = r"\d+\.?\d*[K|M|B]$"
288
+ match = re.match(pattern, string_input)
289
+ if match is None:
290
+ return False
291
+ return True
292
+
293
+
294
+ def update_data_by_percent(channel_name):
295
+ prev_spends = (
296
+ st.session_state["scenario"].channels[channel_name].actual_total_spends
297
+ * st.session_state["scenario"].channels[channel_name].conversion_rate
298
+ )
299
+ modified_spends = prev_spends * (
300
+ 1 + st.session_state[f"{channel_name}_percent"] / 100
301
+ )
302
+
303
+ st.session_state[channel_name] = numerize(modified_spends, 1)
304
+
305
+ st.session_state["scenario"].update(
306
+ channel_name,
307
+ modified_spends
308
+ / st.session_state["scenario"].channels[channel_name].conversion_rate,
309
+ )
310
+
311
+
312
+ def update_data(channel_name):
313
+ """
314
+ Updates the spends for the given channel
315
+ """
316
+ print("tuns update_Data")
317
+ if validate_input(st.session_state[channel_name]):
318
+ modified_spends = extract_number_for_string(
319
+ st.session_state[channel_name]
320
+ )
321
+
322
+ prev_spends = (
323
+ st.session_state["scenario"]
324
+ .channels[channel_name]
325
+ .actual_total_spends
326
+ * st.session_state["scenario"]
327
+ .channels[channel_name]
328
+ .conversion_rate
329
+ )
330
+ st.session_state[f"{channel_name}_percent"] = round(
331
+ 100 * (modified_spends - prev_spends) / prev_spends, 2
332
+ )
333
+ st.session_state["scenario"].update(
334
+ channel_name,
335
+ modified_spends
336
+ / st.session_state["scenario"]
337
+ .channels[channel_name]
338
+ .conversion_rate,
339
+ )
340
+ # st.session_state['scenario'].update(channel_name, modified_spends)
341
+ # else:
342
+ # try:
343
+ # modified_spends = float(st.session_state[channel_name])
344
+ # prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate
345
+ # st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2)
346
+ # st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate)
347
+ # st.session_state[f'{channel_name}'] = numerize(modified_spends,1)
348
+ # except ValueError:
349
+ # st.write('Invalid input')
350
+
351
+
352
+ def select_channel_for_optimization(channel_name):
353
+ """
354
+ Marks the given channel for optimization
355
+ """
356
+ st.session_state["optimization_channels"][channel_name] = st.session_state[
357
+ f"{channel_name}_selected"
358
+ ]
359
+
360
+
361
+ def select_all_channels_for_optimization():
362
+ """
363
+ Marks all the channel for optimization
364
+ """
365
+ # print(
366
+ # "DEBUG: select_all_channels_for_opt",
367
+ # st.session_state["optimze_all_channels"],
368
+ # )
369
+
370
+ for channel_name in st.session_state["optimization_channels"].keys():
371
+ st.session_state[f"{channel_name}_selected"] = st.session_state[
372
+ "optimze_all_channels"
373
+ ]
374
+ st.session_state["optimization_channels"][channel_name] = (
375
+ st.session_state["optimze_all_channels"]
376
+ )
377
+ from pprint import pprint
378
+
379
+
380
+ def update_penalty():
381
+ """
382
+ Updates the penalty flag for sales calculation
383
+ """
384
+ st.session_state["scenario"].update_penalty(
385
+ st.session_state["apply_penalty"]
386
+ )
387
+
388
+
389
+ def reset_optimization():
390
+ print("DEBUG: ", "Running reset_optimization")
391
+ for channel_name in list(
392
+ st.session_state["project_dct"]["scenario_planner"][
393
+ unique_key
394
+ ].channels.keys()
395
+ ):
396
+ st.session_state[f"{channel_name}_selected"] = False
397
+ # st.session_state[f"{channel_name}_change"] = 0
398
+ st.session_state["optimze_all_channels"] = False
399
+ st.session_state["initialized"] = False
400
+ del st.session_state["total_sales_change_abs_slider"]
401
+ del st.session_state["total_sales_change_abs"]
402
+ del st.session_state["total_sales_change"]
403
+
404
+
405
+ def reset_scenario():
406
+ print("[DEBUG]: reset_scenario")
407
+ # def reset_scenario(panel_selected, file_selected, updated_rcs):
408
+ # #print(st.session_state['default_scenario_dict'])
409
+ # st.session_state['scenario'] = class_from_dict(st.session_state['default_scenario_dict'])
410
+ # for channel in st.session_state['scenario'].channels.values():
411
+ # st.session_state[channel.name] = float(channel.actual_total_spends * channel.conversion_rate)
412
+ for channel_name in list(
413
+ st.session_state["project_dct"]["scenario_planner"][
414
+ unique_key
415
+ ].channels.keys()
416
+ ):
417
+ st.session_state[f"{channel_name}_selected"] = False
418
+ # st.session_state[f"{channel_name}_change"] = 0
419
+ st.session_state["optimze_all_channels"] = False
420
+ st.session_state["initialized"] = False
421
+
422
+ del st.session_state["optimization_channels"]
423
+ panel_selected = st.session_state.get("panel_selected", 0)
424
+ file_selected = st.session_state["file_selected"]
425
+ update_rcs = st.session_state.get("update_rcs", None)
426
+
427
+ # print(f"## [DEBUG] [SCENARIO PLANNER][RESET SCENARIO]: {}")
428
+ del st.session_state["project_dct"]["scenario_planner"][
429
+ f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
430
+ ]
431
+ del st.session_state["total_sales_change_abs_slider"]
432
+ del st.session_state["total_sales_change_abs"]
433
+ del st.session_state["total_sales_change"]
434
+ # if panel_selected == "Aggregated":
435
+ # initialize_data(
436
+ # panel=panel_selected,
437
+ # target_file=file_selected,
438
+ # updated_rcs=updated_rcs,
439
+ # metrics=metrics_selected,
440
+ # )
441
+ # panel = None
442
+ # else:
443
+ # initialize_data(
444
+ # panel=panel_selected,
445
+ # target_file=file_selected,
446
+ # updated_rcs=updated_rcs,
447
+ # metrics=metrics_selected,
448
+ # )
449
+ # st.session_state["total_spends_change"] = 0
450
+ # update_all_spends()
451
+
452
+
453
+ def format_number(num):
454
+ if num >= 1_000_000:
455
+ return f"{num / 1_000_000:.2f}M"
456
+ elif num >= 1_000:
457
+ return f"{num / 1_000:.0f}K"
458
+ else:
459
+ return f"{num:.2f}"
460
+
461
+
462
+ def summary_plot(data, x, y, title, text_column):
463
+ fig = px.bar(
464
+ data,
465
+ x=x,
466
+ y=y,
467
+ orientation="h",
468
+ title=title,
469
+ text=text_column,
470
+ color="Channel_name",
471
+ )
472
+
473
+ # Convert text_column to numeric values
474
+ data[text_column] = pd.to_numeric(data[text_column], errors="coerce")
475
+
476
+ # Update the format of the displayed text based on magnitude
477
+ fig.update_traces(
478
+ texttemplate="%{text:.2s}",
479
+ textposition="outside",
480
+ hovertemplate="%{x:.2s}",
481
+ )
482
+
483
+ fig.update_layout(
484
+ xaxis_title=x, yaxis_title="Channel Name", showlegend=False
485
+ )
486
+ return fig
487
+
488
+
489
+ def s_curve(x, K, b, a, x0):
490
+ return K / (1 + b * np.exp(-a * (x - x0)))
491
+
492
+
493
+ def find_segment_value(x, roi, mroi):
494
+ start_value = x[0]
495
+ end_value = x[len(x) - 1]
496
+
497
+ # Condition for green region: Both MROI and ROI > 1
498
+ green_condition = (roi > 1) & (mroi > 1)
499
+ left_indices = np.where(green_condition)[0]
500
+ left_value = x[left_indices[0]] if left_indices.size > 0 else x[0]
501
+
502
+ right_indices = np.where(green_condition)[0]
503
+ right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0]
504
+
505
+ return start_value, end_value, left_value, right_value
506
+
507
+
508
+ def calculate_rgba(
509
+ start_value, end_value, left_value, right_value, current_channel_spends
510
+ ):
511
+ # Initialize alpha to None for clarity
512
+ alpha = None
513
+
514
+ # Determine the color and calculate relative_position and alpha based on the point's position
515
+ if start_value <= current_channel_spends <= left_value:
516
+ color = "yellow"
517
+ relative_position = (current_channel_spends - start_value) / (
518
+ left_value - start_value
519
+ )
520
+ alpha = 0.8 - (
521
+ 0.6 * relative_position
522
+ ) # Alpha decreases from start to end
523
+
524
+ elif left_value < current_channel_spends <= right_value:
525
+ color = "green"
526
+ relative_position = (current_channel_spends - left_value) / (
527
+ right_value - left_value
528
+ )
529
+ alpha = 0.8 - (
530
+ 0.6 * relative_position
531
+ ) # Alpha decreases from start to end
532
+
533
+ elif right_value < current_channel_spends <= end_value:
534
+ color = "red"
535
+ relative_position = (current_channel_spends - right_value) / (
536
+ end_value - right_value
537
+ )
538
+ alpha = 0.2 + (
539
+ 0.6 * relative_position
540
+ ) # Alpha increases from start to end
541
+
542
+ else:
543
+ # Default case, if the spends are outside the defined ranges
544
+ return "rgba(136, 136, 136, 0.5)" # Grey for values outside the range
545
+
546
+ # Ensure alpha is within the intended range in case of any calculation overshoot
547
+ alpha = max(0.2, min(alpha, 0.8))
548
+
549
+ # Define color codes for RGBA
550
+ color_codes = {
551
+ "yellow": "255, 255, 0", # RGB for yellow
552
+ "green": "0, 128, 0", # RGB for green
553
+ "red": "255, 0, 0", # RGB for red
554
+ }
555
+
556
+ rgba = f"rgba({color_codes[color]}, {alpha})"
557
+ return rgba
558
+
559
+
560
+ def debug_temp(x_test, power, K, b, a, x0):
561
+ print("*" * 100)
562
+ # Calculate the count of bins
563
+ count_lower_bin = sum(1 for x in x_test if x <= 2524)
564
+ count_center_bin = sum(1 for x in x_test if x > 2524 and x <= 3377)
565
+ count_ = sum(1 for x in x_test if x > 3377)
566
+
567
+ print(
568
+ f"""
569
+ lower : {count_lower_bin}
570
+ center : {count_center_bin}
571
+ upper : {count_}
572
+ """
573
+ )
574
+
575
+
576
+ # @st.cache
577
+ def plot_response_curves():
578
+ cols = 4
579
+ rows = (
580
+ len(channels_list) // cols
581
+ if len(channels_list) % cols == 0
582
+ else len(channels_list) // cols + 1
583
+ )
584
+ rcs = st.session_state["rcs"]
585
+ shapes = []
586
+ fig = make_subplots(rows=rows, cols=cols, subplot_titles=channels_list)
587
+ for i in range(0, len(channels_list)):
588
+ col = channels_list[i]
589
+ x_actual = st.session_state["scenario"].channels[col].actual_spends
590
+ # x_modified = st.session_state["scenario"].channels[col].modified_spends
591
+
592
+ power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
593
+
594
+ K = rcs[col]["K"]
595
+ b = rcs[col]["b"]
596
+ a = rcs[col]["a"]
597
+ x0 = rcs[col]["x0"]
598
+
599
+ x_plot = np.linspace(0, 5 * x_actual.sum(), 50)
600
+
601
+ x, y, marginal_roi = [], [], []
602
+ for x_p in x_plot:
603
+ x.append(x_p * x_actual / x_actual.sum())
604
+
605
+ for index in range(len(x_plot)):
606
+ y.append(s_curve(x[index] / 10**power, K, b, a, x0))
607
+
608
+ for index in range(len(x_plot)):
609
+ marginal_roi.append(
610
+ a
611
+ * y[index]
612
+ * (1 - y[index] / np.maximum(K, np.finfo(float).eps))
613
+ )
614
+
615
+ x = (
616
+ np.sum(x, axis=1)
617
+ * st.session_state["scenario"].channels[col].conversion_rate
618
+ )
619
+ y = np.sum(y, axis=1)
620
+ marginal_roi = (
621
+ np.average(marginal_roi, axis=1)
622
+ / st.session_state["scenario"].channels[col].conversion_rate
623
+ )
624
+
625
+ roi = y / np.maximum(x, np.finfo(float).eps)
626
+
627
+ fig.add_trace(
628
+ go.Scatter(
629
+ x=x,
630
+ y=y,
631
+ name=col,
632
+ customdata=np.stack((roi, marginal_roi), axis=-1),
633
+ hovertemplate="Spend:%{x:$.2s}<br>Sale:%{y:$.2s}<br>ROI:%{customdata[0]:.3f}<br>MROI:%{customdata[1]:.3f}",
634
+ line=dict(color="blue"),
635
+ ),
636
+ row=1 + (i) // cols,
637
+ col=i % cols + 1,
638
+ )
639
+
640
+ x_optimal = (
641
+ st.session_state["scenario"].channels[col].modified_total_spends
642
+ * st.session_state["scenario"].channels[col].conversion_rate
643
+ )
644
+ y_optimal = (
645
+ st.session_state["scenario"].channels[col].modified_total_sales
646
+ )
647
+
648
+ # if col == "Paid_social_others":
649
+ # debug_temp(x_optimal * x_actual / x_actual.sum(), power, K, b, a, x0)
650
+
651
+ fig.add_trace(
652
+ go.Scatter(
653
+ x=[x_optimal],
654
+ y=[y_optimal],
655
+ name=col,
656
+ legendgroup=col,
657
+ showlegend=False,
658
+ marker=dict(color=["black"]),
659
+ ),
660
+ row=1 + (i) // cols,
661
+ col=i % cols + 1,
662
+ )
663
+
664
+ shapes.append(
665
+ go.layout.Shape(
666
+ type="line",
667
+ x0=0,
668
+ y0=y_optimal,
669
+ x1=x_optimal,
670
+ y1=y_optimal,
671
+ line_width=1,
672
+ line_dash="dash",
673
+ line_color="black",
674
+ xref=f"x{i+1}",
675
+ yref=f"y{i+1}",
676
+ )
677
+ )
678
+
679
+ shapes.append(
680
+ go.layout.Shape(
681
+ type="line",
682
+ x0=x_optimal,
683
+ y0=0,
684
+ x1=x_optimal,
685
+ y1=y_optimal,
686
+ line_width=1,
687
+ line_dash="dash",
688
+ line_color="black",
689
+ xref=f"x{i+1}",
690
+ yref=f"y{i+1}",
691
+ )
692
+ )
693
+
694
+ start_value, end_value, left_value, right_value = find_segment_value(
695
+ x,
696
+ roi,
697
+ marginal_roi,
698
+ )
699
+
700
+ # Adding background colors
701
+ y_max = y.max() * 1.3 # 30% extra space above the max
702
+
703
+ # Yellow region
704
+ shapes.append(
705
+ go.layout.Shape(
706
+ type="rect",
707
+ x0=start_value,
708
+ y0=0,
709
+ x1=left_value,
710
+ y1=y_max,
711
+ line=dict(width=0),
712
+ fillcolor="rgba(255, 255, 0, 0.3)",
713
+ layer="below",
714
+ xref=f"x{i+1}",
715
+ yref=f"y{i+1}",
716
+ )
717
+ )
718
+
719
+ # Green region
720
+ shapes.append(
721
+ go.layout.Shape(
722
+ type="rect",
723
+ x0=left_value,
724
+ y0=0,
725
+ x1=right_value,
726
+ y1=y_max,
727
+ line=dict(width=0),
728
+ fillcolor="rgba(0, 255, 0, 0.3)",
729
+ layer="below",
730
+ xref=f"x{i+1}",
731
+ yref=f"y{i+1}",
732
+ )
733
+ )
734
+
735
+ # Red region
736
+ shapes.append(
737
+ go.layout.Shape(
738
+ type="rect",
739
+ x0=right_value,
740
+ y0=0,
741
+ x1=end_value,
742
+ y1=y_max,
743
+ line=dict(width=0),
744
+ fillcolor="rgba(255, 0, 0, 0.3)",
745
+ layer="below",
746
+ xref=f"x{i+1}",
747
+ yref=f"y{i+1}",
748
+ )
749
+ )
750
+
751
+ fig.update_layout(
752
+ # height=1000,
753
+ # width=1000,
754
+ title_text=f"Response Curves (X: Spends Vs Y: {target})",
755
+ showlegend=False,
756
+ shapes=shapes,
757
+ )
758
+ fig.update_annotations(font_size=10)
759
+ # fig.update_xaxes(title="Spends")
760
+ # fig.update_yaxes(title=target)
761
+ fig.update_yaxes(
762
+ gridcolor="rgba(136, 136, 136, 0.5)", gridwidth=0.5, griddash="dash"
763
+ )
764
+
765
+ return fig
766
+
767
+
768
+ # ======================================================== #
769
+ # ==================== HTML Components =================== #
770
+ # ======================================================== #
771
+
772
+
773
+ def generate_spending_header(heading):
774
+ return st.markdown(
775
+ f"""<h2 class="spends-header">{heading}</h2>""", unsafe_allow_html=True
776
+ )
777
+
778
+
779
+ def save_checkpoint():
780
+ project_dct_path = os.path.join(
781
+ st.session_state["project_path"], "project_dct.pkl"
782
+ )
783
+
784
+ try:
785
+ pickle.dumps(st.session_state["project_dct"])
786
+ with open(project_dct_path, "wb") as f:
787
+ pickle.dump(st.session_state["project_dct"], f)
788
+ except Exception:
789
+ # with warning_placeholder:
790
+ st.toast("Unknown Issue, please reload the page.")
791
+
792
+
793
+ def reset_checkpoint():
794
+ st.session_state["project_dct"]["scenario_planner"] = {}
795
+ save_checkpoint()
796
+
797
+
798
+ # ======================================================== #
799
+ # =================== Session variables ================== #
800
+ # ======================================================== #
801
+
802
+ with open("config.yaml") as file:
803
+ config = yaml.load(file, Loader=SafeLoader)
804
+ st.session_state["config"] = config
805
+
806
+ authenticator = stauth.Authenticate(
807
+ config["credentials"],
808
+ config["cookie"]["name"],
809
+ config["cookie"]["key"],
810
+ config["cookie"]["expiry_days"],
811
+ config["preauthorized"],
812
+ )
813
+ st.session_state["authenticator"] = authenticator
814
+ name, authentication_status, username = authenticator.login("Login", "main")
815
+ auth_status = st.session_state.get("authentication_status")
816
+
817
+ import os
818
+ import glob
819
+
820
+
821
+ def get_excel_names(directory):
822
+ # Create a list to hold the final parts of the filenames
823
+ last_portions = []
824
+
825
+ # Patterns to match Excel files (.xlsx and .xls) that contain @#
826
+ patterns = [
827
+ os.path.join(directory, "*@#*.xlsx"),
828
+ os.path.join(directory, "*@#*.xls"),
829
+ ]
830
+
831
+ # Process each pattern
832
+ for pattern in patterns:
833
+ files = glob.glob(pattern)
834
+
835
+ # Extracting the last portion after @# for each file
836
+ for file in files:
837
+ base_name = os.path.basename(file)
838
+ last_portion = base_name.split("@#")[-1]
839
+ last_portion = last_portion.replace(".xlsx", "").replace(
840
+ ".xls", ""
841
+ ) # Removing extensions
842
+ last_portions.append(last_portion)
843
+
844
+ return last_portions
845
+
846
+
847
+ def name_formating(channel_name):
848
+ # Replace underscores with spaces
849
+ name_mod = channel_name.replace("_", " ")
850
+
851
+ # Capitalize the first letter of each word
852
+ name_mod = name_mod.title()
853
+
854
+ return name_mod
855
+
856
+
857
+ @st.cache_resource(show_spinner=False)
858
+ def panel_fetch(file_selected):
859
+ raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
860
+
861
+ if "Panel" in raw_data_mmm_df.columns:
862
+ panel = list(set(raw_data_mmm_df["Panel"]))
863
+ else:
864
+ raw_data_mmm_df = None
865
+ panel = None
866
+
867
+ return panel
868
+
869
+
870
+ if auth_status is True:
871
+ authenticator.logout("Logout", "main")
872
+
873
+ if "project_dct" not in st.session_state:
874
+ st.error("Please load a project from home")
875
+ st.stop()
876
+
877
+ database_file = r"DB\User.db"
878
+
879
+ conn = sqlite3.connect(
880
+ database_file, check_same_thread=False
881
+ ) # connection with sql db
882
+ c = conn.cursor()
883
+
884
+ with st.sidebar:
885
+ st.button("Save checkpoint", on_click=save_checkpoint)
886
+ st.button("Reset Checkpoint", on_click=reset_checkpoint)
887
+
888
+ warning_placeholder = st.empty()
889
+
890
+ st.header("Scenario Planner")
891
+
892
+ st.markdown("**Simulation**")
893
+
894
+ # st.subheader("Simulation")
895
+ col1, col2 = st.columns([1, 1])
896
+
897
+ # Get metric and panel from last saved state
898
+ if "last_saved_metric" not in st.session_state:
899
+ st.session_state["last_saved_metric"] = st.session_state[
900
+ "project_dct"
901
+ ]["scenario_planner"].get("metric_selected", 0)
902
+ # st.session_state["last_saved_metric"] = st.session_state[
903
+ # "project_dct"
904
+ # ]["scenario_planner"].get("metric_selected", 0)
905
+
906
+ if "last_saved_panel" not in st.session_state:
907
+ st.session_state["last_saved_panel"] = st.session_state["project_dct"][
908
+ "scenario_planner"
909
+ ].get("panel_selected", 0)
910
+ # st.session_state["last_saved_panel"] = st.session_state["project_dct"][
911
+ # "scenario_planner"
912
+ # ].get("panel_selected", 0)
913
+
914
+ # Response Metrics
915
+ directory = "metrics_level_data"
916
+ metrics_list = get_excel_names(directory)
917
+ metrics_selected = col1.selectbox(
918
+ "Response Metrics",
919
+ metrics_list,
920
+ format_func=name_formating,
921
+ index=st.session_state["last_saved_metric"],
922
+ on_change=reset_optimization,
923
+ key="metric_selected",
924
+ )
925
+
926
+ # Target
927
+ target = name_formating(metrics_selected)
928
+
929
+ file_selected = f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
930
+ # print(f"[DEBUG]: {metrics_selected}")
931
+ # print(f"[DEBUG]: {file_selected}")
932
+ st.session_state["file_selected"] = file_selected
933
+ # Panel List
934
+ panel_list = panel_fetch(file_selected)
935
+ panel_list_final = ["Aggregated"] + panel_list
936
+
937
+ # Panel Selected
938
+ panel_selected = col2.selectbox(
939
+ "Panel",
940
+ panel_list_final,
941
+ on_change=reset_optimization,
942
+ key="panel_selected",
943
+ index=st.session_state["last_saved_panel"],
944
+ )
945
+
946
+ unique_key = f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
947
+
948
+ if "update_rcs" in st.session_state:
949
+ updated_rcs = st.session_state["update_rcs"]
950
+ else:
951
+ updated_rcs = None
952
+
953
+ if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
954
+ if panel_selected == "Aggregated":
955
+ initialize_data(
956
+ panel=panel_selected,
957
+ target_file=file_selected,
958
+ updated_rcs=updated_rcs,
959
+ metrics=metrics_selected,
960
+ )
961
+ panel = None
962
+ else:
963
+ initialize_data(
964
+ panel=panel_selected,
965
+ target_file=file_selected,
966
+ updated_rcs=updated_rcs,
967
+ metrics=metrics_selected,
968
+ )
969
+ st.session_state["project_dct"]["scenario_planner"][unique_key] = (
970
+ st.session_state["scenario"]
971
+ )
972
+
973
+ else:
974
+ st.session_state["scenario"] = st.session_state["project_dct"][
975
+ "scenario_planner"
976
+ ][unique_key]
977
+ st.session_state["rcs"] = {}
978
+ st.session_state["powers"] = {}
979
+ if "optimization_channels" not in st.session_state:
980
+ st.session_state["optimization_channels"] = {}
981
+
982
+ for channel_name, _channel in st.session_state["project_dct"][
983
+ "scenario_planner"
984
+ ][unique_key].channels.items():
985
+ st.session_state[channel_name] = numerize(
986
+ _channel.modified_total_spends, 1
987
+ )
988
+ st.session_state["rcs"][
989
+ channel_name
990
+ ] = _channel.response_curve_params
991
+ st.session_state["powers"][channel_name] = _channel.power
992
+ if channel_name not in st.session_state["optimization_channels"]:
993
+ st.session_state["optimization_channels"][channel_name] = False
994
+
995
+ if "first_time" not in st.session_state:
996
+ st.session_state["first_time"] = True
997
+ st.session_state["first_run_scenario"] = True
998
+
999
+ # Check if state is initiaized
1000
+ is_state_initiaized = st.session_state.get("initialized", False)
1001
+
1002
+ # if not is_state_initiaized:
1003
+ # print("running initialize...")
1004
+ # # initialize_data()
1005
+ # if panel_selected == "Aggregated":
1006
+ # initialize_data(
1007
+ # panel=panel_selected,
1008
+ # target_file=file_selected,
1009
+ # updated_rcs=updated_rcs,
1010
+ # metrics=metrics_selected,
1011
+ # )
1012
+ # panel = None
1013
+ # else:
1014
+ # initialize_data(
1015
+ # panel=panel_selected,
1016
+ # target_file=file_selected,
1017
+ # updated_rcs=updated_rcs,
1018
+ # metrics=metrics_selected,
1019
+ # )
1020
+ # st.session_state["initialized"] = True
1021
+ # st.session_state["first_time"] = False
1022
+
1023
+ # Channels List
1024
+ channels_list = list(
1025
+ st.session_state["project_dct"]["scenario_planner"][
1026
+ unique_key
1027
+ ].channels.keys()
1028
+ )
1029
+
1030
+ # ======================================================== #
1031
+ # ========================== UI ========================== #
1032
+ # ======================================================== #
1033
+
1034
+ main_header = st.columns((2, 2))
1035
+ sub_header = st.columns((1, 1, 1, 1))
1036
+ # _scenario = st.session_state["scenario"]
1037
+
1038
+ st.session_state.total_spends_change = round(
1039
+ (
1040
+ st.session_state["scenario"].modified_total_spends
1041
+ / st.session_state["scenario"].actual_total_spends
1042
+ - 1
1043
+ )
1044
+ * 100
1045
+ )
1046
+
1047
+ if "total_sales_change" not in st.session_state:
1048
+ st.session_state.total_sales_change = round(
1049
+ (
1050
+ st.session_state["scenario"].modified_total_sales
1051
+ / st.session_state["scenario"].actual_total_sales
1052
+ - 1
1053
+ )
1054
+ * 100
1055
+ )
1056
+
1057
+ st.session_state["total_spends_change_abs"] = numerize(
1058
+ st.session_state["scenario"].modified_total_spends,
1059
+ 1,
1060
+ )
1061
+ if "total_sales_change_abs" not in st.session_state:
1062
+ st.session_state["total_sales_change_abs"] = numerize(
1063
+ st.session_state["scenario"].modified_total_sales,
1064
+ 1,
1065
+ )
1066
+
1067
+ # if "total_spends_change_abs_slider" not in st.session_state:
1068
+ st.session_state.total_spends_change_abs_slider = numerize(
1069
+ st.session_state["scenario"].modified_total_spends, 1
1070
+ )
1071
+
1072
+ if "total_sales_change_abs_slider" not in st.session_state:
1073
+ st.session_state.total_sales_change_abs_slider = numerize(
1074
+ st.session_state["scenario"].actual_total_sales, 1
1075
+ )
1076
+
1077
+ st.session_state["allow_sales_update"] = True
1078
+
1079
+ st.session_state["allow_spends_update"] = True
1080
+
1081
+ # if "panel_selected" not in st.session_state:
1082
+ # st.session_state["panel_selected"] = 0
1083
+
1084
+ with main_header[0]:
1085
+ st.subheader("Actual")
1086
+
1087
+ with main_header[-1]:
1088
+ st.subheader("Simulated")
1089
+
1090
+ with sub_header[0]:
1091
+ st.metric(
1092
+ label="Spends",
1093
+ value=format_numbers(
1094
+ st.session_state["scenario"].actual_total_spends
1095
+ ),
1096
+ )
1097
+
1098
+ with sub_header[1]:
1099
+ st.metric(
1100
+ label=target,
1101
+ value=format_numbers(
1102
+ float(st.session_state["scenario"].actual_total_sales),
1103
+ include_indicator=False,
1104
+ ),
1105
+ )
1106
+
1107
+ with sub_header[2]:
1108
+ st.metric(
1109
+ label="Spends",
1110
+ value=format_numbers(
1111
+ st.session_state["scenario"].modified_total_spends
1112
+ ),
1113
+ delta=numerize(st.session_state["scenario"].delta_spends, 1),
1114
+ )
1115
+
1116
+ with sub_header[3]:
1117
+ st.metric(
1118
+ label=target,
1119
+ value=format_numbers(
1120
+ float(st.session_state["scenario"].modified_total_sales),
1121
+ include_indicator=False,
1122
+ ),
1123
+ delta=numerize(st.session_state["scenario"].delta_sales, 1),
1124
+ )
1125
+
1126
+ with st.expander("Channel Spends Simulator", expanded=True):
1127
+ _columns1 = st.columns((2, 2, 1, 1))
1128
+ with _columns1[0]:
1129
+ optimization_selection = st.selectbox(
1130
+ "Optimize",
1131
+ options=["Media Spends", target],
1132
+ key="optimization_key_value",
1133
+ )
1134
+
1135
+ with _columns1[1]:
1136
+ st.markdown("#")
1137
+ # if st.checkbox(
1138
+ # label="Optimize all Channels",
1139
+ # key="optimze_all_channels",
1140
+ # value=False,
1141
+ # # on_change=select_all_channels_for_optimization,
1142
+ # ):
1143
+ # select_all_channels_for_optimization()
1144
+
1145
+ st.checkbox(
1146
+ label="Optimize all Channels",
1147
+ key="optimze_all_channels",
1148
+ on_change=select_all_channels_for_optimization,
1149
+ )
1150
+
1151
+ with _columns1[2]:
1152
+ st.markdown("#")
1153
+ # st.button(
1154
+ # "Optimize",
1155
+ # on_click=optimize,
1156
+ # args=(st.session_state["optimization_key_value"]),
1157
+ # use_container_width=True,
1158
+ # )
1159
+
1160
+ optimize_placeholder = st.empty()
1161
+
1162
+ with _columns1[3]:
1163
+ st.markdown("#")
1164
+ st.button(
1165
+ "Reset",
1166
+ on_click=reset_scenario,
1167
+ # args=(panel_selected, file_selected, updated_rcs),
1168
+ use_container_width=True,
1169
+ )
1170
+
1171
+ _columns2 = st.columns((2, 2, 2))
1172
+ if st.session_state["optimization_key_value"] == "Media Spends":
1173
+
1174
+ # update_spends()
1175
+
1176
+ with _columns2[0]:
1177
+ spend_input = st.text_input(
1178
+ "Absolute",
1179
+ key="total_spends_change_abs",
1180
+ # label_visibility="collapsed",
1181
+ on_change=update_all_spends_abs,
1182
+ )
1183
+
1184
+ with _columns2[1]:
1185
+ st.number_input(
1186
+ "Percent Change",
1187
+ key="total_spends_change",
1188
+ min_value=-50,
1189
+ max_value=50,
1190
+ step=1,
1191
+ on_change=update_spends,
1192
+ )
1193
+
1194
+ with _columns2[2]:
1195
+ scenario = st.session_state["project_dct"]["scenario_planner"][
1196
+ unique_key
1197
+ ]
1198
+ min_value = round(scenario.actual_total_spends * 0.5)
1199
+ max_value = round(scenario.actual_total_spends * 1.5)
1200
+ st.session_state["total_spends_change_abs_slider_options"] = [
1201
+ numerize(value, 1)
1202
+ for value in range(min_value, max_value + 1, int(1e4))
1203
+ ]
1204
+
1205
+ st.select_slider(
1206
+ "Absolute Slider",
1207
+ options=st.session_state[
1208
+ "total_spends_change_abs_slider_options"
1209
+ ],
1210
+ key="total_spends_change_abs_slider",
1211
+ on_change=update_all_spends_abs_slider,
1212
+ )
1213
+
1214
+ elif st.session_state["optimization_key_value"] == target:
1215
+ # update_sales()
1216
+
1217
+ with _columns2[0]:
1218
+ sales_input = st.text_input(
1219
+ "Absolute",
1220
+ key="total_sales_change_abs",
1221
+ on_change=update_sales_abs,
1222
+ )
1223
+
1224
+ with _columns2[1]:
1225
+ st.number_input(
1226
+ "Percent Change",
1227
+ key="total_sales_change",
1228
+ min_value=-50,
1229
+ max_value=50,
1230
+ step=1,
1231
+ on_change=update_sales,
1232
+ )
1233
+
1234
+ with _columns2[2]:
1235
+ min_value = round(
1236
+ st.session_state["scenario"].actual_total_sales * 0.5
1237
+ )
1238
+ max_value = round(
1239
+ st.session_state["scenario"].actual_total_sales * 1.5
1240
+ )
1241
+ st.session_state["total_sales_change_abs_slider_options"] = [
1242
+ numerize(value, 1)
1243
+ for value in range(min_value, max_value + 1, int(1e5))
1244
+ ]
1245
+
1246
+ st.select_slider(
1247
+ "Absolute Slider",
1248
+ options=st.session_state[
1249
+ "total_sales_change_abs_slider_options"
1250
+ ],
1251
+ key="total_sales_change_abs_slider",
1252
+ on_change=update_sales_abs_slider,
1253
+ )
1254
+
1255
+ if (
1256
+ not st.session_state["allow_sales_update"]
1257
+ and optimization_selection == target
1258
+ ):
1259
+ st.warning("Invalid Input")
1260
+
1261
+ if (
1262
+ not st.session_state["allow_spends_update"]
1263
+ and optimization_selection == "Media Spends"
1264
+ ):
1265
+ st.warning("Invalid Input")
1266
+
1267
+ status_placeholder = st.empty()
1268
+
1269
+ # if optimize_placeholder.button("Optimize", use_container_width=True):
1270
+ # optimize(st.session_state["optimization_key_value"], status_placeholder)
1271
+ # st.rerun()
1272
+
1273
+ optimize_placeholder.button(
1274
+ "Optimize",
1275
+ on_click=optimize,
1276
+ args=(
1277
+ st.session_state["optimization_key_value"],
1278
+ status_placeholder,
1279
+ ),
1280
+ use_container_width=True,
1281
+ )
1282
+
1283
+ st.markdown(
1284
+ """<hr class="spends-heading-seperator">""", unsafe_allow_html=True
1285
+ )
1286
+ _columns = st.columns((2.5, 2, 1.5, 1.5, 1))
1287
+ with _columns[0]:
1288
+ generate_spending_header("Channel")
1289
+ with _columns[1]:
1290
+ generate_spending_header("Spends Input")
1291
+ with _columns[2]:
1292
+ generate_spending_header("Spends")
1293
+ with _columns[3]:
1294
+ generate_spending_header(target)
1295
+ with _columns[4]:
1296
+ generate_spending_header("Optimize")
1297
+
1298
+ st.markdown(
1299
+ """<hr class="spends-heading-seperator">""", unsafe_allow_html=True
1300
+ )
1301
+
1302
+ if "acutual_predicted" not in st.session_state:
1303
+ st.session_state["acutual_predicted"] = {
1304
+ "Channel_name": [],
1305
+ "Actual_spend": [],
1306
+ "Optimized_spend": [],
1307
+ "Delta": [],
1308
+ }
1309
+ for i, channel_name in enumerate(channels_list):
1310
+ _channel_class = st.session_state["scenario"].channels[
1311
+ channel_name
1312
+ ]
1313
+
1314
+ st.session_state[f"{channel_name}_percent"] = round(
1315
+ (
1316
+ _channel_class.modified_total_spends
1317
+ / _channel_class.actual_total_spends
1318
+ - 1
1319
+ )
1320
+ * 100
1321
+ )
1322
+
1323
+ _columns = st.columns((2.5, 1.5, 1.5, 1.5, 1))
1324
+ with _columns[0]:
1325
+ st.write(channel_name_formating(channel_name))
1326
+ bin_placeholder = st.container()
1327
+
1328
+ with _columns[1]:
1329
+ channel_bounds = _channel_class.bounds
1330
+ channel_spends = float(_channel_class.actual_total_spends)
1331
+ min_value = float(
1332
+ (1 + channel_bounds[0] / 100) * channel_spends
1333
+ )
1334
+ max_value = float(
1335
+ (1 + channel_bounds[1] / 100) * channel_spends
1336
+ )
1337
+ # print("##########", st.session_state[channel_name])
1338
+ spend_input = st.text_input(
1339
+ channel_name,
1340
+ key=channel_name,
1341
+ label_visibility="collapsed",
1342
+ on_change=partial(update_data, channel_name),
1343
+ )
1344
+ if not validate_input(spend_input):
1345
+ st.error("Invalid input")
1346
+
1347
+ channel_name_current = f"{channel_name}_change"
1348
+
1349
+ st.number_input(
1350
+ "Percent Change",
1351
+ key=f"{channel_name}_percent",
1352
+ step=1,
1353
+ on_change=partial(update_data_by_percent, channel_name),
1354
+ )
1355
+
1356
+ with _columns[2]:
1357
+ # spends
1358
+ current_channel_spends = float(
1359
+ _channel_class.modified_total_spends
1360
+ * _channel_class.conversion_rate
1361
+ )
1362
+ actual_channel_spends = float(
1363
+ _channel_class.actual_total_spends
1364
+ * _channel_class.conversion_rate
1365
+ )
1366
+ spends_delta = float(
1367
+ _channel_class.delta_spends
1368
+ * _channel_class.conversion_rate
1369
+ )
1370
+ st.session_state["acutual_predicted"]["Channel_name"].append(
1371
+ channel_name
1372
+ )
1373
+ st.session_state["acutual_predicted"]["Actual_spend"].append(
1374
+ actual_channel_spends
1375
+ )
1376
+ st.session_state["acutual_predicted"][
1377
+ "Optimized_spend"
1378
+ ].append(current_channel_spends)
1379
+ st.session_state["acutual_predicted"]["Delta"].append(
1380
+ spends_delta
1381
+ )
1382
+ ## REMOVE
1383
+ st.metric(
1384
+ "Spends",
1385
+ format_numbers(current_channel_spends),
1386
+ delta=numerize(spends_delta, 1),
1387
+ label_visibility="collapsed",
1388
+ )
1389
+
1390
+ with _columns[3]:
1391
+ # sales
1392
+ current_channel_sales = float(
1393
+ _channel_class.modified_total_sales
1394
+ )
1395
+ actual_channel_sales = float(_channel_class.actual_total_sales)
1396
+ sales_delta = float(_channel_class.delta_sales)
1397
+ st.metric(
1398
+ target,
1399
+ format_numbers(
1400
+ current_channel_sales, include_indicator=False
1401
+ ),
1402
+ delta=numerize(sales_delta, 1),
1403
+ label_visibility="collapsed",
1404
+ )
1405
+
1406
+ with _columns[4]:
1407
+
1408
+ # if st.checkbox(
1409
+ # label="select for optimization",
1410
+ # key=f"{channel_name}_selected",
1411
+ # value=False,
1412
+ # # on_change=partial(select_channel_for_optimization, channel_name),
1413
+ # label_visibility="collapsed",
1414
+ # ):
1415
+ # select_channel_for_optimization(channel_name)
1416
+
1417
+ st.checkbox(
1418
+ label="select for optimization",
1419
+ key=f"{channel_name}_selected",
1420
+ value=False,
1421
+ on_change=partial(
1422
+ select_channel_for_optimization, channel_name
1423
+ ),
1424
+ label_visibility="collapsed",
1425
+ )
1426
+
1427
+ st.markdown(
1428
+ """<hr class="spends-child-seperator">""",
1429
+ unsafe_allow_html=True,
1430
+ )
1431
+
1432
+ # Bins
1433
+ col = channels_list[i]
1434
+ x_actual = st.session_state["scenario"].channels[col].actual_spends
1435
+ x_modified = (
1436
+ st.session_state["scenario"].channels[col].modified_spends
1437
+ )
1438
+
1439
+ x_total = x_modified.sum()
1440
+ power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
1441
+
1442
+ updated_rcs_key = (
1443
+ f"{metrics_selected}#@{panel_selected}#@{channel_name}"
1444
+ )
1445
+
1446
+ if updated_rcs and updated_rcs_key in list(updated_rcs.keys()):
1447
+ K = updated_rcs[updated_rcs_key]["K"]
1448
+ b = updated_rcs[updated_rcs_key]["b"]
1449
+ a = updated_rcs[updated_rcs_key]["a"]
1450
+ x0 = updated_rcs[updated_rcs_key]["x0"]
1451
+ else:
1452
+ K = st.session_state["rcs"][col]["K"]
1453
+ b = st.session_state["rcs"][col]["b"]
1454
+ a = st.session_state["rcs"][col]["a"]
1455
+ x0 = st.session_state["rcs"][col]["x0"]
1456
+
1457
+ x_plot = np.linspace(0, 5 * x_actual.sum(), 200)
1458
+
1459
+ # Append current_channel_spends to the end of x_plot
1460
+ x_plot = np.append(x_plot, current_channel_spends)
1461
+
1462
+ x, y, marginal_roi = [], [], []
1463
+ for x_p in x_plot:
1464
+ x.append(x_p * x_actual / x_actual.sum())
1465
+
1466
+ for index in range(len(x_plot)):
1467
+ y.append(s_curve(x[index] / 10**power, K, b, a, x0))
1468
+
1469
+ for index in range(len(x_plot)):
1470
+ marginal_roi.append(
1471
+ a
1472
+ * y[index]
1473
+ * (1 - y[index] / np.maximum(K, np.finfo(float).eps))
1474
+ )
1475
+
1476
+ x = (
1477
+ np.sum(x, axis=1)
1478
+ * st.session_state["scenario"].channels[col].conversion_rate
1479
+ )
1480
+ y = np.sum(y, axis=1)
1481
+ marginal_roi = (
1482
+ np.average(marginal_roi, axis=1)
1483
+ / st.session_state["scenario"].channels[col].conversion_rate
1484
+ )
1485
+
1486
+ roi = y / np.maximum(x, np.finfo(float).eps)
1487
+
1488
+ roi_current, marginal_roi_current = roi[-1], marginal_roi[-1]
1489
+ x, y, roi, marginal_roi = (
1490
+ x[:-1],
1491
+ y[:-1],
1492
+ roi[:-1],
1493
+ marginal_roi[:-1],
1494
+ ) # Drop data for current spends
1495
+
1496
+ start_value, end_value, left_value, right_value = (
1497
+ find_segment_value(
1498
+ x,
1499
+ roi,
1500
+ marginal_roi,
1501
+ )
1502
+ )
1503
+
1504
+ rgba = calculate_rgba(
1505
+ start_value,
1506
+ end_value,
1507
+ left_value,
1508
+ right_value,
1509
+ current_channel_spends,
1510
+ )
1511
+
1512
+ with bin_placeholder:
1513
+ st.markdown(
1514
+ f"""
1515
+ <div style="
1516
+ border-radius: 12px;
1517
+ background-color: {rgba};
1518
+ padding: 10px;
1519
+ text-align: center;
1520
+ color: #006EC0;
1521
+ ">
1522
+ <p style="margin: 0; font-size: 20px;">ROI: {round(roi_current,1)}</p>
1523
+ <p style="margin: 0; font-size: 20px;">Marginal ROI: {round(marginal_roi_current,1)}</p>
1524
+ </div>
1525
+ """,
1526
+ unsafe_allow_html=True,
1527
+ )
1528
+
1529
+ st.session_state["project_dct"]["scenario_planner"]["scenario"] = (
1530
+ st.session_state["scenario"]
1531
+ )
1532
+
1533
+ with st.expander("See Response Curves", expanded=True):
1534
+ fig = plot_response_curves()
1535
+ st.plotly_chart(fig, use_container_width=True)
1536
+
1537
+ def update_optimization_bounds(channel_name, bound_type):
1538
+ index = 0 if bound_type == "lower" else 1
1539
+ update_key = (
1540
+ f"{channel_name}_b_lower"
1541
+ if bound_type == "lower"
1542
+ else f"{channel_name}_b_upper"
1543
+ )
1544
+ st.session_state["project_dct"]["scenario_planner"][
1545
+ unique_key
1546
+ ].channels[channel_name].bounds[index] = st.session_state[update_key]
1547
+
1548
+ def update_optimization_bounds_all(bound_type):
1549
+ index = 0 if bound_type == "lower" else 1
1550
+ update_key = (
1551
+ f"all_b_lower" if bound_type == "lower" else f"all_b_upper"
1552
+ )
1553
+
1554
+ for channel_name, _channel in st.session_state["project_dct"][
1555
+ "scenario_planner"
1556
+ ][unique_key].channels.items():
1557
+ _channel.bounds[index] = st.session_state[update_key]
1558
+
1559
+ with st.expander("Optimization setup"):
1560
+ bounds_placeholder = st.container()
1561
+ with bounds_placeholder:
1562
+ st.subheader("Optimization Bounds")
1563
+ with st.container():
1564
+ bounds_columns = st.columns((1, 0.35, 0.35, 1))
1565
+ with bounds_columns[0]:
1566
+ st.write("##")
1567
+ st.write("Update all channels")
1568
+
1569
+ with bounds_columns[1]:
1570
+ st.number_input(
1571
+ "Lower",
1572
+ min_value=-100,
1573
+ max_value=500,
1574
+ key=f"all_b_lower",
1575
+ # label_visibility="hidden",
1576
+ on_change=update_optimization_bounds_all,
1577
+ args=("lower",),
1578
+ step=5,
1579
+ value=-10,
1580
+ )
1581
+
1582
+ with bounds_columns[2]:
1583
+ st.number_input(
1584
+ "Higher",
1585
+ value=10,
1586
+ min_value=-100,
1587
+ max_value=500,
1588
+ key=f"all_b_upper",
1589
+ # label_visibility="hidden",
1590
+ on_change=update_optimization_bounds_all,
1591
+ args=("upper",),
1592
+ step=5,
1593
+ )
1594
+ st.divider()
1595
+
1596
+ st.write("#### Channel wise bounds")
1597
+ # st.divider()
1598
+ # bounds_columns = st.columns((1, 0.35, 0.35, 1))
1599
+
1600
+ # with bounds_columns[0]:
1601
+ # st.write("Channel")
1602
+ # with bounds_columns[1]:
1603
+ # st.write("Lower")
1604
+ # with bounds_columns[2]:
1605
+ # st.write("Upper")
1606
+ # st.divider()
1607
+
1608
+ for channel_name, _channel in st.session_state["project_dct"][
1609
+ "scenario_planner"
1610
+ ][unique_key].channels.items():
1611
+ st.session_state[f"{channel_name}_b_lower"] = _channel.bounds[0]
1612
+ st.session_state[f"{channel_name}_b_upper"] = _channel.bounds[1]
1613
+ with bounds_placeholder:
1614
+ with st.container():
1615
+ bounds_columns = st.columns((1, 0.35, 0.35, 1))
1616
+ with bounds_columns[0]:
1617
+ st.write("##")
1618
+ st.write(channel_name)
1619
+ with bounds_columns[1]:
1620
+ st.number_input(
1621
+ "Lower",
1622
+ min_value=-100,
1623
+ max_value=500,
1624
+ key=f"{channel_name}_b_lower",
1625
+ label_visibility="hidden",
1626
+ on_change=update_optimization_bounds,
1627
+ args=(
1628
+ channel_name,
1629
+ "lower",
1630
+ ),
1631
+ )
1632
+
1633
+ with bounds_columns[2]:
1634
+ st.number_input(
1635
+ "Higher",
1636
+ min_value=-100,
1637
+ max_value=500,
1638
+ key=f"{channel_name}_b_upper",
1639
+ label_visibility="hidden",
1640
+ on_change=update_optimization_bounds,
1641
+ args=(
1642
+ channel_name,
1643
+ "upper",
1644
+ ),
1645
+ )
1646
+
1647
+ st.divider()
1648
+ _columns = st.columns(2)
1649
+ with _columns[0]:
1650
+ st.subheader("Save Scenario")
1651
+ scenario_name = st.text_input(
1652
+ "Scenario name",
1653
+ key="scenario_input",
1654
+ placeholder="Scenario name",
1655
+ label_visibility="collapsed",
1656
+ )
1657
+ st.button(
1658
+ "Save",
1659
+ on_click=lambda: save_scenario(scenario_name),
1660
+ disabled=len(st.session_state["scenario_input"]) == 0,
1661
+ )
1662
+
1663
+ summary_df = pd.DataFrame(st.session_state["acutual_predicted"])
1664
+ summary_df.drop_duplicates(
1665
+ subset="Channel_name", keep="last", inplace=True
1666
+ )
1667
+
1668
+ summary_df_sorted = summary_df.sort_values(by="Delta", ascending=False)
1669
+ summary_df_sorted["Delta_percent"] = np.round(
1670
+ (
1671
+ (
1672
+ summary_df_sorted["Optimized_spend"]
1673
+ / summary_df_sorted["Actual_spend"]
1674
+ )
1675
+ - 1
1676
+ )
1677
+ * 100,
1678
+ 2,
1679
+ )
1680
+
1681
+ with open("summary_df.pkl", "wb") as f:
1682
+ pickle.dump(summary_df_sorted, f)
1683
+ # st.dataframe(summary_df_sorted)
1684
+ # ___columns=st.columns(3)
1685
+ # with ___columns[2]:
1686
+ # fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent')
1687
+ # st.plotly_chart(fig,use_container_width=True)
1688
+ # with ___columns[0]:
1689
+ # fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend')
1690
+ # st.plotly_chart(fig,use_container_width=True)
1691
+ # with ___columns[1]:
1692
+ # fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend')
1693
+ # st.plotly_chart(fig,use_container_width=True)
1694
+
1695
+ elif auth_status == False:
1696
+ st.error("Username/Password is incorrect")
1697
+
1698
+ if auth_status != True:
1699
+ try:
1700
+ username_forgot_pw, email_forgot_password, random_password = (
1701
+ authenticator.forgot_password("Forgot password")
1702
+ )
1703
+ if username_forgot_pw:
1704
+ st.session_state["config"]["credentials"]["usernames"][
1705
+ username_forgot_pw
1706
+ ]["password"] = stauth.Hasher([random_password]).generate()[0]
1707
+ send_email(email_forgot_password, random_password)
1708
+ st.success("New password sent securely")
1709
+ # Random password to be transferred to user securely
1710
+ elif username_forgot_pw == False:
1711
+ st.error("Username not found")
1712
+ except Exception as e:
1713
+ st.error(e)
1714
+
1715
+ update_db("9_Scenario_Planner.py")