Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- pages/10_Saved_Scenarios.py +408 -0
- pages/11_Model_Optimized_Recommendation.py +545 -0
- pages/1_Data_Import.py +1501 -0
- pages/2_Data_Validation_and_Insights.py +514 -0
- pages/3_Transformations.py +639 -0
- pages/4_Model_Build 2.py +1288 -0
- pages/5_Model_Tuning.py +917 -0
- pages/6_AI_Model_Results.py +828 -0
- pages/7_Current_Media_Performance.py +573 -0
- pages/8_Response_Curves.py +596 -0
- pages/9_Scenario_Planner.py +1715 -0
pages/10_Saved_Scenarios.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from numerize.numerize import numerize
|
3 |
+
import io
|
4 |
+
import pandas as pd
|
5 |
+
from utilities import (
|
6 |
+
format_numbers,
|
7 |
+
decimal_formater,
|
8 |
+
channel_name_formating,
|
9 |
+
load_local_css,
|
10 |
+
set_header,
|
11 |
+
initialize_data,
|
12 |
+
load_authenticator,
|
13 |
+
)
|
14 |
+
from openpyxl import Workbook
|
15 |
+
from openpyxl.styles import Alignment, Font, PatternFill
|
16 |
+
import pickle
|
17 |
+
import streamlit_authenticator as stauth
|
18 |
+
import yaml
|
19 |
+
from yaml import SafeLoader
|
20 |
+
from classes import class_from_dict
|
21 |
+
from utilities import update_db
|
22 |
+
|
23 |
+
st.set_page_config(layout="wide")
|
24 |
+
load_local_css("styles.css")
|
25 |
+
set_header()
|
26 |
+
|
27 |
+
# for k, v in st.session_state.items():
|
28 |
+
# if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
|
29 |
+
# st.session_state[k] = v
|
30 |
+
|
31 |
+
|
32 |
+
def create_scenario_summary(scenario_dict):
|
33 |
+
summary_rows = []
|
34 |
+
for channel_dict in scenario_dict["channels"]:
|
35 |
+
name_mod = channel_name_formating(channel_dict["name"])
|
36 |
+
summary_rows.append(
|
37 |
+
[
|
38 |
+
name_mod,
|
39 |
+
channel_dict.get("actual_total_spends")
|
40 |
+
* channel_dict.get("conversion_rate"),
|
41 |
+
channel_dict.get("modified_total_spends")
|
42 |
+
* channel_dict.get("conversion_rate"),
|
43 |
+
channel_dict.get("actual_total_sales"),
|
44 |
+
channel_dict.get("modified_total_sales"),
|
45 |
+
channel_dict.get("actual_total_sales")
|
46 |
+
/ (
|
47 |
+
channel_dict.get("actual_total_spends")
|
48 |
+
* channel_dict.get("conversion_rate")
|
49 |
+
),
|
50 |
+
channel_dict.get("modified_total_sales")
|
51 |
+
/ (
|
52 |
+
channel_dict.get("modified_total_spends")
|
53 |
+
* channel_dict.get("conversion_rate")
|
54 |
+
),
|
55 |
+
channel_dict.get("actual_mroi"),
|
56 |
+
channel_dict.get("modified_mroi"),
|
57 |
+
channel_dict.get("actual_total_spends")
|
58 |
+
* channel_dict.get("conversion_rate")
|
59 |
+
/ channel_dict.get("actual_total_sales"),
|
60 |
+
channel_dict.get("modified_total_spends")
|
61 |
+
* channel_dict.get("conversion_rate")
|
62 |
+
/ channel_dict.get("modified_total_sales"),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
summary_rows.append(
|
67 |
+
[
|
68 |
+
"Total",
|
69 |
+
scenario_dict.get("actual_total_spends"),
|
70 |
+
scenario_dict.get("modified_total_spends"),
|
71 |
+
scenario_dict.get("actual_total_sales"),
|
72 |
+
scenario_dict.get("modified_total_sales"),
|
73 |
+
scenario_dict.get("actual_total_sales")
|
74 |
+
/ scenario_dict.get("actual_total_spends"),
|
75 |
+
scenario_dict.get("modified_total_sales")
|
76 |
+
/ scenario_dict.get("modified_total_spends"),
|
77 |
+
"-",
|
78 |
+
"-",
|
79 |
+
scenario_dict.get("actual_total_spends")
|
80 |
+
/ scenario_dict.get("actual_total_sales"),
|
81 |
+
scenario_dict.get("modified_total_spends")
|
82 |
+
/ scenario_dict.get("modified_total_sales"),
|
83 |
+
]
|
84 |
+
)
|
85 |
+
|
86 |
+
columns_index = pd.MultiIndex.from_product(
|
87 |
+
[[""], ["Channel"]], names=["first", "second"]
|
88 |
+
)
|
89 |
+
columns_index = columns_index.append(
|
90 |
+
pd.MultiIndex.from_product(
|
91 |
+
[
|
92 |
+
["Spends", "NRPU", "ROI", "MROI", "Spend per NRPU"],
|
93 |
+
["Actual", "Simulated"],
|
94 |
+
],
|
95 |
+
names=["first", "second"],
|
96 |
+
)
|
97 |
+
)
|
98 |
+
return pd.DataFrame(summary_rows, columns=columns_index)
|
99 |
+
|
100 |
+
|
101 |
+
def summary_df_to_worksheet(df, ws):
|
102 |
+
heading_fill = PatternFill(
|
103 |
+
fill_type="solid", start_color="FF11B6BD", end_color="FF11B6BD"
|
104 |
+
)
|
105 |
+
for j, header in enumerate(df.columns.values):
|
106 |
+
col = j + 1
|
107 |
+
for i in range(1, 3):
|
108 |
+
ws.cell(row=i, column=j + 1, value=header[i - 1]).font = Font(
|
109 |
+
bold=True, color="FF11B6BD"
|
110 |
+
)
|
111 |
+
ws.cell(row=i, column=j + 1).fill = heading_fill
|
112 |
+
if col > 1 and (col - 6) % 5 == 0:
|
113 |
+
ws.merge_cells(start_row=1, end_row=1, start_column=col - 3, end_column=col)
|
114 |
+
ws.cell(row=1, column=col).alignment = Alignment(horizontal="center")
|
115 |
+
for i, row in enumerate(df.itertuples()):
|
116 |
+
for j, value in enumerate(row):
|
117 |
+
if j == 0:
|
118 |
+
continue
|
119 |
+
elif (j - 2) % 4 == 0 or (j - 3) % 4 == 0:
|
120 |
+
ws.cell(row=i + 3, column=j, value=value).number_format = "$#,##0.0"
|
121 |
+
else:
|
122 |
+
ws.cell(row=i + 3, column=j, value=value)
|
123 |
+
|
124 |
+
|
125 |
+
from openpyxl.utils import get_column_letter
|
126 |
+
from openpyxl.styles import Font, PatternFill
|
127 |
+
import logging
|
128 |
+
|
129 |
+
|
130 |
+
def scenario_df_to_worksheet(df, ws):
|
131 |
+
heading_fill = PatternFill(
|
132 |
+
start_color="FF11B6BD", end_color="FF11B6BD", fill_type="solid"
|
133 |
+
)
|
134 |
+
|
135 |
+
for j, header in enumerate(df.columns.values):
|
136 |
+
cell = ws.cell(row=1, column=j + 1, value=header)
|
137 |
+
cell.font = Font(bold=True, color="FF11B6BD")
|
138 |
+
cell.fill = heading_fill
|
139 |
+
|
140 |
+
for i, row in enumerate(df.itertuples()):
|
141 |
+
for j, value in enumerate(
|
142 |
+
row[1:], start=1
|
143 |
+
): # Start from index 1 to skip the index column
|
144 |
+
try:
|
145 |
+
cell = ws.cell(row=i + 2, column=j, value=value)
|
146 |
+
if isinstance(value, (int, float)):
|
147 |
+
cell.number_format = "$#,##0.0"
|
148 |
+
elif isinstance(value, str):
|
149 |
+
cell.value = value[:32767]
|
150 |
+
else:
|
151 |
+
cell.value = str(value)
|
152 |
+
except ValueError as e:
|
153 |
+
logging.error(
|
154 |
+
f"Error assigning value '{value}' to cell {get_column_letter(j)}{i+2}: {e}"
|
155 |
+
)
|
156 |
+
cell.value = None # Assign None to the cell where the error occurred
|
157 |
+
|
158 |
+
return ws
|
159 |
+
|
160 |
+
|
161 |
+
def download_scenarios():
|
162 |
+
"""
|
163 |
+
Makes a excel with all saved scenarios and saves it locally
|
164 |
+
"""
|
165 |
+
## create summary page
|
166 |
+
if len(scenarios_to_download) == 0:
|
167 |
+
return
|
168 |
+
wb = Workbook()
|
169 |
+
wb.iso_dates = True
|
170 |
+
wb.remove(wb.active)
|
171 |
+
st.session_state["xlsx_buffer"] = io.BytesIO()
|
172 |
+
summary_df = None
|
173 |
+
# print(scenarios_to_download)
|
174 |
+
for scenario_name in scenarios_to_download:
|
175 |
+
scenario_dict = st.session_state["saved_scenarios"][scenario_name]
|
176 |
+
_spends = []
|
177 |
+
column_names = ["Date"]
|
178 |
+
_sales = None
|
179 |
+
dates = None
|
180 |
+
summary_rows = []
|
181 |
+
for channel in scenario_dict["channels"]:
|
182 |
+
if dates is None:
|
183 |
+
dates = channel.get("dates")
|
184 |
+
_spends.append(dates)
|
185 |
+
if _sales is None:
|
186 |
+
_sales = channel.get("modified_sales")
|
187 |
+
else:
|
188 |
+
_sales += channel.get("modified_sales")
|
189 |
+
_spends.append(
|
190 |
+
channel.get("modified_spends") * channel.get("conversion_rate")
|
191 |
+
)
|
192 |
+
column_names.append(channel.get("name"))
|
193 |
+
|
194 |
+
name_mod = channel_name_formating(channel["name"])
|
195 |
+
summary_rows.append(
|
196 |
+
[
|
197 |
+
name_mod,
|
198 |
+
channel.get("modified_total_spends")
|
199 |
+
* channel.get("conversion_rate"),
|
200 |
+
channel.get("modified_total_sales"),
|
201 |
+
channel.get("modified_total_sales")
|
202 |
+
/ channel.get("modified_total_spends")
|
203 |
+
* channel.get("conversion_rate"),
|
204 |
+
channel.get("modified_mroi"),
|
205 |
+
channel.get("modified_total_sales")
|
206 |
+
/ channel.get("modified_total_spends")
|
207 |
+
* channel.get("conversion_rate"),
|
208 |
+
]
|
209 |
+
)
|
210 |
+
_spends.append(_sales)
|
211 |
+
column_names.append("NRPU")
|
212 |
+
scenario_df = pd.DataFrame(_spends).T
|
213 |
+
scenario_df.columns = column_names
|
214 |
+
## write to sheet
|
215 |
+
ws = wb.create_sheet(scenario_name)
|
216 |
+
scenario_df_to_worksheet(scenario_df, ws)
|
217 |
+
summary_rows.append(
|
218 |
+
[
|
219 |
+
"Total",
|
220 |
+
scenario_dict.get("modified_total_spends"),
|
221 |
+
scenario_dict.get("modified_total_sales"),
|
222 |
+
scenario_dict.get("modified_total_sales")
|
223 |
+
/ scenario_dict.get("modified_total_spends"),
|
224 |
+
"-",
|
225 |
+
scenario_dict.get("modified_total_spends")
|
226 |
+
/ scenario_dict.get("modified_total_sales"),
|
227 |
+
]
|
228 |
+
)
|
229 |
+
columns_index = pd.MultiIndex.from_product(
|
230 |
+
[[""], ["Channel"]], names=["first", "second"]
|
231 |
+
)
|
232 |
+
columns_index = columns_index.append(
|
233 |
+
pd.MultiIndex.from_product(
|
234 |
+
[[scenario_name], ["Spends", "NRPU", "ROI", "MROI", "Spends per NRPU"]],
|
235 |
+
names=["first", "second"],
|
236 |
+
)
|
237 |
+
)
|
238 |
+
if summary_df is None:
|
239 |
+
summary_df = pd.DataFrame(summary_rows, columns=columns_index)
|
240 |
+
summary_df = summary_df.set_index(("", "Channel"))
|
241 |
+
else:
|
242 |
+
_df = pd.DataFrame(summary_rows, columns=columns_index)
|
243 |
+
_df = _df.set_index(("", "Channel"))
|
244 |
+
summary_df = summary_df.merge(_df, left_index=True, right_index=True)
|
245 |
+
ws = wb.create_sheet("Summary", 0)
|
246 |
+
summary_df_to_worksheet(summary_df.reset_index(), ws)
|
247 |
+
wb.save(st.session_state["xlsx_buffer"])
|
248 |
+
st.session_state["disable_download_button"] = False
|
249 |
+
|
250 |
+
|
251 |
+
def disable_download_button():
|
252 |
+
st.session_state["disable_download_button"] = True
|
253 |
+
|
254 |
+
|
255 |
+
def transform(x):
|
256 |
+
if x.name == ("", "Channel"):
|
257 |
+
return x
|
258 |
+
elif x.name[0] == "ROI" or x.name[0] == "MROI":
|
259 |
+
return x.apply(
|
260 |
+
lambda y: (
|
261 |
+
y
|
262 |
+
if isinstance(y, str)
|
263 |
+
else decimal_formater(
|
264 |
+
format_numbers(y, include_indicator=False, n_decimals=4),
|
265 |
+
n_decimals=4,
|
266 |
+
)
|
267 |
+
)
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
return x.apply(lambda y: y if isinstance(y, str) else format_numbers(y))
|
271 |
+
|
272 |
+
|
273 |
+
def delete_scenario():
|
274 |
+
if selected_scenario in st.session_state["saved_scenarios"]:
|
275 |
+
del st.session_state["saved_scenarios"][selected_scenario]
|
276 |
+
with open("../saved_scenarios.pkl", "wb") as f:
|
277 |
+
pickle.dump(st.session_state["saved_scenarios"], f)
|
278 |
+
|
279 |
+
|
280 |
+
def load_scenario():
|
281 |
+
if selected_scenario in st.session_state["saved_scenarios"]:
|
282 |
+
st.session_state["scenario"] = class_from_dict(selected_scenario_details)
|
283 |
+
|
284 |
+
|
285 |
+
authenticator = st.session_state.get("authenticator")
|
286 |
+
if authenticator is None:
|
287 |
+
authenticator = load_authenticator()
|
288 |
+
|
289 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
290 |
+
auth_status = st.session_state.get("authentication_status")
|
291 |
+
|
292 |
+
if auth_status == True:
|
293 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
294 |
+
if not is_state_initiaized:
|
295 |
+
# print("Scenario page state reloaded")
|
296 |
+
initialize_data()
|
297 |
+
|
298 |
+
saved_scenarios = st.session_state["saved_scenarios"]
|
299 |
+
|
300 |
+
if len(saved_scenarios) == 0:
|
301 |
+
st.header("No saved scenarios")
|
302 |
+
|
303 |
+
else:
|
304 |
+
selected_scenario_list = list(saved_scenarios.keys())
|
305 |
+
if "selected_scenario_selectbox_key" not in st.session_state:
|
306 |
+
st.session_state["selected_scenario_selectbox_key"] = (
|
307 |
+
selected_scenario_list[
|
308 |
+
st.session_state["project_dct"]["saved_scenarios"][
|
309 |
+
"selected_scenario_selectbox_key"
|
310 |
+
]
|
311 |
+
]
|
312 |
+
)
|
313 |
+
|
314 |
+
col_a, col_b = st.columns(2)
|
315 |
+
selected_scenario = col_a.selectbox(
|
316 |
+
"Pick a scenario to view details",
|
317 |
+
selected_scenario_list,
|
318 |
+
# key="selected_scenario_selectbox_key",
|
319 |
+
index=st.session_state["project_dct"]["saved_scenarios"][
|
320 |
+
"selected_scenario_selectbox_key"
|
321 |
+
],
|
322 |
+
)
|
323 |
+
st.session_state["project_dct"]["saved_scenarios"][
|
324 |
+
"selected_scenario_selectbox_key"
|
325 |
+
] = selected_scenario_list.index(selected_scenario)
|
326 |
+
|
327 |
+
scenarios_to_download = col_b.multiselect(
|
328 |
+
"Select scenarios to download",
|
329 |
+
list(saved_scenarios.keys()),
|
330 |
+
on_change=disable_download_button,
|
331 |
+
)
|
332 |
+
|
333 |
+
with col_a:
|
334 |
+
col3, col4 = st.columns(2)
|
335 |
+
|
336 |
+
col4.button(
|
337 |
+
"Delete scenarios",
|
338 |
+
on_click=delete_scenario,
|
339 |
+
use_container_width=True,
|
340 |
+
)
|
341 |
+
col3.button(
|
342 |
+
"Load Scenario",
|
343 |
+
on_click=load_scenario,
|
344 |
+
use_container_width=True,
|
345 |
+
)
|
346 |
+
|
347 |
+
with col_b:
|
348 |
+
col1, col2 = st.columns(2)
|
349 |
+
|
350 |
+
col1.button(
|
351 |
+
"Prepare download",
|
352 |
+
on_click=download_scenarios,
|
353 |
+
use_container_width=True,
|
354 |
+
)
|
355 |
+
col2.download_button(
|
356 |
+
label="Download Scenarios",
|
357 |
+
data=st.session_state["xlsx_buffer"].getvalue(),
|
358 |
+
file_name="scenarios.xlsx",
|
359 |
+
mime="application/vnd.ms-excel",
|
360 |
+
disabled=st.session_state["disable_download_button"],
|
361 |
+
on_click=disable_download_button,
|
362 |
+
use_container_width=True,
|
363 |
+
)
|
364 |
+
|
365 |
+
# column_1, column_2, column_3 = st.columns((6, 1, 1))
|
366 |
+
# with column_1:
|
367 |
+
# st.header(selected_scenario)
|
368 |
+
# with column_2:
|
369 |
+
# st.button("Delete scenarios", on_click=delete_scenario)
|
370 |
+
# with column_3:
|
371 |
+
# st.button("Load Scenario", on_click=load_scenario)
|
372 |
+
|
373 |
+
selected_scenario_details = saved_scenarios[selected_scenario]
|
374 |
+
|
375 |
+
#st.write(pd.DataFrame(selected_scenario_details))
|
376 |
+
pd.set_option("display.max_colwidth", 100)
|
377 |
+
|
378 |
+
st.markdown(
|
379 |
+
create_scenario_summary(selected_scenario_details)
|
380 |
+
.transform(transform)
|
381 |
+
.style.set_table_styles(
|
382 |
+
[
|
383 |
+
{"selector": "th", "props": [("background-color", "#FFFFFF")]},
|
384 |
+
{
|
385 |
+
"selector": "tr:nth-child(even)",
|
386 |
+
"props": [("background-color", "#FFFFFF")],
|
387 |
+
},
|
388 |
+
]
|
389 |
+
)
|
390 |
+
.to_html(),
|
391 |
+
unsafe_allow_html=True,
|
392 |
+
)
|
393 |
+
|
394 |
+
elif auth_status == False:
|
395 |
+
st.error("Username/Password is incorrect")
|
396 |
+
|
397 |
+
if auth_status != True:
|
398 |
+
try:
|
399 |
+
username_forgot_pw, email_forgot_password, random_password = (
|
400 |
+
authenticator.forgot_password("Forgot password")
|
401 |
+
)
|
402 |
+
if username_forgot_pw:
|
403 |
+
st.success("New password sent securely")
|
404 |
+
# Random password to be transferred to user securely
|
405 |
+
elif username_forgot_pw == False:
|
406 |
+
st.error("Username not found")
|
407 |
+
except Exception as e:
|
408 |
+
st.error(e)
|
pages/11_Model_Optimized_Recommendation.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from numerize.numerize import numerize
|
3 |
+
import pandas as pd
|
4 |
+
from utilities import (format_numbers,decimal_formater,
|
5 |
+
load_local_css,set_header,
|
6 |
+
initialize_data,
|
7 |
+
load_authenticator)
|
8 |
+
import pickle
|
9 |
+
import streamlit_authenticator as stauth
|
10 |
+
import yaml
|
11 |
+
from yaml import SafeLoader
|
12 |
+
from classes import class_from_dict
|
13 |
+
import plotly.express as px
|
14 |
+
import numpy as np
|
15 |
+
import plotly.graph_objects as go
|
16 |
+
import pandas as pd
|
17 |
+
from plotly.subplots import make_subplots
|
18 |
+
import sqlite3
|
19 |
+
from utilities import update_db
|
20 |
+
|
21 |
+
def format_number(x):
|
22 |
+
if x >= 1_000_000:
|
23 |
+
return f'{x / 1_000_000:.2f}M'
|
24 |
+
elif x >= 1_000:
|
25 |
+
return f'{x / 1_000:.2f}K'
|
26 |
+
else:
|
27 |
+
return f'{x:.2f}'
|
28 |
+
|
29 |
+
def summary_plot(data, x, y, title, text_column, color, format_as_percent=False, format_as_decimal=False):
|
30 |
+
fig = px.bar(data, x=x, y=y, orientation='h',
|
31 |
+
title=title, text=text_column, color=color)
|
32 |
+
fig.update_layout(showlegend=False)
|
33 |
+
data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
|
34 |
+
|
35 |
+
# Update the format of the displayed text based on the chosen format
|
36 |
+
if format_as_percent:
|
37 |
+
fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
|
38 |
+
elif format_as_decimal:
|
39 |
+
fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
|
40 |
+
else:
|
41 |
+
fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
|
42 |
+
|
43 |
+
fig.update_layout(xaxis_title=x, yaxis_title='Channel Name', showlegend=False)
|
44 |
+
return fig
|
45 |
+
|
46 |
+
|
47 |
+
def stacked_summary_plot(data, x, y, title, text_column, color_column, stack_column=None, format_as_percent=False, format_as_decimal=False):
|
48 |
+
fig = px.bar(data, x=x, y=y, orientation='h',
|
49 |
+
title=title, text=text_column, color=color_column, facet_col=stack_column)
|
50 |
+
fig.update_layout(showlegend=False)
|
51 |
+
data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
|
52 |
+
|
53 |
+
# Update the format of the displayed text based on the chosen format
|
54 |
+
if format_as_percent:
|
55 |
+
fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
|
56 |
+
elif format_as_decimal:
|
57 |
+
fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
|
58 |
+
else:
|
59 |
+
fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
|
60 |
+
|
61 |
+
fig.update_layout(xaxis_title=x, yaxis_title='', showlegend=False)
|
62 |
+
return fig
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def funnel_plot(data, x, y, title, text_column, color_column, format_as_percent=False, format_as_decimal=False):
|
67 |
+
data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
|
68 |
+
|
69 |
+
# Round the numeric values in the text column to two decimal points
|
70 |
+
data[text_column] = data[text_column].round(2)
|
71 |
+
|
72 |
+
# Create a color map for categorical data
|
73 |
+
color_map = {category: f'rgb({i * 30 % 255},{i * 50 % 255},{i * 70 % 255})' for i, category in enumerate(data[color_column].unique())}
|
74 |
+
|
75 |
+
fig = go.Figure(go.Funnel(
|
76 |
+
y=data[y],
|
77 |
+
x=data[x],
|
78 |
+
text=data[text_column],
|
79 |
+
marker=dict(color=data[color_column].map(color_map)),
|
80 |
+
textinfo="value",
|
81 |
+
hoverinfo='y+x+text'
|
82 |
+
))
|
83 |
+
|
84 |
+
# Update the format of the displayed text based on the chosen format
|
85 |
+
if format_as_percent:
|
86 |
+
fig.update_layout(title=title, funnelmode="percent")
|
87 |
+
elif format_as_decimal:
|
88 |
+
fig.update_layout(title=title, funnelmode="overlay")
|
89 |
+
else:
|
90 |
+
fig.update_layout(title=title, funnelmode="group")
|
91 |
+
|
92 |
+
return fig
|
93 |
+
|
94 |
+
|
95 |
+
st.set_page_config(layout='wide')
|
96 |
+
load_local_css('styles.css')
|
97 |
+
set_header()
|
98 |
+
|
99 |
+
# for k, v in st.session_state.items():
|
100 |
+
# if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
|
101 |
+
# st.session_state[k] = v
|
102 |
+
|
103 |
+
st.empty()
|
104 |
+
st.header('Model Result Analysis')
|
105 |
+
spends_data=pd.read_excel('Overview_data_test.xlsx')
|
106 |
+
|
107 |
+
with open('summary_df.pkl', 'rb') as file:
|
108 |
+
summary_df_sorted = pickle.load(file)
|
109 |
+
#st.write(summary_df_sorted)
|
110 |
+
|
111 |
+
selected_scenario= st.selectbox('Select Saved Scenarios',['S1','S2'])
|
112 |
+
summary_df_sorted=summary_df_sorted.sort_values(by=['Optimized_spend'])
|
113 |
+
st.header('Optimized Spends Overview')
|
114 |
+
|
115 |
+
|
116 |
+
channel_colors = px.colors.qualitative.Plotly
|
117 |
+
|
118 |
+
fig = make_subplots(rows=1, cols=3, subplot_titles=('Actual Spend','Planned Spend', 'Delta'), horizontal_spacing=0.05)
|
119 |
+
|
120 |
+
for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
|
121 |
+
channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
|
122 |
+
channel_color = channel_colors[i % len(channel_colors)]
|
123 |
+
|
124 |
+
fig.add_trace(go.Bar(x=channel_df['Actual_spend'],
|
125 |
+
y=channel_df['Channel_name'],
|
126 |
+
text=channel_df['Actual_spend'].apply(format_number),
|
127 |
+
marker_color=channel_color,
|
128 |
+
orientation='h'), row=1, col=1)
|
129 |
+
|
130 |
+
fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
|
131 |
+
y=channel_df['Channel_name'],
|
132 |
+
text=channel_df['Optimized_spend'].apply(format_number),
|
133 |
+
marker_color=channel_color,
|
134 |
+
orientation='h', showlegend=False), row=1, col=2)
|
135 |
+
|
136 |
+
fig.add_trace(go.Bar(x=channel_df['Delta_percent'],
|
137 |
+
y=channel_df['Channel_name'],
|
138 |
+
text=channel_df['Delta_percent'].apply(format_number),
|
139 |
+
marker_color=channel_color,
|
140 |
+
orientation='h', showlegend=False), row=1, col=3)
|
141 |
+
fig.update_layout(
|
142 |
+
height=600,
|
143 |
+
width=900,
|
144 |
+
title='',
|
145 |
+
showlegend=False
|
146 |
+
)
|
147 |
+
|
148 |
+
fig.update_yaxes(showticklabels=False ,row=1, col=2 )
|
149 |
+
fig.update_yaxes(showticklabels=False, row=1, col=3)
|
150 |
+
|
151 |
+
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
152 |
+
fig.update_xaxes(showticklabels=False, row=1, col=2)
|
153 |
+
fig.update_xaxes(showticklabels=False, row=1, col=3)
|
154 |
+
|
155 |
+
|
156 |
+
st.plotly_chart(fig, use_container_width=True)
|
157 |
+
|
158 |
+
|
159 |
+
# ___columns=st.columns(3)
|
160 |
+
|
161 |
+
|
162 |
+
# with ___columns[2]:
|
163 |
+
|
164 |
+
# fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent',color='Channel_name')
|
165 |
+
# st.plotly_chart(fig,use_container_width=True)
|
166 |
+
# with ___columns[0]:
|
167 |
+
# fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend',color='Channel_name')
|
168 |
+
# st.plotly_chart(fig,use_container_width=True)
|
169 |
+
# with ___columns[1]:
|
170 |
+
# fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
|
171 |
+
# st.plotly_chart(fig,use_container_width=False)
|
172 |
+
summary_df_sorted['Perc_alloted']=np.round(summary_df_sorted['Optimized_spend']/summary_df_sorted['Optimized_spend'].sum(),2)
|
173 |
+
st.header(' Budget Allocation')
|
174 |
+
|
175 |
+
fig = make_subplots(rows=1, cols=2, subplot_titles=('Planned Spend','% Split'), horizontal_spacing=0.05)
|
176 |
+
|
177 |
+
for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
|
178 |
+
channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
|
179 |
+
channel_color = channel_colors[i % len(channel_colors)]
|
180 |
+
|
181 |
+
fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
|
182 |
+
y=channel_df['Channel_name'],
|
183 |
+
text=channel_df['Optimized_spend'].apply(format_number),
|
184 |
+
marker_color=channel_color,
|
185 |
+
orientation='h'), row=1, col=1)
|
186 |
+
|
187 |
+
fig.add_trace(go.Bar(x=channel_df['Perc_alloted'],
|
188 |
+
y=channel_df['Channel_name'],
|
189 |
+
text=channel_df['Perc_alloted'].apply(lambda x: f'{100*x:.0f}%'),
|
190 |
+
marker_color=channel_color,
|
191 |
+
orientation='h', showlegend=False), row=1, col=2)
|
192 |
+
|
193 |
+
fig.update_layout(
|
194 |
+
height=600,
|
195 |
+
width=900,
|
196 |
+
title='',
|
197 |
+
showlegend=False
|
198 |
+
)
|
199 |
+
|
200 |
+
fig.update_yaxes(showticklabels=False ,row=1, col=2 )
|
201 |
+
fig.update_yaxes(showticklabels=False, row=1, col=3)
|
202 |
+
|
203 |
+
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
204 |
+
fig.update_xaxes(showticklabels=False, row=1, col=2)
|
205 |
+
fig.update_xaxes(showticklabels=False, row=1, col=3)
|
206 |
+
|
207 |
+
|
208 |
+
st.plotly_chart(fig, use_container_width=True)
|
209 |
+
|
210 |
+
# summary_df_sorted['Perc_alloted']=np.round(summary_df_sorted['Optimized_spend']/summary_df_sorted['Optimized_spend'].sum(),2)
|
211 |
+
# columns2=st.columns(2)
|
212 |
+
# with columns2[0]:
|
213 |
+
# fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
|
214 |
+
# st.plotly_chart(fig,use_container_width=True)
|
215 |
+
# with columns2[1]:
|
216 |
+
# fig=summary_plot(summary_df_sorted, x='Perc_alloted', y='Channel_name', title='% Split', text_column='Perc_alloted',color='Channel_name',format_as_percent=True)
|
217 |
+
# st.plotly_chart(fig,use_container_width=True)
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
if 'raw_data' not in st.session_state:
|
228 |
+
st.session_state['raw_data']=pd.read_excel('raw_data_nov7_combined1.xlsx')
|
229 |
+
st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['MediaChannelName'].isin(summary_df_sorted['Channel_name'].unique())]
|
230 |
+
st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['Date'].isin(spends_data["Date"].unique())]
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
#st.write(st.session_state['raw_data']['ResponseMetricName'])
|
235 |
+
# st.write(st.session_state['raw_data'])
|
236 |
+
|
237 |
+
|
238 |
+
st.header('Response Forecast Overview')
|
239 |
+
raw_data=st.session_state['raw_data']
|
240 |
+
effectiveness_overall=raw_data.groupby('ResponseMetricName').agg({'ResponseMetricValue': 'sum'}).reset_index()
|
241 |
+
effectiveness_overall['Efficiency']=effectiveness_overall['ResponseMetricValue'].map(lambda x: x/raw_data['Media Spend'].sum() )
|
242 |
+
# st.write(effectiveness_overall)
|
243 |
+
|
244 |
+
columns6=st.columns(3)
|
245 |
+
|
246 |
+
effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False,inplace=True)
|
247 |
+
effectiveness_overall=np.round(effectiveness_overall,2)
|
248 |
+
effectiveness_overall['ResponseMetric'] = effectiveness_overall['ResponseMetricName'].apply(lambda x: 'BAU' if 'BAU' in x else ('Gamified' if 'Gamified' in x else x))
|
249 |
+
# effectiveness_overall=np.where(effectiveness_overall[effectiveness_overall['ResponseMetricName']=="Adjusted Account Approval BAU"],"Adjusted Account Approval BAU",effectiveness_overall['ResponseMetricName'])
|
250 |
+
|
251 |
+
effectiveness_overall.replace({'ResponseMetricName':{'BAU approved clients - Appsflyer':'Approved clients - Appsflyer',
|
252 |
+
'Gamified approved clients - Appsflyer':'Approved clients - Appsflyer'}},inplace=True)
|
253 |
+
|
254 |
+
# st.write(effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False))
|
255 |
+
|
256 |
+
|
257 |
+
condition = effectiveness_overall['ResponseMetricName'] == "Adjusted Account Approval BAU"
|
258 |
+
condition1= effectiveness_overall['ResponseMetricName'] == "Approved clients - Appsflyer"
|
259 |
+
effectiveness_overall['ResponseMetric'] = np.where(condition, "Adjusted Account Approval BAU", effectiveness_overall['ResponseMetric'])
|
260 |
+
|
261 |
+
effectiveness_overall['ResponseMetricName'] = np.where(condition1, "Approved clients - Appsflyer (BAU, Gamified)", effectiveness_overall['ResponseMetricName'])
|
262 |
+
# effectiveness_overall=pd.DataFrame({'ResponseMetricName':["App Installs - Appsflyer",'Account Requests - Appsflyer',
|
263 |
+
# 'Total Adjusted Account Approval','Adjusted Account Approval BAU',
|
264 |
+
# 'Approved clients - Appsflyer','Approved clients - Appsflyer'],
|
265 |
+
# 'ResponseMetricValue':[683067,367020,112315,79768,36661,16834],
|
266 |
+
# 'Efficiency':[1.24,0.67,0.2,0.14,0.07,0.03],
|
267 |
+
custom_colors = {
|
268 |
+
'App Installs - Appsflyer': 'rgb(255, 135, 0)', # Steel Blue (Blue)
|
269 |
+
'Account Requests - Appsflyer': 'rgb(125, 239, 161)', # Cornflower Blue (Blue)
|
270 |
+
'Adjusted Account Approval': 'rgb(129, 200, 255)', # Dodger Blue (Blue)
|
271 |
+
'Adjusted Account Approval BAU': 'rgb(255, 207, 98)', # Light Sky Blue (Blue)
|
272 |
+
'Approved clients - Appsflyer': 'rgb(0, 97, 198)', # Light Blue (Blue)
|
273 |
+
"BAU": 'rgb(41, 176, 157)', # Steel Blue (Blue)
|
274 |
+
"Gamified": 'rgb(213, 218, 229)' # Silver (Gray)
|
275 |
+
# Add more categories and their respective shades of blue as needed
|
276 |
+
}
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
with columns6[0]:
|
284 |
+
revenue=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Total Approved Accounts - Revenue']['ResponseMetricValue']).iloc[0]
|
285 |
+
revenue=round(revenue / 1_000_000, 2)
|
286 |
+
|
287 |
+
# st.metric('Total Revenue', f"${revenue} M")
|
288 |
+
# with columns6[1]:
|
289 |
+
# BAU=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='BAU approved clients - Revenue']['ResponseMetricValue']).iloc[0]
|
290 |
+
# BAU=round(BAU / 1_000_000, 2)
|
291 |
+
# st.metric('BAU approved clients - Revenue', f"${BAU} M")
|
292 |
+
# with columns6[2]:
|
293 |
+
# Gam=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Gamified approved clients - Revenue']['ResponseMetricValue']).iloc[0]
|
294 |
+
# Gam=round(Gam / 1_000_000, 2)
|
295 |
+
# st.metric('Gamified approved clients - Revenue', f"${Gam} M")
|
296 |
+
|
297 |
+
# st.write(effectiveness_overall)
|
298 |
+
data = {'Revenue': ['BAU approved clients - Revenue', 'Gamified approved clients- Revenue'],
|
299 |
+
'ResponseMetricValue': [70200000, 1770000],
|
300 |
+
'Efficiency':[127.54,3.21]}
|
301 |
+
df = pd.DataFrame(data)
|
302 |
+
|
303 |
+
|
304 |
+
columns9=st.columns([0.60,0.40])
|
305 |
+
with columns9[0]:
|
306 |
+
figd = px.pie(df,
|
307 |
+
names='Revenue',
|
308 |
+
values='ResponseMetricValue',
|
309 |
+
hole=0.3, # set the size of the hole in the donut
|
310 |
+
title='Effectiveness')
|
311 |
+
figd.update_layout(
|
312 |
+
margin=dict(l=0, r=0, b=0, t=0),width=100, height=180,legend=dict(
|
313 |
+
orientation='v', # set orientation to horizontal
|
314 |
+
x=0, # set x to 0 to move to the left
|
315 |
+
y=0.8 # adjust y as needed
|
316 |
+
)
|
317 |
+
)
|
318 |
+
|
319 |
+
st.plotly_chart(figd, use_container_width=True)
|
320 |
+
|
321 |
+
with columns9[1]:
|
322 |
+
figd1 = px.pie(df,
|
323 |
+
names='Revenue',
|
324 |
+
values='Efficiency',
|
325 |
+
hole=0.3, # set the size of the hole in the donut
|
326 |
+
title='Efficiency')
|
327 |
+
figd1.update_layout(
|
328 |
+
margin=dict(l=0, r=0, b=0, t=0),width=100,height=180,showlegend=False
|
329 |
+
)
|
330 |
+
st.plotly_chart(figd1, use_container_width=True)
|
331 |
+
|
332 |
+
effectiveness_overall['Response Metric Name']=effectiveness_overall['ResponseMetricName']
|
333 |
+
|
334 |
+
|
335 |
+
|
336 |
+
columns4= st.columns([0.55,0.45])
|
337 |
+
with columns4[0]:
|
338 |
+
fig=px.funnel(effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
|
339 |
+
'BAU approved clients - Revenue',
|
340 |
+
'Gamified approved clients - Revenue',
|
341 |
+
"Total Approved Accounts - Appsflyer"]))],
|
342 |
+
x='ResponseMetricValue', y='Response Metric Name',color='ResponseMetric',
|
343 |
+
color_discrete_map=custom_colors,title='Effectiveness',
|
344 |
+
labels=None)
|
345 |
+
custom_y_labels=['App Installs - Appsflyer','Account Requests - Appsflyer','Adjusted Account Approval','Adjusted Account Approval BAU',
|
346 |
+
"Approved clients - Appsflyer (BAU, Gamified)"
|
347 |
+
]
|
348 |
+
fig.update_layout(showlegend=False,
|
349 |
+
yaxis=dict(
|
350 |
+
tickmode='array',
|
351 |
+
ticktext=custom_y_labels,
|
352 |
+
)
|
353 |
+
)
|
354 |
+
fig.update_traces(textinfo='value', textposition='inside', texttemplate='%{x:.2s} ', hoverinfo='y+x+percent initial')
|
355 |
+
|
356 |
+
last_trace_index = len(fig.data) - 1
|
357 |
+
fig.update_traces(marker=dict(line=dict(color='black', width=2)), selector=dict(marker=dict(color='blue')))
|
358 |
+
|
359 |
+
st.plotly_chart(fig,use_container_width=True)
|
360 |
+
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
|
365 |
+
with columns4[1]:
|
366 |
+
|
367 |
+
# Your existing code for creating the bar chart
|
368 |
+
fig1 = px.bar((effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
|
369 |
+
'BAU approved clients - Revenue',
|
370 |
+
'Gamified approved clients - Revenue',
|
371 |
+
"Total Approved Accounts - Appsflyer"]))]).sort_values(by='ResponseMetricValue'),
|
372 |
+
x='Efficiency', y='Response Metric Name',
|
373 |
+
color_discrete_map=custom_colors, color='ResponseMetric',
|
374 |
+
labels=None,text_auto=True,title='Efficiency'
|
375 |
+
)
|
376 |
+
|
377 |
+
# Update layout and traces
|
378 |
+
fig1.update_traces(customdata=effectiveness_overall['Efficiency'],
|
379 |
+
textposition='auto')
|
380 |
+
fig1.update_layout(showlegend=False)
|
381 |
+
fig1.update_yaxes(title='',showticklabels=False)
|
382 |
+
fig1.update_xaxes(title='',showticklabels=False)
|
383 |
+
fig1.update_xaxes(tickfont=dict(size=20))
|
384 |
+
fig1.update_yaxes(tickfont=dict(size=20))
|
385 |
+
st.plotly_chart(fig1, use_container_width=True)
|
386 |
+
|
387 |
+
|
388 |
+
effectiveness_overall_revenue=pd.DataFrame({'ResponseMetricName':['Approved Clients','Approved Clients'],
|
389 |
+
'ResponseMetricValue':[70201070,1768900],
|
390 |
+
'Efficiency':[127.54,3.21],
|
391 |
+
'ResponseMetric':['BAU','Gamified']
|
392 |
+
})
|
393 |
+
# from plotly.subplots import make_subplots
|
394 |
+
# fig = make_subplots(rows=1, cols=2,
|
395 |
+
# subplot_titles=["Effectiveness", "Efficiency"])
|
396 |
+
|
397 |
+
# # Add first plot as subplot
|
398 |
+
# fig.add_trace(go.Funnel(
|
399 |
+
# x = fig.data[0].x,
|
400 |
+
# y = fig.data[0].y,
|
401 |
+
# textinfo = 'value+percent initial',
|
402 |
+
# hoverinfo = 'x+y+percent initial'
|
403 |
+
# ), row=1, col=1)
|
404 |
+
|
405 |
+
# # Update layout for first subplot
|
406 |
+
# fig.update_xaxes(title_text="Response Metric Value", row=1, col=1)
|
407 |
+
# fig.update_yaxes(ticktext = custom_y_labels, row=1, col=1)
|
408 |
+
|
409 |
+
# # Add second plot as subplot
|
410 |
+
# fig.add_trace(go.Bar(
|
411 |
+
# x = fig1.data[0].x,
|
412 |
+
# y = fig1.data[0].y,
|
413 |
+
# customdata = fig1.data[0].customdata,
|
414 |
+
# textposition = 'auto'
|
415 |
+
# ), row=1, col=2)
|
416 |
+
|
417 |
+
# # Update layout for second subplot
|
418 |
+
# fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
|
419 |
+
# fig.update_yaxes(title='', showticklabels=False, row=1, col=2)
|
420 |
+
|
421 |
+
# fig.update_layout(height=600, width=800, title_text="Key Metrics")
|
422 |
+
# st.plotly_chart(fig)
|
423 |
+
|
424 |
+
|
425 |
+
st.header('Return Forecast by Media Channel')
|
426 |
+
with st.expander("Return Forecast by Media Channel"):
|
427 |
+
metric_data=[val for val in list(st.session_state['raw_data']['ResponseMetricName'].unique()) if val!=np.NaN]
|
428 |
+
# st.write(metric_data)
|
429 |
+
metric=st.selectbox('Select Metric',metric_data,index=1)
|
430 |
+
|
431 |
+
selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
|
432 |
+
# st.dataframe(selected_metric.head(2))
|
433 |
+
selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
|
434 |
+
effectiveness=selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()
|
435 |
+
effectiveness_df=pd.DataFrame({'Channel':effectiveness.index,"ResponseMetricValue":effectiveness.values})
|
436 |
+
|
437 |
+
summary_df_sorted=summary_df_sorted.merge(effectiveness_df,left_on="Channel_name",right_on='Channel')
|
438 |
+
|
439 |
+
#
|
440 |
+
summary_df_sorted['Efficiency'] = summary_df_sorted['ResponseMetricValue'] / summary_df_sorted['Optimized_spend']
|
441 |
+
summary_df_sorted=summary_df_sorted.sort_values(by='Optimized_spend',ascending=True)
|
442 |
+
#st.dataframe(summary_df_sorted)
|
443 |
+
|
444 |
+
channel_colors = px.colors.qualitative.Plotly
|
445 |
+
|
446 |
+
fig = make_subplots(rows=1, cols=3, subplot_titles=('Optimized Spends', 'Effectiveness', 'Efficiency'), horizontal_spacing=0.05)
|
447 |
+
|
448 |
+
for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
|
449 |
+
channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
|
450 |
+
channel_color = channel_colors[i % len(channel_colors)]
|
451 |
+
|
452 |
+
fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
|
453 |
+
y=channel_df['Channel_name'],
|
454 |
+
text=channel_df['Optimized_spend'].apply(format_number),
|
455 |
+
marker_color=channel_color,
|
456 |
+
orientation='h'), row=1, col=1)
|
457 |
+
|
458 |
+
fig.add_trace(go.Bar(x=channel_df['ResponseMetricValue'],
|
459 |
+
y=channel_df['Channel_name'],
|
460 |
+
text=channel_df['ResponseMetricValue'].apply(format_number),
|
461 |
+
marker_color=channel_color,
|
462 |
+
orientation='h', showlegend=False), row=1, col=2)
|
463 |
+
|
464 |
+
fig.add_trace(go.Bar(x=channel_df['Efficiency'],
|
465 |
+
y=channel_df['Channel_name'],
|
466 |
+
text=channel_df['Efficiency'].apply(format_number),
|
467 |
+
marker_color=channel_color,
|
468 |
+
orientation='h', showlegend=False), row=1, col=3)
|
469 |
+
|
470 |
+
fig.update_layout(
|
471 |
+
height=600,
|
472 |
+
width=900,
|
473 |
+
title='Media Channel Performance',
|
474 |
+
showlegend=False
|
475 |
+
)
|
476 |
+
|
477 |
+
fig.update_yaxes(showticklabels=False ,row=1, col=2 )
|
478 |
+
fig.update_yaxes(showticklabels=False, row=1, col=3)
|
479 |
+
|
480 |
+
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
481 |
+
fig.update_xaxes(showticklabels=False, row=1, col=2)
|
482 |
+
fig.update_xaxes(showticklabels=False, row=1, col=3)
|
483 |
+
|
484 |
+
|
485 |
+
st.plotly_chart(fig, use_container_width=True)
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
# columns= st.columns(3)
|
490 |
+
# with columns[0]:
|
491 |
+
# fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='', text_column='Optimized_spend',color='Channel_name')
|
492 |
+
# st.plotly_chart(fig,use_container_width=True)
|
493 |
+
# with columns[1]:
|
494 |
+
|
495 |
+
# # effectiveness=(selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()).values
|
496 |
+
# # effectiveness_df=pd.DataFrame({'Channel':st.session_state['raw_data']['MediaChannelName'].unique(),"ResponseMetricValue":effectiveness})
|
497 |
+
# # # effectiveness.reset_index(inplace=True)
|
498 |
+
# # # st.dataframe(effectiveness.head())
|
499 |
+
|
500 |
+
|
501 |
+
# fig=summary_plot(summary_df_sorted, x='ResponseMetricValue', y='Channel_name', title='Effectiveness', text_column='ResponseMetricValue',color='Channel_name')
|
502 |
+
# st.plotly_chart(fig,use_container_width=True)
|
503 |
+
|
504 |
+
# with columns[2]:
|
505 |
+
# fig=summary_plot(summary_df_sorted, x='Efficiency', y='Channel_name', title='Efficiency', text_column='Efficiency',color='Channel_name',format_as_decimal=True)
|
506 |
+
# st.plotly_chart(fig,use_container_width=True)
|
507 |
+
|
508 |
+
|
509 |
+
# Create figure with subplots
|
510 |
+
# fig = make_subplots(rows=1, cols=2)
|
511 |
+
|
512 |
+
# # Add funnel plot to subplot 1
|
513 |
+
# fig.add_trace(
|
514 |
+
# go.Funnel(
|
515 |
+
# x=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricValue'],
|
516 |
+
# y=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricName'],
|
517 |
+
# textposition="inside",
|
518 |
+
# texttemplate="%{x:.2s}",
|
519 |
+
# customdata=effectiveness_overall['Efficiency'],
|
520 |
+
# hovertemplate="%{customdata:.2f}<extra></extra>"
|
521 |
+
# ),
|
522 |
+
# row=1, col=1
|
523 |
+
# )
|
524 |
+
|
525 |
+
# # Add bar plot to subplot 2
|
526 |
+
# fig.add_trace(
|
527 |
+
# go.Bar(
|
528 |
+
# x=effectiveness_overall.sort_values(by='ResponseMetricValue')['Efficiency'],
|
529 |
+
# y=effectiveness_overall.sort_values(by='ResponseMetricValue')['ResponseMetricName'],
|
530 |
+
# marker_color=effectiveness_overall['ResponseMetric'],
|
531 |
+
# customdata=effectiveness_overall['Efficiency'],
|
532 |
+
# hovertemplate="%{customdata:.2f}<extra></extra>",
|
533 |
+
# textposition="outside"
|
534 |
+
# ),
|
535 |
+
# row=1, col=2
|
536 |
+
# )
|
537 |
+
|
538 |
+
# # Update layout
|
539 |
+
# fig.update_layout(title_text="Effectiveness")
|
540 |
+
# fig.update_yaxes(title_text="", row=1, col=1)
|
541 |
+
# fig.update_yaxes(title_text="", showticklabels=False, row=1, col=2)
|
542 |
+
# fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
|
543 |
+
|
544 |
+
# # Show figure
|
545 |
+
# st.plotly_chart(fig)
|
pages/1_Data_Import.py
ADDED
@@ -0,0 +1,1501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import streamlit as st
|
3 |
+
import os
|
4 |
+
|
5 |
+
st.set_page_config(
|
6 |
+
page_title="Data Import",
|
7 |
+
page_icon=":shark:",
|
8 |
+
layout="wide",
|
9 |
+
initial_sidebar_state="collapsed",
|
10 |
+
)
|
11 |
+
|
12 |
+
import pickle
|
13 |
+
import pandas as pd
|
14 |
+
from utilities import set_header, load_local_css,update_db,project_selection
|
15 |
+
import streamlit_authenticator as stauth
|
16 |
+
import yaml
|
17 |
+
from yaml import SafeLoader
|
18 |
+
import sqlite3
|
19 |
+
|
20 |
+
load_local_css("styles.css")
|
21 |
+
set_header()
|
22 |
+
|
23 |
+
for k, v in st.session_state.items():
|
24 |
+
if (
|
25 |
+
k not in ["logout", "login", "config"]
|
26 |
+
and not k.startswith("FormSubmitter")
|
27 |
+
and not k.startswith("data-editor")
|
28 |
+
):
|
29 |
+
st.session_state[k] = v
|
30 |
+
with open("config.yaml") as file:
|
31 |
+
config = yaml.load(file, Loader=SafeLoader)
|
32 |
+
st.session_state["config"] = config
|
33 |
+
authenticator = stauth.Authenticate(
|
34 |
+
config["credentials"],
|
35 |
+
config["cookie"]["name"],
|
36 |
+
config["cookie"]["key"],
|
37 |
+
config["cookie"]["expiry_days"],
|
38 |
+
config["preauthorized"],
|
39 |
+
)
|
40 |
+
st.session_state["authenticator"] = authenticator
|
41 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
42 |
+
auth_status = st.session_state.get("authentication_status")
|
43 |
+
|
44 |
+
if auth_status == True:
|
45 |
+
authenticator.logout("Logout", "main")
|
46 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
47 |
+
|
48 |
+
if not is_state_initiaized:
|
49 |
+
a=1
|
50 |
+
|
51 |
+
if "project_name" not in st.session_state:
|
52 |
+
st.session_state["project_name"] = None
|
53 |
+
|
54 |
+
if "project_dct" not in st.session_state:
|
55 |
+
# home()
|
56 |
+
project_selection(name)
|
57 |
+
#st.write(st.session_state['project_name'])
|
58 |
+
|
59 |
+
cols1 = st.columns([2, 1])
|
60 |
+
|
61 |
+
with cols1[0]:
|
62 |
+
st.markdown(f"**Welcome {name}**")
|
63 |
+
with cols1[1]:
|
64 |
+
st.markdown(
|
65 |
+
f"**Current Project: {st.session_state['project_name']}**"
|
66 |
+
)
|
67 |
+
|
68 |
+
# Function to validate date column in dataframe
|
69 |
+
|
70 |
+
|
71 |
+
#st.warning("please select a project from Home page")
|
72 |
+
#st.stop()
|
73 |
+
|
74 |
+
def validate_date_column(df):
|
75 |
+
try:
|
76 |
+
# Attempt to convert the 'Date' column to datetime
|
77 |
+
df["date"] = pd.to_datetime(df["date"], format="%d-%m-%Y")
|
78 |
+
return True
|
79 |
+
except:
|
80 |
+
return False
|
81 |
+
|
82 |
+
# Function to determine data interval
|
83 |
+
def determine_data_interval(common_freq):
|
84 |
+
if common_freq == 1:
|
85 |
+
return "daily"
|
86 |
+
elif common_freq == 7:
|
87 |
+
return "weekly"
|
88 |
+
elif 28 <= common_freq <= 31:
|
89 |
+
return "monthly"
|
90 |
+
else:
|
91 |
+
return "irregular"
|
92 |
+
|
93 |
+
# Function to read each uploaded Excel file into a pandas DataFrame and stores them in a dictionary
|
94 |
+
st.cache_resource(show_spinner=False)
|
95 |
+
|
96 |
+
def files_to_dataframes(uploaded_files):
|
97 |
+
df_dict = {}
|
98 |
+
for uploaded_file in uploaded_files:
|
99 |
+
# Extract file name without extension
|
100 |
+
file_name = uploaded_file.name.rsplit(".", 1)[0]
|
101 |
+
|
102 |
+
# Check for duplicate file names
|
103 |
+
if file_name in df_dict:
|
104 |
+
st.warning(
|
105 |
+
f"Duplicate File: {file_name}. This file will be skipped.",
|
106 |
+
icon="⚠️",
|
107 |
+
)
|
108 |
+
continue
|
109 |
+
|
110 |
+
# Read the file into a DataFrame
|
111 |
+
df = pd.read_excel(uploaded_file)
|
112 |
+
|
113 |
+
# Convert all column names to lowercase
|
114 |
+
df.columns = df.columns.str.lower().str.strip()
|
115 |
+
|
116 |
+
# Separate numeric and non-numeric columns
|
117 |
+
numeric_cols = list(df.select_dtypes(include=["number"]).columns)
|
118 |
+
non_numeric_cols = [
|
119 |
+
col
|
120 |
+
for col in df.select_dtypes(exclude=["number"]).columns
|
121 |
+
if col.lower() != "date"
|
122 |
+
]
|
123 |
+
|
124 |
+
# Check for 'Date' column
|
125 |
+
if not (validate_date_column(df) and len(numeric_cols) > 0):
|
126 |
+
st.warning(
|
127 |
+
f"File Name: {file_name} ➜ Please upload data with Date column in 'DD-MM-YYYY' format and at least one media/exogenous column. This file will be skipped.",
|
128 |
+
icon="⚠️",
|
129 |
+
)
|
130 |
+
continue
|
131 |
+
|
132 |
+
# Check for interval
|
133 |
+
common_freq = common_freq = (
|
134 |
+
pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
|
135 |
+
)
|
136 |
+
# Calculate the data interval (daily, weekly, monthly or irregular)
|
137 |
+
interval = determine_data_interval(common_freq)
|
138 |
+
if interval == "irregular":
|
139 |
+
st.warning(
|
140 |
+
f"File Name: {file_name} ➜ Please upload data in daily, weekly or monthly interval. This file will be skipped.",
|
141 |
+
icon="⚠️",
|
142 |
+
)
|
143 |
+
continue
|
144 |
+
|
145 |
+
# Store both DataFrames in the dictionary under their respective keys
|
146 |
+
df_dict[file_name] = {
|
147 |
+
"numeric": numeric_cols,
|
148 |
+
"non_numeric": non_numeric_cols,
|
149 |
+
"interval": interval,
|
150 |
+
"df": df,
|
151 |
+
}
|
152 |
+
|
153 |
+
return df_dict
|
154 |
+
|
155 |
+
# Function to adjust dataframe granularity
|
156 |
+
def adjust_dataframe_granularity(df, current_granularity, target_granularity):
|
157 |
+
# Set index
|
158 |
+
df.set_index("date", inplace=True)
|
159 |
+
|
160 |
+
# Define aggregation rules for resampling
|
161 |
+
aggregation_rules = {
|
162 |
+
col: "sum" if pd.api.types.is_numeric_dtype(df[col]) else "first"
|
163 |
+
for col in df.columns
|
164 |
+
}
|
165 |
+
|
166 |
+
# Initialize resampled_df
|
167 |
+
resampled_df = df
|
168 |
+
if current_granularity == "daily" and target_granularity == "weekly":
|
169 |
+
resampled_df = df.resample("W-MON", closed="left", label="left").agg(
|
170 |
+
aggregation_rules
|
171 |
+
)
|
172 |
+
|
173 |
+
elif current_granularity == "daily" and target_granularity == "monthly":
|
174 |
+
resampled_df = df.resample("MS", closed="left", label="left").agg(
|
175 |
+
aggregation_rules
|
176 |
+
)
|
177 |
+
|
178 |
+
elif current_granularity == "daily" and target_granularity == "daily":
|
179 |
+
resampled_df = df.resample("D").agg(aggregation_rules)
|
180 |
+
|
181 |
+
elif (
|
182 |
+
current_granularity in ["weekly", "monthly"]
|
183 |
+
and target_granularity == "daily"
|
184 |
+
):
|
185 |
+
# For higher to lower granularity, distribute numeric and replicate non-numeric values equally across the new period
|
186 |
+
expanded_data = []
|
187 |
+
for _, row in df.iterrows():
|
188 |
+
if current_granularity == "weekly":
|
189 |
+
period_range = pd.date_range(start=row.name, periods=7)
|
190 |
+
elif current_granularity == "monthly":
|
191 |
+
period_range = pd.date_range(
|
192 |
+
start=row.name, periods=row.name.days_in_month
|
193 |
+
)
|
194 |
+
|
195 |
+
for date in period_range:
|
196 |
+
new_row = {}
|
197 |
+
for col in df.columns:
|
198 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
199 |
+
if current_granularity == "weekly":
|
200 |
+
new_row[col] = row[col] / 7
|
201 |
+
elif current_granularity == "monthly":
|
202 |
+
new_row[col] = row[col] / row.name.days_in_month
|
203 |
+
else:
|
204 |
+
new_row[col] = row[col]
|
205 |
+
expanded_data.append((date, new_row))
|
206 |
+
|
207 |
+
resampled_df = pd.DataFrame(
|
208 |
+
[data for _, data in expanded_data],
|
209 |
+
index=[date for date, _ in expanded_data],
|
210 |
+
)
|
211 |
+
|
212 |
+
# Reset index
|
213 |
+
resampled_df = resampled_df.reset_index().rename(columns={"index": "date"})
|
214 |
+
|
215 |
+
return resampled_df
|
216 |
+
|
217 |
+
# Function to clean and extract unique values of Panel_1 and Panel_2
|
218 |
+
st.cache_resource(show_spinner=False)
|
219 |
+
|
220 |
+
def clean_and_extract_unique_values(files_dict, selections):
|
221 |
+
all_panel1_values = set()
|
222 |
+
all_panel2_values = set()
|
223 |
+
|
224 |
+
for file_name, file_data in files_dict.items():
|
225 |
+
df = file_data["df"]
|
226 |
+
|
227 |
+
# 'Panel_1' and 'Panel_2' selections
|
228 |
+
selected_panel1 = selections[file_name].get("Panel_1")
|
229 |
+
selected_panel2 = selections[file_name].get("Panel_2")
|
230 |
+
|
231 |
+
# Clean and standardize Panel_1 column if it exists and is selected
|
232 |
+
if (
|
233 |
+
selected_panel1
|
234 |
+
and selected_panel1 != "N/A"
|
235 |
+
and selected_panel1 in df.columns
|
236 |
+
):
|
237 |
+
df[selected_panel1] = (
|
238 |
+
df[selected_panel1].str.lower().str.strip().str.replace("_", " ")
|
239 |
+
)
|
240 |
+
all_panel1_values.update(df[selected_panel1].dropna().unique())
|
241 |
+
|
242 |
+
# Clean and standardize Panel_2 column if it exists and is selected
|
243 |
+
if (
|
244 |
+
selected_panel2
|
245 |
+
and selected_panel2 != "N/A"
|
246 |
+
and selected_panel2 in df.columns
|
247 |
+
):
|
248 |
+
df[selected_panel2] = (
|
249 |
+
df[selected_panel2].str.lower().str.strip().str.replace("_", " ")
|
250 |
+
)
|
251 |
+
all_panel2_values.update(df[selected_panel2].dropna().unique())
|
252 |
+
|
253 |
+
# Update the processed DataFrame back in the dictionary
|
254 |
+
files_dict[file_name]["df"] = df
|
255 |
+
|
256 |
+
return all_panel1_values, all_panel2_values
|
257 |
+
|
258 |
+
# Function to format values for display
|
259 |
+
st.cache_resource(show_spinner=False)
|
260 |
+
|
261 |
+
def format_values_for_display(values_list):
|
262 |
+
# Capitalize the first letter of each word and replace underscores with spaces
|
263 |
+
formatted_list = [value.replace("_", " ").title() for value in values_list]
|
264 |
+
# Join values with commas and 'and' before the last value
|
265 |
+
if len(formatted_list) > 1:
|
266 |
+
return ", ".join(formatted_list[:-1]) + ", and " + formatted_list[-1]
|
267 |
+
elif formatted_list:
|
268 |
+
return formatted_list[0]
|
269 |
+
return "No values available"
|
270 |
+
|
271 |
+
# Function to normalizes all data within files_dict to a daily granularity
|
272 |
+
st.cache(show_spinner=False, allow_output_mutation=True)
|
273 |
+
|
274 |
+
def standardize_data_to_daily(files_dict, selections):
|
275 |
+
# Normalize all data to a daily granularity using a provided function
|
276 |
+
files_dict = apply_granularity_to_all(files_dict, "daily", selections)
|
277 |
+
|
278 |
+
# Update the "interval" attribute for each dataset to indicate the new granularity
|
279 |
+
for files_name, files_data in files_dict.items():
|
280 |
+
files_data["interval"] = "daily"
|
281 |
+
|
282 |
+
return files_dict
|
283 |
+
|
284 |
+
# Function to apply granularity transformation to all DataFrames in files_dict
|
285 |
+
st.cache_resource(show_spinner=False)
|
286 |
+
|
287 |
+
def apply_granularity_to_all(files_dict, granularity_selection, selections):
|
288 |
+
for file_name, file_data in files_dict.items():
|
289 |
+
df = file_data["df"].copy()
|
290 |
+
|
291 |
+
# Handling when Panel_1 or Panel_2 might be 'N/A'
|
292 |
+
selected_panel1 = selections[file_name].get("Panel_1")
|
293 |
+
selected_panel2 = selections[file_name].get("Panel_2")
|
294 |
+
|
295 |
+
# Correcting the segment selection logic & handling 'N/A'
|
296 |
+
if selected_panel1 != "N/A" and selected_panel2 != "N/A":
|
297 |
+
unique_combinations = df[
|
298 |
+
[selected_panel1, selected_panel2]
|
299 |
+
].drop_duplicates()
|
300 |
+
elif selected_panel1 != "N/A":
|
301 |
+
unique_combinations = df[[selected_panel1]].drop_duplicates()
|
302 |
+
selected_panel2 = None # Ensure Panel_2 is ignored if N/A
|
303 |
+
elif selected_panel2 != "N/A":
|
304 |
+
unique_combinations = df[[selected_panel2]].drop_duplicates()
|
305 |
+
selected_panel1 = None # Ensure Panel_1 is ignored if N/A
|
306 |
+
else:
|
307 |
+
# If both are 'N/A', process the entire dataframe as is
|
308 |
+
df = adjust_dataframe_granularity(
|
309 |
+
df, file_data["interval"], granularity_selection
|
310 |
+
)
|
311 |
+
files_dict[file_name]["df"] = df
|
312 |
+
continue # Skip to the next file
|
313 |
+
|
314 |
+
transformed_segments = []
|
315 |
+
for _, combo in unique_combinations.iterrows():
|
316 |
+
if selected_panel1 and selected_panel2:
|
317 |
+
segment = df[
|
318 |
+
(df[selected_panel1] == combo[selected_panel1])
|
319 |
+
& (df[selected_panel2] == combo[selected_panel2])
|
320 |
+
]
|
321 |
+
elif selected_panel1:
|
322 |
+
segment = df[df[selected_panel1] == combo[selected_panel1]]
|
323 |
+
elif selected_panel2:
|
324 |
+
segment = df[df[selected_panel2] == combo[selected_panel2]]
|
325 |
+
|
326 |
+
# Adjust granularity of the segment
|
327 |
+
transformed_segment = adjust_dataframe_granularity(
|
328 |
+
segment, file_data["interval"], granularity_selection
|
329 |
+
)
|
330 |
+
transformed_segments.append(transformed_segment)
|
331 |
+
|
332 |
+
# Combine all transformed segments into a single DataFrame for this file
|
333 |
+
transformed_df = pd.concat(transformed_segments, ignore_index=True)
|
334 |
+
files_dict[file_name]["df"] = transformed_df
|
335 |
+
|
336 |
+
return files_dict
|
337 |
+
|
338 |
+
# Function to create main dataframe structure
|
339 |
+
st.cache_resource(show_spinner=False)
|
340 |
+
|
341 |
+
def create_main_dataframe(
|
342 |
+
files_dict, all_panel1_values, all_panel2_values, granularity_selection
|
343 |
+
):
|
344 |
+
# Determine the global start and end dates across all DataFrames
|
345 |
+
global_start = min(df["df"]["date"].min() for df in files_dict.values())
|
346 |
+
global_end = max(df["df"]["date"].max() for df in files_dict.values())
|
347 |
+
|
348 |
+
# Adjust the date_range generation based on the granularity_selection
|
349 |
+
if granularity_selection == "weekly":
|
350 |
+
# Generate a weekly range, with weeks starting on Monday
|
351 |
+
date_range = pd.date_range(start=global_start, end=global_end, freq="W-MON")
|
352 |
+
elif granularity_selection == "monthly":
|
353 |
+
# Generate a monthly range, starting from the first day of each month
|
354 |
+
date_range = pd.date_range(start=global_start, end=global_end, freq="MS")
|
355 |
+
else: # Default to daily if not weekly or monthly
|
356 |
+
date_range = pd.date_range(start=global_start, end=global_end, freq="D")
|
357 |
+
|
358 |
+
# Collect all unique Panel_1 and Panel_2 values, excluding 'N/A'
|
359 |
+
all_panel1s = all_panel1_values
|
360 |
+
all_panel2s = all_panel2_values
|
361 |
+
|
362 |
+
# Dynamically build the list of dimensions (Panel_1, Panel_2) to include in the main DataFrame based on availability
|
363 |
+
dimensions, merge_keys = [], []
|
364 |
+
if all_panel1s:
|
365 |
+
dimensions.append(all_panel1s)
|
366 |
+
merge_keys.append("Panel_1")
|
367 |
+
if all_panel2s:
|
368 |
+
dimensions.append(all_panel2s)
|
369 |
+
merge_keys.append("Panel_2")
|
370 |
+
|
371 |
+
dimensions.append(date_range) # Date range is always included
|
372 |
+
merge_keys.append("date") # Date range is always included
|
373 |
+
|
374 |
+
# Create a main DataFrame template with the dimensions
|
375 |
+
main_df = pd.MultiIndex.from_product(
|
376 |
+
dimensions,
|
377 |
+
names=[name for name, _ in zip(merge_keys, dimensions)],
|
378 |
+
).to_frame(index=False)
|
379 |
+
|
380 |
+
return main_df.reset_index(drop=True)
|
381 |
+
|
382 |
+
# Function to prepare and merge dataFrames
|
383 |
+
st.cache_resource(show_spinner=False)
|
384 |
+
|
385 |
+
def merge_into_main_df(main_df, files_dict, selections):
|
386 |
+
for file_name, file_data in files_dict.items():
|
387 |
+
df = file_data["df"].copy()
|
388 |
+
|
389 |
+
# Rename selected Panel_1 and Panel_2 columns if not 'N/A'
|
390 |
+
selected_panel1 = selections[file_name].get("Panel_1", "N/A")
|
391 |
+
selected_panel2 = selections[file_name].get("Panel_2", "N/A")
|
392 |
+
if selected_panel1 != "N/A":
|
393 |
+
df.rename(columns={selected_panel1: "Panel_1"}, inplace=True)
|
394 |
+
if selected_panel2 != "N/A":
|
395 |
+
df.rename(columns={selected_panel2: "Panel_2"}, inplace=True)
|
396 |
+
|
397 |
+
# Merge current DataFrame into main_df based on 'date', and where applicable, 'Panel_1' and 'Panel_2'
|
398 |
+
merge_keys = ["date"]
|
399 |
+
if "Panel_1" in df.columns:
|
400 |
+
merge_keys.append("Panel_1")
|
401 |
+
if "Panel_2" in df.columns:
|
402 |
+
merge_keys.append("Panel_2")
|
403 |
+
main_df = pd.merge(main_df, df, on=merge_keys, how="left")
|
404 |
+
|
405 |
+
# After all merges, sort by 'date' and reset index for cleanliness
|
406 |
+
sort_by = ["date"]
|
407 |
+
if "Panel_1" in main_df.columns:
|
408 |
+
sort_by.append("Panel_1")
|
409 |
+
if "Panel_2" in main_df.columns:
|
410 |
+
sort_by.append("Panel_2")
|
411 |
+
main_df.sort_values(by=sort_by, inplace=True)
|
412 |
+
main_df.reset_index(drop=True, inplace=True)
|
413 |
+
|
414 |
+
return main_df
|
415 |
+
|
416 |
+
# Function to categorize column
|
417 |
+
def categorize_column(column_name):
|
418 |
+
# Define keywords for each category
|
419 |
+
internal_keywords = [
|
420 |
+
"Price",
|
421 |
+
"Discount",
|
422 |
+
"product_price",
|
423 |
+
"cost",
|
424 |
+
"margin",
|
425 |
+
"inventory",
|
426 |
+
"sales",
|
427 |
+
"revenue",
|
428 |
+
"turnover",
|
429 |
+
"expense",
|
430 |
+
]
|
431 |
+
exogenous_keywords = [
|
432 |
+
"GDP",
|
433 |
+
"Tax",
|
434 |
+
"Inflation",
|
435 |
+
"interest_rate",
|
436 |
+
"employment_rate",
|
437 |
+
"exchange_rate",
|
438 |
+
"consumer_spending",
|
439 |
+
"retail_sales",
|
440 |
+
"oil_prices",
|
441 |
+
"weather",
|
442 |
+
]
|
443 |
+
|
444 |
+
# Check if the column name matches any of the keywords for Internal or Exogenous categories
|
445 |
+
|
446 |
+
if (
|
447 |
+
column_name
|
448 |
+
in st.session_state["project_dct"]["data_import"]["cat_dct"].keys()
|
449 |
+
and st.session_state["project_dct"]["data_import"]["cat_dct"][column_name]
|
450 |
+
is not None
|
451 |
+
):
|
452 |
+
|
453 |
+
return st.session_state["project_dct"]["data_import"]["cat_dct"][
|
454 |
+
column_name
|
455 |
+
] # resume project manoj
|
456 |
+
|
457 |
+
else:
|
458 |
+
for keyword in internal_keywords:
|
459 |
+
if keyword.lower() in column_name.lower():
|
460 |
+
return "Internal"
|
461 |
+
for keyword in exogenous_keywords:
|
462 |
+
if keyword.lower() in column_name.lower():
|
463 |
+
return "Exogenous"
|
464 |
+
|
465 |
+
# Default to Media if no match found
|
466 |
+
return "Media"
|
467 |
+
|
468 |
+
# Function to calculate missing stats and prepare for editable DataFrame
|
469 |
+
st.cache_resource(show_spinner=False)
|
470 |
+
|
471 |
+
def prepare_missing_stats_df(df):
|
472 |
+
missing_stats = []
|
473 |
+
for column in df.columns:
|
474 |
+
if (
|
475 |
+
column == "date" or column == "Panel_2" or column == "Panel_1"
|
476 |
+
): # Skip Date, Panel_1 and Panel_2 column
|
477 |
+
continue
|
478 |
+
|
479 |
+
missing = df[column].isnull().sum()
|
480 |
+
pct_missing = round((missing / len(df)) * 100, 2)
|
481 |
+
|
482 |
+
# Dynamically assign category based on column name
|
483 |
+
category = categorize_column(column)
|
484 |
+
# category = "Media" # Keep default bin as Media
|
485 |
+
|
486 |
+
missing_stats.append(
|
487 |
+
{
|
488 |
+
"Column": column,
|
489 |
+
"Missing Values": missing,
|
490 |
+
"Missing Percentage": pct_missing,
|
491 |
+
"Impute Method": "Fill with 0", # Default value
|
492 |
+
"Category": category,
|
493 |
+
}
|
494 |
+
)
|
495 |
+
|
496 |
+
stats_df = pd.DataFrame(missing_stats)
|
497 |
+
|
498 |
+
return stats_df
|
499 |
+
|
500 |
+
# Function to add API DataFrame details to the files dictionary
|
501 |
+
st.cache_resource(show_spinner=False)
|
502 |
+
|
503 |
+
def add_api_dataframe_to_dict(main_df, files_dict):
|
504 |
+
files_dict["API"] = {
|
505 |
+
"numeric": list(main_df.select_dtypes(include=["number"]).columns),
|
506 |
+
"non_numeric": [
|
507 |
+
col
|
508 |
+
for col in main_df.select_dtypes(exclude=["number"]).columns
|
509 |
+
if col.lower() != "date"
|
510 |
+
],
|
511 |
+
"interval": determine_data_interval(
|
512 |
+
pd.Series(main_df["date"].unique()).diff().dt.days.dropna().mode()[0]
|
513 |
+
),
|
514 |
+
"df": main_df,
|
515 |
+
}
|
516 |
+
|
517 |
+
return files_dict
|
518 |
+
|
519 |
+
# Function to reads an API into a DataFrame, parsing specified columns as datetime
|
520 |
+
@st.cache_resource(show_spinner=False)
|
521 |
+
def read_API_data():
|
522 |
+
return pd.read_excel(
|
523 |
+
r"./upf_data_converted_randomized_resp_metrics.xlsx",
|
524 |
+
parse_dates=["Date"],
|
525 |
+
)
|
526 |
+
|
527 |
+
# Function to set the 'Panel_1_Panel_2_Selected' session state variable to False
|
528 |
+
def set_Panel_1_Panel_2_Selected_false():
|
529 |
+
|
530 |
+
st.session_state["Panel_1_Panel_2_Selected"] = False
|
531 |
+
|
532 |
+
# restoring project_dct to default values when user modify any widjets
|
533 |
+
st.session_state["project_dct"]["data_import"]["edited_stats_df"] = None
|
534 |
+
st.session_state["project_dct"]["data_import"]["merged_df"] = None
|
535 |
+
st.session_state["project_dct"]["data_import"]["missing_stats_df"] = None
|
536 |
+
st.session_state["project_dct"]["data_import"]["cat_dct"] = {}
|
537 |
+
st.session_state["project_dct"]["data_import"]["numeric_columns"] = None
|
538 |
+
st.session_state["project_dct"]["data_import"]["default_df"] = None
|
539 |
+
st.session_state["project_dct"]["data_import"]["final_df"] = None
|
540 |
+
st.session_state["project_dct"]["data_import"]["edited_df"] = None
|
541 |
+
|
542 |
+
# Function to serialize and save the objects into a pickle file
|
543 |
+
@st.cache_resource(show_spinner=False)
|
544 |
+
def save_to_pickle(file_path, final_df, bin_dict):
|
545 |
+
# Open the file in write-binary mode and dump the objects
|
546 |
+
with open(file_path, "wb") as f:
|
547 |
+
pickle.dump({"final_df": final_df, "bin_dict": bin_dict}, f)
|
548 |
+
# Data is now saved to file
|
549 |
+
|
550 |
+
# Function to processes the merged_df DataFrame based on operations defined in edited_df
|
551 |
+
@st.cache_resource(show_spinner=False)
|
552 |
+
def process_dataframes(merged_df, edited_df, edited_stats_df):
|
553 |
+
# Ensure there are operations defined by the user
|
554 |
+
if edited_df.empty:
|
555 |
+
|
556 |
+
return merged_df, edited_stats_df # No operations to apply
|
557 |
+
|
558 |
+
# Perform operations as defined by the user
|
559 |
+
else:
|
560 |
+
|
561 |
+
for index, row in edited_df.iterrows():
|
562 |
+
result_column_name = (
|
563 |
+
f"{row['Column 1']}{row['Operator']}{row['Column 2']}"
|
564 |
+
)
|
565 |
+
col1 = row["Column 1"]
|
566 |
+
col2 = row["Column 2"]
|
567 |
+
op = row["Operator"]
|
568 |
+
|
569 |
+
# Apply the specified operation
|
570 |
+
if op == "+":
|
571 |
+
merged_df[result_column_name] = merged_df[col1] + merged_df[col2]
|
572 |
+
elif op == "-":
|
573 |
+
merged_df[result_column_name] = merged_df[col1] - merged_df[col2]
|
574 |
+
elif op == "*":
|
575 |
+
merged_df[result_column_name] = merged_df[col1] * merged_df[col2]
|
576 |
+
elif op == "/":
|
577 |
+
merged_df[result_column_name] = merged_df[col1] / merged_df[
|
578 |
+
col2
|
579 |
+
].replace(0, 1e-9)
|
580 |
+
|
581 |
+
# Add summary of operation to edited_stats_df
|
582 |
+
new_row = {
|
583 |
+
"Column": result_column_name,
|
584 |
+
"Missing Values": None,
|
585 |
+
"Missing Percentage": None,
|
586 |
+
"Impute Method": None,
|
587 |
+
"Category": row["Category"],
|
588 |
+
}
|
589 |
+
new_row_df = pd.DataFrame([new_row])
|
590 |
+
|
591 |
+
# Use pd.concat to add the new_row_df to edited_stats_df
|
592 |
+
edited_stats_df = pd.concat(
|
593 |
+
[edited_stats_df, new_row_df], ignore_index=True, axis=0
|
594 |
+
)
|
595 |
+
|
596 |
+
# Combine column names from edited_df for cleanup
|
597 |
+
combined_columns = set(edited_df["Column 1"]).union(
|
598 |
+
set(edited_df["Column 2"])
|
599 |
+
)
|
600 |
+
|
601 |
+
# Filter out rows in edited_stats_df and drop columns from merged_df
|
602 |
+
edited_stats_df = edited_stats_df[
|
603 |
+
~edited_stats_df["Column"].isin(combined_columns)
|
604 |
+
]
|
605 |
+
merged_df.drop(
|
606 |
+
columns=list(combined_columns), errors="ignore", inplace=True
|
607 |
+
)
|
608 |
+
|
609 |
+
return merged_df, edited_stats_df
|
610 |
+
|
611 |
+
# Function to prepare a list of numeric column names and initialize an empty DataFrame with predefined structure
|
612 |
+
st.cache_resource(show_spinner=False)
|
613 |
+
|
614 |
+
def prepare_numeric_columns_and_default_df(merged_df, edited_stats_df):
|
615 |
+
# Get columns categorized as 'Response Metrics'
|
616 |
+
columns_response_metrics = edited_stats_df[
|
617 |
+
edited_stats_df["Category"] == "Response Metrics"
|
618 |
+
]["Column"].tolist()
|
619 |
+
|
620 |
+
# Filter numeric columns, excluding those categorized as 'Response Metrics'
|
621 |
+
numeric_columns = [
|
622 |
+
col
|
623 |
+
for col in merged_df.select_dtypes(include=["number"]).columns
|
624 |
+
if col not in columns_response_metrics
|
625 |
+
]
|
626 |
+
|
627 |
+
# Define the structure of the empty DataFrame
|
628 |
+
data = {
|
629 |
+
"Column 1": pd.Series([], dtype="str"),
|
630 |
+
"Operator": pd.Series([], dtype="str"),
|
631 |
+
"Column 2": pd.Series([], dtype="str"),
|
632 |
+
"Category": pd.Series([], dtype="str"),
|
633 |
+
}
|
634 |
+
default_df = pd.DataFrame(data)
|
635 |
+
|
636 |
+
return numeric_columns, default_df
|
637 |
+
|
638 |
+
# function to reset to default values in project_dct:
|
639 |
+
|
640 |
+
# Initialize 'final_df' in session state
|
641 |
+
if "final_df" not in st.session_state:
|
642 |
+
st.session_state["final_df"] = pd.DataFrame()
|
643 |
+
|
644 |
+
# Initialize 'bin_dict' in session state
|
645 |
+
if "bin_dict" not in st.session_state:
|
646 |
+
st.session_state["bin_dict"] = {}
|
647 |
+
|
648 |
+
# Initialize 'Panel_1_Panel_2_Selected' in session state
|
649 |
+
if "Panel_1_Panel_2_Selected" not in st.session_state:
|
650 |
+
st.session_state["Panel_1_Panel_2_Selected"] = False
|
651 |
+
|
652 |
+
# Page Title
|
653 |
+
st.write("") # Top padding
|
654 |
+
st.title("Data Import")
|
655 |
+
|
656 |
+
conn = sqlite3.connect(
|
657 |
+
r"DB\User.db", check_same_thread=False
|
658 |
+
) # connection with sql db
|
659 |
+
c = conn.cursor()
|
660 |
+
|
661 |
+
#########################################################################################################################################################
|
662 |
+
# Create a dictionary to hold all DataFrames and collect user input to specify "Panel_2" and "Panel_1" columns for each file
|
663 |
+
#########################################################################################################################################################
|
664 |
+
|
665 |
+
# Read the Excel file, parsing 'Date' column as datetime
|
666 |
+
main_df = read_API_data()
|
667 |
+
|
668 |
+
# Convert all column names to lowercase
|
669 |
+
main_df.columns = main_df.columns.str.lower().str.strip()
|
670 |
+
|
671 |
+
# File uploader
|
672 |
+
uploaded_files = st.file_uploader(
|
673 |
+
"Upload additional data",
|
674 |
+
type=["xlsx"],
|
675 |
+
accept_multiple_files=True,
|
676 |
+
on_change=set_Panel_1_Panel_2_Selected_false,
|
677 |
+
)
|
678 |
+
|
679 |
+
# Custom HTML for upload instructions
|
680 |
+
recommendation_html = f"""
|
681 |
+
<div style="text-align: justify;">
|
682 |
+
<strong>Recommendation:</strong> For optimal processing, please ensure that all uploaded datasets including panel, media, internal, and exogenous data adhere to the following guidelines: Each dataset must include a <code>Date</code> column formatted as <code>DD-MM-YYYY</code>, be free of missing values.
|
683 |
+
</div>
|
684 |
+
"""
|
685 |
+
st.markdown(recommendation_html, unsafe_allow_html=True)
|
686 |
+
|
687 |
+
# Choose Desired Granularity
|
688 |
+
st.markdown("#### Choose Desired Granularity")
|
689 |
+
# Granularity Selection
|
690 |
+
|
691 |
+
granularity_selection = st.selectbox(
|
692 |
+
"Choose Date Granularity",
|
693 |
+
["Daily", "Weekly", "Monthly"],
|
694 |
+
label_visibility="collapsed",
|
695 |
+
on_change=set_Panel_1_Panel_2_Selected_false,
|
696 |
+
index=st.session_state["project_dct"]["data_import"][
|
697 |
+
"granularity_selection"
|
698 |
+
], # resume
|
699 |
+
)
|
700 |
+
|
701 |
+
# st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
|
702 |
+
|
703 |
+
st.session_state["project_dct"]["data_import"]["granularity_selection"] = [
|
704 |
+
"Daily",
|
705 |
+
"Weekly",
|
706 |
+
"Monthly",
|
707 |
+
].index(granularity_selection)
|
708 |
+
# st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
|
709 |
+
granularity_selection = str(granularity_selection).lower()
|
710 |
+
|
711 |
+
# Convert files to dataframes
|
712 |
+
files_dict = files_to_dataframes(uploaded_files)
|
713 |
+
|
714 |
+
# Add API Dataframe
|
715 |
+
if main_df is not None:
|
716 |
+
files_dict = add_api_dataframe_to_dict(main_df, files_dict)
|
717 |
+
|
718 |
+
# Display a warning message if no files have been uploaded and halt further execution
|
719 |
+
if not files_dict:
|
720 |
+
st.warning(
|
721 |
+
"Please upload at least one file to proceed.",
|
722 |
+
icon="⚠️",
|
723 |
+
)
|
724 |
+
st.stop() # Halts further execution until file is uploaded
|
725 |
+
|
726 |
+
# Select Panel_1 and Panel_2 columns
|
727 |
+
st.markdown("#### Select Panel columns")
|
728 |
+
selections = {}
|
729 |
+
with st.expander("Select Panel columns", expanded=False):
|
730 |
+
count = 0 # Initialize counter to manage the visibility of labels and keys
|
731 |
+
for file_name, file_data in files_dict.items():
|
732 |
+
|
733 |
+
# generatimg project dct keys dynamically
|
734 |
+
if (
|
735 |
+
f"Panel_1_selectbox{file_name}"
|
736 |
+
not in st.session_state["project_dct"]["data_import"].keys()
|
737 |
+
):
|
738 |
+
st.session_state["project_dct"]["data_import"][
|
739 |
+
f"Panel_1_selectbox{file_name}"
|
740 |
+
] = 0
|
741 |
+
|
742 |
+
if (
|
743 |
+
f"Panel_2_selectbox{file_name}"
|
744 |
+
not in st.session_state["project_dct"]["data_import"].keys()
|
745 |
+
):
|
746 |
+
|
747 |
+
st.session_state["project_dct"]["data_import"][
|
748 |
+
f"Panel_2_selectbox{file_name}"
|
749 |
+
] = 0
|
750 |
+
|
751 |
+
# Determine visibility of the label based on the count
|
752 |
+
if count == 0:
|
753 |
+
label_visibility = "visible"
|
754 |
+
else:
|
755 |
+
label_visibility = "collapsed"
|
756 |
+
|
757 |
+
# Extract non-numeric columns
|
758 |
+
non_numeric_cols = file_data["non_numeric"]
|
759 |
+
|
760 |
+
# Prepare Panel_1 and Panel_2 values for dropdown, adding "N/A" as an option
|
761 |
+
panel1_values = non_numeric_cols + ["N/A"]
|
762 |
+
panel2_values = non_numeric_cols + ["N/A"]
|
763 |
+
|
764 |
+
# Skip if only one option is available
|
765 |
+
if len(panel1_values) == 1 and len(panel2_values) == 1:
|
766 |
+
selected_panel1, selected_panel2 = "N/A", "N/A"
|
767 |
+
# Update the selections for Panel_1 and Panel_2 for the current file
|
768 |
+
selections[file_name] = {
|
769 |
+
"Panel_1": selected_panel1,
|
770 |
+
"Panel_2": selected_panel2,
|
771 |
+
}
|
772 |
+
continue
|
773 |
+
|
774 |
+
# Create layout columns for File Name, Panel_2, and Panel_1 selections
|
775 |
+
file_name_col, Panel_1_col, Panel_2_col = st.columns([2, 4, 4])
|
776 |
+
|
777 |
+
with file_name_col:
|
778 |
+
# Display "File Name" label only for the first file
|
779 |
+
if count == 0:
|
780 |
+
st.write("File Name")
|
781 |
+
else:
|
782 |
+
st.write("")
|
783 |
+
st.write(file_name) # Display the file name
|
784 |
+
|
785 |
+
with Panel_1_col:
|
786 |
+
# Display a selectbox for Panel_1 values
|
787 |
+
selected_panel1 = st.selectbox(
|
788 |
+
"Select Panel Level 1",
|
789 |
+
panel2_values,
|
790 |
+
on_change=set_Panel_1_Panel_2_Selected_false,
|
791 |
+
label_visibility=label_visibility, # Control visibility of the label
|
792 |
+
key=f"Panel_1_selectbox{count}", # Ensure unique key for each selectbox
|
793 |
+
index=st.session_state["project_dct"]["data_import"][
|
794 |
+
f"Panel_1_selectbox{file_name}"
|
795 |
+
],
|
796 |
+
)
|
797 |
+
|
798 |
+
st.session_state["project_dct"]["data_import"][
|
799 |
+
f"Panel_1_selectbox{file_name}"
|
800 |
+
] = panel2_values.index(selected_panel1)
|
801 |
+
|
802 |
+
with Panel_2_col:
|
803 |
+
# Display a selectbox for Panel_2 values
|
804 |
+
selected_panel2 = st.selectbox(
|
805 |
+
"Select Panel Level 2",
|
806 |
+
panel1_values,
|
807 |
+
on_change=set_Panel_1_Panel_2_Selected_false,
|
808 |
+
label_visibility=label_visibility, # Control visibility of the label
|
809 |
+
key=f"Panel_2_selectbox{count}", # Ensure unique key for each selectbox
|
810 |
+
index=st.session_state["project_dct"]["data_import"][
|
811 |
+
f"Panel_2_selectbox{file_name}"
|
812 |
+
],
|
813 |
+
)
|
814 |
+
|
815 |
+
st.session_state["project_dct"]["data_import"][
|
816 |
+
f"Panel_2_selectbox{file_name}"
|
817 |
+
] = panel1_values.index(selected_panel2)
|
818 |
+
|
819 |
+
# st.write(st.session_state['project_dct']['data_import'][f"Panel_2_selectbox{file_name}"])
|
820 |
+
|
821 |
+
# Skip processing if the same column is selected for both Panel_1 and Panel_2 due to potential data integrity issues
|
822 |
+
|
823 |
+
if selected_panel2 == selected_panel1 and not (
|
824 |
+
selected_panel2 == "N/A" and selected_panel1 == "N/A"
|
825 |
+
):
|
826 |
+
st.warning(
|
827 |
+
f"File: {file_name} → The same column cannot serve as both Panel_1 and Panel_2. Please adjust your selections.",
|
828 |
+
)
|
829 |
+
selected_panel1, selected_panel2 = "N/A", "N/A"
|
830 |
+
st.stop()
|
831 |
+
|
832 |
+
# Update the selections for Panel_1 and Panel_2 for the current file
|
833 |
+
selections[file_name] = {
|
834 |
+
"Panel_1": selected_panel1,
|
835 |
+
"Panel_2": selected_panel2,
|
836 |
+
}
|
837 |
+
|
838 |
+
count += 1 # Increment the counter after processing each file
|
839 |
+
st.write()
|
840 |
+
# Accept Panel_1 and Panel_2 selection
|
841 |
+
accept = st.button(
|
842 |
+
"Accept and Process", use_container_width=True
|
843 |
+
) # resume project manoj
|
844 |
+
|
845 |
+
if (
|
846 |
+
accept == False
|
847 |
+
and st.session_state["project_dct"]["data_import"]["edited_stats_df"]
|
848 |
+
is not None
|
849 |
+
):
|
850 |
+
|
851 |
+
# st.write(st.session_state['project_dct'])
|
852 |
+
st.markdown("#### Unique Panel values")
|
853 |
+
# Display Panel_1 and Panel_2 values
|
854 |
+
with st.expander("Unique Panel values"):
|
855 |
+
st.write("")
|
856 |
+
st.markdown(
|
857 |
+
f"""
|
858 |
+
<style>
|
859 |
+
.justify-text {{
|
860 |
+
text-align: justify;
|
861 |
+
}}
|
862 |
+
</style>
|
863 |
+
<div class="justify-text">
|
864 |
+
<strong>Panel Level 1 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel1_values']}<br>
|
865 |
+
<strong>Panel Level 2 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel2_values']}
|
866 |
+
</div>
|
867 |
+
""",
|
868 |
+
unsafe_allow_html=True,
|
869 |
+
)
|
870 |
+
|
871 |
+
# Display total Panel_1 and Panel_2
|
872 |
+
st.write("")
|
873 |
+
st.markdown(
|
874 |
+
f"""
|
875 |
+
<div style="text-align: justify;">
|
876 |
+
<strong>Number of Level 1 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}<br>
|
877 |
+
<strong>Number of Level 2 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}
|
878 |
+
</div>
|
879 |
+
""",
|
880 |
+
unsafe_allow_html=True,
|
881 |
+
)
|
882 |
+
st.write("")
|
883 |
+
|
884 |
+
# Create an editable DataFrame in Streamlit
|
885 |
+
|
886 |
+
st.markdown("#### Select Variables Category & Impute Missing Values")
|
887 |
+
|
888 |
+
# data_temp_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
|
889 |
+
|
890 |
+
# with open(data_temp_path,"rb") as f:
|
891 |
+
# saved_edited_stats_df=pickle.load(f)
|
892 |
+
|
893 |
+
# a=st.data_editor(saved_edited_stats_df)
|
894 |
+
|
895 |
+
merged_df = st.session_state["project_dct"]["data_import"]["merged_df"].copy()
|
896 |
+
|
897 |
+
missing_stats_df = st.session_state["project_dct"]["data_import"][
|
898 |
+
"missing_stats_df"
|
899 |
+
]
|
900 |
+
|
901 |
+
editable_df = st.session_state["project_dct"]["data_import"]["edited_stats_df"]
|
902 |
+
sorted_editable_df = editable_df.sort_values(
|
903 |
+
by="Missing Values", ascending=False, na_position="first"
|
904 |
+
)
|
905 |
+
|
906 |
+
edited_stats_df = st.data_editor(
|
907 |
+
sorted_editable_df,
|
908 |
+
column_config={
|
909 |
+
"Impute Method": st.column_config.SelectboxColumn(
|
910 |
+
options=[
|
911 |
+
"Drop Column",
|
912 |
+
"Fill with Mean",
|
913 |
+
"Fill with Median",
|
914 |
+
"Fill with 0",
|
915 |
+
],
|
916 |
+
required=True,
|
917 |
+
default="Fill with 0",
|
918 |
+
),
|
919 |
+
"Category": st.column_config.SelectboxColumn(
|
920 |
+
options=[
|
921 |
+
"Media",
|
922 |
+
"Exogenous",
|
923 |
+
"Internal",
|
924 |
+
"Response Metrics",
|
925 |
+
],
|
926 |
+
required=True,
|
927 |
+
default="Media",
|
928 |
+
),
|
929 |
+
},
|
930 |
+
disabled=["Column", "Missing Values", "Missing Percentage"],
|
931 |
+
hide_index=True,
|
932 |
+
use_container_width=True,
|
933 |
+
key="data-editor-1",
|
934 |
+
)
|
935 |
+
|
936 |
+
st.session_state["project_dct"]["data_import"]["cat_dct"] = {
|
937 |
+
col: cat
|
938 |
+
for col, cat in zip(edited_stats_df["Column"], edited_stats_df["Category"])
|
939 |
+
}
|
940 |
+
|
941 |
+
for i, row in edited_stats_df.iterrows():
|
942 |
+
column = row["Column"]
|
943 |
+
if row["Impute Method"] == "Drop Column":
|
944 |
+
merged_df.drop(columns=[column], inplace=True)
|
945 |
+
|
946 |
+
elif row["Impute Method"] == "Fill with Mean":
|
947 |
+
merged_df[column].fillna(
|
948 |
+
st.session_state["project_dct"]["data_import"]["merged_df"][
|
949 |
+
column
|
950 |
+
].mean(),
|
951 |
+
inplace=True,
|
952 |
+
)
|
953 |
+
|
954 |
+
elif row["Impute Method"] == "Fill with Median":
|
955 |
+
merged_df[column].fillna(
|
956 |
+
st.session_state["project_dct"]["data_import"]["merged_df"][
|
957 |
+
column
|
958 |
+
].median(),
|
959 |
+
inplace=True,
|
960 |
+
)
|
961 |
+
|
962 |
+
elif row["Impute Method"] == "Fill with 0":
|
963 |
+
merged_df[column].fillna(0, inplace=True)
|
964 |
+
|
965 |
+
# st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
|
966 |
+
#########################################################################################################################################################
|
967 |
+
# Group columns
|
968 |
+
#########################################################################################################################################################
|
969 |
+
|
970 |
+
# Display Group columns header
|
971 |
+
numeric_columns = st.session_state["project_dct"]["data_import"][
|
972 |
+
"numeric_columns"
|
973 |
+
]
|
974 |
+
default_df = st.session_state["project_dct"]["data_import"]["default_df"]
|
975 |
+
|
976 |
+
st.markdown("#### Feature engineering")
|
977 |
+
|
978 |
+
edited_df = st.data_editor(
|
979 |
+
st.session_state["project_dct"]["data_import"]["edited_df"],
|
980 |
+
column_config={
|
981 |
+
"Column 1": st.column_config.SelectboxColumn(
|
982 |
+
options=numeric_columns,
|
983 |
+
required=True,
|
984 |
+
width=400,
|
985 |
+
),
|
986 |
+
"Operator": st.column_config.SelectboxColumn(
|
987 |
+
options=["+", "-", "*", "/"],
|
988 |
+
required=True,
|
989 |
+
default="+",
|
990 |
+
width=100,
|
991 |
+
),
|
992 |
+
"Column 2": st.column_config.SelectboxColumn(
|
993 |
+
options=numeric_columns,
|
994 |
+
required=True,
|
995 |
+
default=numeric_columns[0],
|
996 |
+
width=400,
|
997 |
+
),
|
998 |
+
"Category": st.column_config.SelectboxColumn(
|
999 |
+
options=[
|
1000 |
+
"Media",
|
1001 |
+
"Exogenous",
|
1002 |
+
"Internal",
|
1003 |
+
"Response Metrics",
|
1004 |
+
],
|
1005 |
+
required=True,
|
1006 |
+
default="Media",
|
1007 |
+
width=200,
|
1008 |
+
),
|
1009 |
+
},
|
1010 |
+
num_rows="dynamic",
|
1011 |
+
key="data-editor-4",
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
final_df, edited_stats_df = process_dataframes(
|
1015 |
+
merged_df, edited_df, edited_stats_df
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
st.markdown("#### Final DataFrame")
|
1019 |
+
sort_col = []
|
1020 |
+
for col in final_df.columns:
|
1021 |
+
if col in ["Panel_1", "Panel_2", "date"]:
|
1022 |
+
sort_col.append(col)
|
1023 |
+
|
1024 |
+
sorted_final_df = final_df.sort_values(
|
1025 |
+
by=sort_col, ascending=True, na_position="first"
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
st.dataframe(sorted_final_df, hide_index=True)
|
1029 |
+
|
1030 |
+
# Initialize an empty dictionary to hold categories and their variables
|
1031 |
+
category_dict = {}
|
1032 |
+
|
1033 |
+
# Iterate over each row in the edited DataFrame to populate the dictionary
|
1034 |
+
for i, row in edited_stats_df.iterrows():
|
1035 |
+
column = row["Column"]
|
1036 |
+
category = row[
|
1037 |
+
"Category"
|
1038 |
+
] # The category chosen by the user for this variable
|
1039 |
+
|
1040 |
+
# Check if the category already exists in the dictionary
|
1041 |
+
if category not in category_dict:
|
1042 |
+
# If not, initialize it with the current column as its first element
|
1043 |
+
category_dict[category] = [column]
|
1044 |
+
else:
|
1045 |
+
# If it exists, append the current column to the list of variables under this category
|
1046 |
+
category_dict[category].append(column)
|
1047 |
+
|
1048 |
+
# Add Date, Panel_1 and Panel_12 in category dictionary
|
1049 |
+
category_dict.update({"Date": ["date"]})
|
1050 |
+
if "Panel_1" in final_df.columns:
|
1051 |
+
category_dict["Panel Level 1"] = ["Panel_1"]
|
1052 |
+
if "Panel_2" in final_df.columns:
|
1053 |
+
category_dict["Panel Level 2"] = ["Panel_2"]
|
1054 |
+
|
1055 |
+
# Display the dictionary
|
1056 |
+
st.markdown("#### Variable Category")
|
1057 |
+
for category, variables in category_dict.items():
|
1058 |
+
# Check if there are multiple variables to handle "and" insertion correctly
|
1059 |
+
if len(variables) > 1:
|
1060 |
+
# Join all but the last variable with ", ", then add " and " before the last variable
|
1061 |
+
variables_str = ", ".join(variables[:-1]) + " and " + variables[-1]
|
1062 |
+
else:
|
1063 |
+
# If there's only one variable, no need for "and"
|
1064 |
+
variables_str = variables[0]
|
1065 |
+
|
1066 |
+
# Display the category and its variables in the desired format
|
1067 |
+
st.markdown(
|
1068 |
+
f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
|
1069 |
+
unsafe_allow_html=True,
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
# Function to check if Response Metrics is selected
|
1073 |
+
st.write("")
|
1074 |
+
response_metrics_col = category_dict.get("Response Metrics", [])
|
1075 |
+
if len(response_metrics_col) == 0:
|
1076 |
+
st.warning("Please select Response Metrics column", icon="⚠️")
|
1077 |
+
st.stop()
|
1078 |
+
# elif len(response_metrics_col) > 1:
|
1079 |
+
# st.warning("Please select only one Response Metrics column", icon="⚠️")
|
1080 |
+
# st.stop()
|
1081 |
+
|
1082 |
+
# Store final dataframe and bin dictionary into session state
|
1083 |
+
st.session_state["final_df"], st.session_state["bin_dict"] = (
|
1084 |
+
final_df,
|
1085 |
+
category_dict,
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
# Save the DataFrame and dictionary from the session state to the pickle file
|
1089 |
+
if st.button(
|
1090 |
+
"Accept and Save",
|
1091 |
+
use_container_width=True,
|
1092 |
+
key="data-editor-button",
|
1093 |
+
):
|
1094 |
+
print("test*************")
|
1095 |
+
update_db("1_Data_Import.py")
|
1096 |
+
final_df = final_df.loc[:, ~final_df.columns.duplicated()]
|
1097 |
+
|
1098 |
+
project_dct_path = os.path.join(
|
1099 |
+
st.session_state["project_path"], "project_dct.pkl"
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
with open(project_dct_path, "wb") as f:
|
1103 |
+
pickle.dump(st.session_state["project_dct"], f)
|
1104 |
+
|
1105 |
+
data_path = os.path.join(
|
1106 |
+
st.session_state["project_path"], "data_import.pkl"
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
st.session_state["data_path"] = data_path
|
1110 |
+
|
1111 |
+
save_to_pickle(
|
1112 |
+
data_path,
|
1113 |
+
st.session_state["final_df"],
|
1114 |
+
st.session_state["bin_dict"],
|
1115 |
+
)
|
1116 |
+
|
1117 |
+
st.session_state["project_dct"]["data_import"][
|
1118 |
+
"edited_stats_df"
|
1119 |
+
] = edited_stats_df
|
1120 |
+
st.session_state["project_dct"]["data_import"]["merged_df"] = merged_df
|
1121 |
+
st.session_state["project_dct"]["data_import"][
|
1122 |
+
"missing_stats_df"
|
1123 |
+
] = missing_stats_df
|
1124 |
+
st.session_state["project_dct"]["data_import"]["cat_dct"] = {
|
1125 |
+
col: cat
|
1126 |
+
for col, cat in zip(
|
1127 |
+
edited_stats_df["Column"], edited_stats_df["Category"]
|
1128 |
+
)
|
1129 |
+
}
|
1130 |
+
st.session_state["project_dct"]["data_import"][
|
1131 |
+
"numeric_columns"
|
1132 |
+
] = numeric_columns
|
1133 |
+
st.session_state["project_dct"]["data_import"]["default_df"] = default_df
|
1134 |
+
st.session_state["project_dct"]["data_import"]["final_df"] = final_df
|
1135 |
+
st.session_state["project_dct"]["data_import"]["edited_df"] = edited_df
|
1136 |
+
|
1137 |
+
st.toast("💾 Saved Successfully!")
|
1138 |
+
|
1139 |
+
if accept:
|
1140 |
+
# Normalize all data to a daily granularity. This initial standardization simplifies subsequent conversions to other levels of granularity
|
1141 |
+
with st.spinner("Processing..."):
|
1142 |
+
files_dict = standardize_data_to_daily(files_dict, selections)
|
1143 |
+
|
1144 |
+
# Convert all data to daily level granularity
|
1145 |
+
files_dict = apply_granularity_to_all(
|
1146 |
+
files_dict, granularity_selection, selections
|
1147 |
+
)
|
1148 |
+
|
1149 |
+
# Update the 'files_dict' in the session state
|
1150 |
+
st.session_state["files_dict"] = files_dict
|
1151 |
+
|
1152 |
+
# Set a flag in the session state to indicate that selection has been made
|
1153 |
+
st.session_state["Panel_1_Panel_2_Selected"] = True
|
1154 |
+
|
1155 |
+
#########################################################################################################################################################
|
1156 |
+
# Display unique Panel_1 and Panel_2 values
|
1157 |
+
#########################################################################################################################################################
|
1158 |
+
|
1159 |
+
# Halts further execution until Panel_1 and Panel_2 columns are selected
|
1160 |
+
if st.session_state["project_dct"]["data_import"]["edited_stats_df"] is None:
|
1161 |
+
|
1162 |
+
if (
|
1163 |
+
"files_dict" in st.session_state
|
1164 |
+
and st.session_state["Panel_1_Panel_2_Selected"]
|
1165 |
+
):
|
1166 |
+
files_dict = st.session_state["files_dict"]
|
1167 |
+
|
1168 |
+
st.session_state["project_dct"]["data_import"][
|
1169 |
+
"files_dict"
|
1170 |
+
] = files_dict # resume
|
1171 |
+
else:
|
1172 |
+
st.stop()
|
1173 |
+
|
1174 |
+
# Set to store unique values of Panel_1 and Panel_2
|
1175 |
+
with st.spinner("Fetching Panel values..."):
|
1176 |
+
all_panel1_values, all_panel2_values = clean_and_extract_unique_values(
|
1177 |
+
files_dict, selections
|
1178 |
+
)
|
1179 |
+
|
1180 |
+
# List of Panel_1 and Panel_2 columns unique values
|
1181 |
+
list_of_all_panel1_values = list(all_panel1_values)
|
1182 |
+
list_of_all_panel2_values = list(all_panel2_values)
|
1183 |
+
|
1184 |
+
# Format Panel_1 and Panel_2 values for display
|
1185 |
+
formatted_panel1_values = format_values_for_display(
|
1186 |
+
list_of_all_panel1_values
|
1187 |
+
) ##
|
1188 |
+
formatted_panel2_values = format_values_for_display(
|
1189 |
+
list_of_all_panel2_values
|
1190 |
+
) ##
|
1191 |
+
|
1192 |
+
# storing panel values in project_dct
|
1193 |
+
|
1194 |
+
st.session_state["project_dct"]["data_import"][
|
1195 |
+
"formatted_panel1_values"
|
1196 |
+
] = formatted_panel1_values
|
1197 |
+
st.session_state["project_dct"]["data_import"][
|
1198 |
+
"formatted_panel2_values"
|
1199 |
+
] = formatted_panel2_values
|
1200 |
+
|
1201 |
+
# Unique Panel_1 and Panel_2 values
|
1202 |
+
st.markdown("#### Unique Panel values")
|
1203 |
+
# Display Panel_1 and Panel_2 values
|
1204 |
+
with st.expander("Unique Panel values"):
|
1205 |
+
st.write("")
|
1206 |
+
st.markdown(
|
1207 |
+
f"""
|
1208 |
+
<style>
|
1209 |
+
.justify-text {{
|
1210 |
+
text-align: justify;
|
1211 |
+
}}
|
1212 |
+
</style>
|
1213 |
+
<div class="justify-text">
|
1214 |
+
<strong>Panel Level 1 Values:</strong> {formatted_panel1_values}<br>
|
1215 |
+
<strong>Panel Level 2 Values:</strong> {formatted_panel2_values}
|
1216 |
+
</div>
|
1217 |
+
""",
|
1218 |
+
unsafe_allow_html=True,
|
1219 |
+
)
|
1220 |
+
|
1221 |
+
# Display total Panel_1 and Panel_2
|
1222 |
+
st.write("")
|
1223 |
+
st.markdown(
|
1224 |
+
f"""
|
1225 |
+
<div style="text-align: justify;">
|
1226 |
+
<strong>Number of Level 1 Panels detected:</strong> {len(list_of_all_panel1_values)}<br>
|
1227 |
+
<strong>Number of Level 2 Panels detected:</strong> {len(list_of_all_panel2_values)}
|
1228 |
+
</div>
|
1229 |
+
""",
|
1230 |
+
unsafe_allow_html=True,
|
1231 |
+
)
|
1232 |
+
st.write("")
|
1233 |
+
|
1234 |
+
#########################################################################################################################################################
|
1235 |
+
# Merge all DataFrames
|
1236 |
+
#########################################################################################################################################################
|
1237 |
+
|
1238 |
+
# Merge all DataFrames selected
|
1239 |
+
|
1240 |
+
main_df = create_main_dataframe(
|
1241 |
+
files_dict,
|
1242 |
+
all_panel1_values,
|
1243 |
+
all_panel2_values,
|
1244 |
+
granularity_selection,
|
1245 |
+
)
|
1246 |
+
|
1247 |
+
merged_df = merge_into_main_df(main_df, files_dict, selections) ##
|
1248 |
+
|
1249 |
+
#########################################################################################################################################################
|
1250 |
+
# Categorize Variables and Impute Missing Values
|
1251 |
+
#########################################################################################################################################################
|
1252 |
+
|
1253 |
+
# Create an editable DataFrame in Streamlit
|
1254 |
+
|
1255 |
+
st.markdown("#### Select Variables Category & Impute Missing Values")
|
1256 |
+
|
1257 |
+
# Prepare missing stats DataFrame for editing
|
1258 |
+
missing_stats_df = prepare_missing_stats_df(merged_df)
|
1259 |
+
sorted_missing_stats_df = missing_stats_df.sort_values(
|
1260 |
+
by="Missing Values", ascending=False, na_position="first"
|
1261 |
+
)
|
1262 |
+
|
1263 |
+
# storing missing stats df
|
1264 |
+
|
1265 |
+
edited_stats_df = st.data_editor(
|
1266 |
+
sorted_missing_stats_df,
|
1267 |
+
column_config={
|
1268 |
+
"Impute Method": st.column_config.SelectboxColumn(
|
1269 |
+
options=[
|
1270 |
+
"Drop Column",
|
1271 |
+
"Fill with Mean",
|
1272 |
+
"Fill with Median",
|
1273 |
+
"Fill with 0",
|
1274 |
+
],
|
1275 |
+
required=True,
|
1276 |
+
default="Fill with 0",
|
1277 |
+
),
|
1278 |
+
"Category": st.column_config.SelectboxColumn(
|
1279 |
+
options=[
|
1280 |
+
"Media",
|
1281 |
+
"Exogenous",
|
1282 |
+
"Internal",
|
1283 |
+
"Response Metrics",
|
1284 |
+
],
|
1285 |
+
required=True,
|
1286 |
+
default="Media",
|
1287 |
+
),
|
1288 |
+
},
|
1289 |
+
disabled=["Column", "Missing Values", "Missing Percentage"],
|
1290 |
+
hide_index=True,
|
1291 |
+
use_container_width=True,
|
1292 |
+
key="data-editor-2",
|
1293 |
+
)
|
1294 |
+
|
1295 |
+
# edited_stats_df_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
|
1296 |
+
|
1297 |
+
# edited_stats_df.to_pickle(edited_stats_df_path)
|
1298 |
+
|
1299 |
+
# Apply changes based on edited DataFrame
|
1300 |
+
for i, row in edited_stats_df.iterrows():
|
1301 |
+
column = row["Column"]
|
1302 |
+
if row["Impute Method"] == "Drop Column":
|
1303 |
+
merged_df.drop(columns=[column], inplace=True)
|
1304 |
+
|
1305 |
+
elif row["Impute Method"] == "Fill with Mean":
|
1306 |
+
merged_df[column].fillna(merged_df[column].mean(), inplace=True)
|
1307 |
+
|
1308 |
+
elif row["Impute Method"] == "Fill with Median":
|
1309 |
+
merged_df[column].fillna(merged_df[column].median(), inplace=True)
|
1310 |
+
|
1311 |
+
elif row["Impute Method"] == "Fill with 0":
|
1312 |
+
merged_df[column].fillna(0, inplace=True)
|
1313 |
+
|
1314 |
+
# st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
|
1315 |
+
|
1316 |
+
#########################################################################################################################################################
|
1317 |
+
# Group columns
|
1318 |
+
#########################################################################################################################################################
|
1319 |
+
|
1320 |
+
# Display Group columns header
|
1321 |
+
st.markdown("#### Feature engineering")
|
1322 |
+
|
1323 |
+
# Prepare the numeric columns and an empty DataFrame for user input
|
1324 |
+
numeric_columns, default_df = prepare_numeric_columns_and_default_df(
|
1325 |
+
merged_df, edited_stats_df
|
1326 |
+
)
|
1327 |
+
|
1328 |
+
# st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
|
1329 |
+
|
1330 |
+
# Display editable Dataframe
|
1331 |
+
edited_df = st.data_editor(
|
1332 |
+
default_df,
|
1333 |
+
column_config={
|
1334 |
+
"Column 1": st.column_config.SelectboxColumn(
|
1335 |
+
options=numeric_columns,
|
1336 |
+
required=True,
|
1337 |
+
width=400,
|
1338 |
+
),
|
1339 |
+
"Operator": st.column_config.SelectboxColumn(
|
1340 |
+
options=["+", "-", "*", "/"],
|
1341 |
+
required=True,
|
1342 |
+
default="+",
|
1343 |
+
width=100,
|
1344 |
+
),
|
1345 |
+
"Column 2": st.column_config.SelectboxColumn(
|
1346 |
+
options=numeric_columns,
|
1347 |
+
required=True,
|
1348 |
+
default=numeric_columns[0],
|
1349 |
+
width=400,
|
1350 |
+
),
|
1351 |
+
"Category": st.column_config.SelectboxColumn(
|
1352 |
+
options=[
|
1353 |
+
"Media",
|
1354 |
+
"Exogenous",
|
1355 |
+
"Internal",
|
1356 |
+
"Response Metrics",
|
1357 |
+
],
|
1358 |
+
required=True,
|
1359 |
+
default="Media",
|
1360 |
+
width=200,
|
1361 |
+
),
|
1362 |
+
},
|
1363 |
+
num_rows="dynamic",
|
1364 |
+
key="data-editor-3",
|
1365 |
+
)
|
1366 |
+
|
1367 |
+
# Process the DataFrame based on user inputs and operations specified in edited_df
|
1368 |
+
final_df, edited_stats_df = process_dataframes(
|
1369 |
+
merged_df, edited_df, edited_stats_df
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
# edited_df_path=os.path.join(st.session_state['project_path'],'edited_df.pkl')
|
1373 |
+
# edited_df.to_pickle(edited_df_path)
|
1374 |
+
|
1375 |
+
#########################################################################################################################################################
|
1376 |
+
# Display the Final DataFrame and variables
|
1377 |
+
#########################################################################################################################################################
|
1378 |
+
|
1379 |
+
# Display the Final DataFrame and variables
|
1380 |
+
|
1381 |
+
st.markdown("#### Final DataFrame")
|
1382 |
+
|
1383 |
+
sort_col = []
|
1384 |
+
for col in final_df.columns:
|
1385 |
+
if col in ["Panel_1", "Panel_2", "date"]:
|
1386 |
+
sort_col.append(col)
|
1387 |
+
|
1388 |
+
sorted_final_df = final_df.sort_values(
|
1389 |
+
by=sort_col, ascending=True, na_position="first"
|
1390 |
+
)
|
1391 |
+
st.dataframe(sorted_final_df, hide_index=True)
|
1392 |
+
|
1393 |
+
# Initialize an empty dictionary to hold categories and their variables
|
1394 |
+
category_dict = {}
|
1395 |
+
|
1396 |
+
# Iterate over each row in the edited DataFrame to populate the dictionary
|
1397 |
+
for i, row in edited_stats_df.iterrows():
|
1398 |
+
column = row["Column"]
|
1399 |
+
category = row[
|
1400 |
+
"Category"
|
1401 |
+
] # The category chosen by the user for this variable
|
1402 |
+
|
1403 |
+
# Check if the category already exists in the dictionary
|
1404 |
+
if category not in category_dict:
|
1405 |
+
# If not, initialize it with the current column as its first element
|
1406 |
+
category_dict[category] = [column]
|
1407 |
+
else:
|
1408 |
+
# If it exists, append the current column to the list of variables under this category
|
1409 |
+
category_dict[category].append(column)
|
1410 |
+
|
1411 |
+
# Add Date, Panel_1 and Panel_12 in category dictionary
|
1412 |
+
category_dict.update({"Date": ["date"]})
|
1413 |
+
if "Panel_1" in final_df.columns:
|
1414 |
+
category_dict["Panel Level 1"] = ["Panel_1"]
|
1415 |
+
if "Panel_2" in final_df.columns:
|
1416 |
+
category_dict["Panel Level 2"] = ["Panel_2"]
|
1417 |
+
|
1418 |
+
# Display the dictionary
|
1419 |
+
st.markdown("#### Variable Category")
|
1420 |
+
for category, variables in category_dict.items():
|
1421 |
+
# Check if there are multiple variables to handle "and" insertion correctly
|
1422 |
+
if len(variables) > 1:
|
1423 |
+
# Join all but the last variable with ", ", then add " and " before the last variable
|
1424 |
+
variables_str = ", ".join(variables[:-1]) + " and " + variables[-1]
|
1425 |
+
else:
|
1426 |
+
# If there's only one variable, no need for "and"
|
1427 |
+
variables_str = variables[0]
|
1428 |
+
|
1429 |
+
# Display the category and its variables in the desired format
|
1430 |
+
st.markdown(
|
1431 |
+
f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
|
1432 |
+
unsafe_allow_html=True,
|
1433 |
+
)
|
1434 |
+
|
1435 |
+
# Function to check if Response Metrics is selected
|
1436 |
+
st.write("")
|
1437 |
+
|
1438 |
+
response_metrics_col = category_dict.get("Response Metrics", [])
|
1439 |
+
if len(response_metrics_col) == 0:
|
1440 |
+
st.warning("Please select Response Metrics column", icon="⚠️")
|
1441 |
+
st.stop()
|
1442 |
+
# elif len(response_metrics_col) > 1:
|
1443 |
+
# st.warning("Please select only one Response Metrics column", icon="⚠️")
|
1444 |
+
# st.stop()
|
1445 |
+
|
1446 |
+
# Store final dataframe and bin dictionary into session state
|
1447 |
+
|
1448 |
+
st.session_state["final_df"], st.session_state["bin_dict"] = (
|
1449 |
+
final_df,
|
1450 |
+
category_dict,
|
1451 |
+
)
|
1452 |
+
|
1453 |
+
# Save the DataFrame and dictionary from the session state to the pickle file
|
1454 |
+
|
1455 |
+
if st.button("Accept and Save", use_container_width=True):
|
1456 |
+
|
1457 |
+
print("test*************")
|
1458 |
+
update_db("1_Data_Import.py")
|
1459 |
+
|
1460 |
+
project_dct_path = os.path.join(
|
1461 |
+
st.session_state["project_path"], "project_dct.pkl"
|
1462 |
+
)
|
1463 |
+
|
1464 |
+
with open(project_dct_path, "wb") as f:
|
1465 |
+
pickle.dump(st.session_state["project_dct"], f)
|
1466 |
+
|
1467 |
+
data_path = os.path.join(
|
1468 |
+
st.session_state["project_path"], "data_import.pkl"
|
1469 |
+
)
|
1470 |
+
st.session_state["data_path"] = data_path
|
1471 |
+
|
1472 |
+
save_to_pickle(
|
1473 |
+
data_path,
|
1474 |
+
st.session_state["final_df"],
|
1475 |
+
st.session_state["bin_dict"],
|
1476 |
+
)
|
1477 |
+
|
1478 |
+
st.session_state["project_dct"]["data_import"][
|
1479 |
+
"edited_stats_df"
|
1480 |
+
] = edited_stats_df
|
1481 |
+
st.session_state["project_dct"]["data_import"]["merged_df"] = merged_df
|
1482 |
+
st.session_state["project_dct"]["data_import"][
|
1483 |
+
"missing_stats_df"
|
1484 |
+
] = missing_stats_df
|
1485 |
+
st.session_state["project_dct"]["data_import"]["cat_dct"] = {
|
1486 |
+
col: cat
|
1487 |
+
for col, cat in zip(
|
1488 |
+
edited_stats_df["Column"], edited_stats_df["Category"]
|
1489 |
+
)
|
1490 |
+
}
|
1491 |
+
st.session_state["project_dct"]["data_import"][
|
1492 |
+
"numeric_columns"
|
1493 |
+
] = numeric_columns
|
1494 |
+
st.session_state["project_dct"]["data_import"]["default_df"] = default_df
|
1495 |
+
st.session_state["project_dct"]["data_import"]["final_df"] = final_df
|
1496 |
+
st.session_state["project_dct"]["data_import"]["edited_df"] = edited_df
|
1497 |
+
|
1498 |
+
st.toast("💾 Saved Successfully!")
|
1499 |
+
|
1500 |
+
# *****************************************************************
|
1501 |
+
# *********************************Persistant flow****************
|
pages/2_Data_Validation_and_Insights.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from Eda_functions import *
|
6 |
+
import numpy as np
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
import streamlit as st
|
10 |
+
import streamlit.components.v1 as components
|
11 |
+
import sweetviz as sv
|
12 |
+
from utilities import set_header, load_local_css
|
13 |
+
from st_aggrid import GridOptionsBuilder, GridUpdateMode
|
14 |
+
from st_aggrid import GridOptionsBuilder
|
15 |
+
from st_aggrid import AgGrid
|
16 |
+
import base64
|
17 |
+
import os
|
18 |
+
import tempfile
|
19 |
+
|
20 |
+
# from ydata_profiling import ProfileReport
|
21 |
+
import re
|
22 |
+
|
23 |
+
# from pygwalker.api.streamlit import StreamlitRenderer
|
24 |
+
# from Home_redirecting import home
|
25 |
+
import sqlite3
|
26 |
+
from utilities import update_db
|
27 |
+
|
28 |
+
st.set_page_config(
|
29 |
+
page_title="Data Validation",
|
30 |
+
page_icon=":shark:",
|
31 |
+
layout="wide",
|
32 |
+
initial_sidebar_state="collapsed",
|
33 |
+
)
|
34 |
+
load_local_css("styles.css")
|
35 |
+
set_header()
|
36 |
+
|
37 |
+
|
38 |
+
if "project_dct" not in st.session_state:
|
39 |
+
# home()
|
40 |
+
st.warning("Please select a project from home page")
|
41 |
+
st.stop()
|
42 |
+
|
43 |
+
|
44 |
+
data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")
|
45 |
+
|
46 |
+
try:
|
47 |
+
with open(data_path, "rb") as f:
|
48 |
+
data = pickle.load(f)
|
49 |
+
except Exception as e:
|
50 |
+
st.error(f"Please import data from the Data Import Page")
|
51 |
+
st.stop()
|
52 |
+
|
53 |
+
conn = sqlite3.connect(r"DB\User.db", check_same_thread=False) # connection with sql db
|
54 |
+
c = conn.cursor()
|
55 |
+
st.session_state["cleaned_data"] = data["final_df"]
|
56 |
+
st.session_state["category_dict"] = data["bin_dict"]
|
57 |
+
# st.write(st.session_state['category_dict'])
|
58 |
+
|
59 |
+
st.title("Data Validation and Insights")
|
60 |
+
|
61 |
+
|
62 |
+
target_variables = [
|
63 |
+
st.session_state["category_dict"][key]
|
64 |
+
for key in st.session_state["category_dict"].keys()
|
65 |
+
if key == "Response Metrics"
|
66 |
+
]
|
67 |
+
|
68 |
+
|
69 |
+
def format_display(inp):
|
70 |
+
return inp.title().replace("_", " ").strip()
|
71 |
+
|
72 |
+
|
73 |
+
target_variables = list(*target_variables)
|
74 |
+
target_column = st.selectbox(
|
75 |
+
"Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
|
76 |
+
target_variables,
|
77 |
+
index=st.session_state["project_dct"]["data_validation"]["target_column"],
|
78 |
+
format_func=format_display,
|
79 |
+
)
|
80 |
+
|
81 |
+
st.session_state["project_dct"]["data_validation"]["target_column"] = (
|
82 |
+
target_variables.index(target_column)
|
83 |
+
)
|
84 |
+
|
85 |
+
st.session_state["target_column"] = target_column
|
86 |
+
|
87 |
+
panels = st.session_state["category_dict"]["Panel Level 1"][0]
|
88 |
+
|
89 |
+
selected_panels = st.multiselect(
|
90 |
+
"Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
|
91 |
+
st.session_state["cleaned_data"][panels].unique(),
|
92 |
+
default=st.session_state["project_dct"]["data_validation"]["selected_panels"],
|
93 |
+
)
|
94 |
+
|
95 |
+
st.session_state["project_dct"]["data_validation"]["selected_panels"] = selected_panels
|
96 |
+
|
97 |
+
aggregation_dict = {
|
98 |
+
item: "sum" if key == "Media" else "mean"
|
99 |
+
for key, value in st.session_state["category_dict"].items()
|
100 |
+
for item in value
|
101 |
+
if item not in ["date", "Panel_1"]
|
102 |
+
}
|
103 |
+
|
104 |
+
with st.expander("**Reponse Metric Analysis**"):
|
105 |
+
|
106 |
+
if len(selected_panels) > 0:
|
107 |
+
st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][
|
108 |
+
st.session_state["cleaned_data"]["Panel_1"].isin(selected_panels)
|
109 |
+
]
|
110 |
+
|
111 |
+
st.session_state["Cleaned_data_panel"] = (
|
112 |
+
st.session_state["Cleaned_data_panel"]
|
113 |
+
.groupby(by="date")
|
114 |
+
.agg(aggregation_dict)
|
115 |
+
)
|
116 |
+
st.session_state["Cleaned_data_panel"] = st.session_state[
|
117 |
+
"Cleaned_data_panel"
|
118 |
+
].reset_index()
|
119 |
+
else:
|
120 |
+
# st.write(st.session_state['cleaned_data'])
|
121 |
+
st.session_state["Cleaned_data_panel"] = (
|
122 |
+
st.session_state["cleaned_data"].groupby(by="date").agg(aggregation_dict)
|
123 |
+
)
|
124 |
+
st.session_state["Cleaned_data_panel"] = st.session_state[
|
125 |
+
"Cleaned_data_panel"
|
126 |
+
].reset_index()
|
127 |
+
|
128 |
+
fig = line_plot_target(
|
129 |
+
st.session_state["Cleaned_data_panel"],
|
130 |
+
target=target_column,
|
131 |
+
title=f"{target_column} Over Time",
|
132 |
+
)
|
133 |
+
st.plotly_chart(fig, use_container_width=True)
|
134 |
+
|
135 |
+
media_channel = list(
|
136 |
+
*[
|
137 |
+
st.session_state["category_dict"][key]
|
138 |
+
for key in st.session_state["category_dict"].keys()
|
139 |
+
if key == "Media"
|
140 |
+
]
|
141 |
+
)
|
142 |
+
# st.write(media_channel)
|
143 |
+
|
144 |
+
exo_var = list(
|
145 |
+
*[
|
146 |
+
st.session_state["category_dict"][key]
|
147 |
+
for key in st.session_state["category_dict"].keys()
|
148 |
+
if key == "Exogenous"
|
149 |
+
]
|
150 |
+
)
|
151 |
+
internal_var = list(
|
152 |
+
*[
|
153 |
+
st.session_state["category_dict"][key]
|
154 |
+
for key in st.session_state["category_dict"].keys()
|
155 |
+
if key == "Internal"
|
156 |
+
]
|
157 |
+
)
|
158 |
+
Non_media_variables = exo_var + internal_var
|
159 |
+
|
160 |
+
st.markdown("### Annual Data Summary")
|
161 |
+
|
162 |
+
summary_df = summary(
|
163 |
+
st.session_state["Cleaned_data_panel"],
|
164 |
+
media_channel + [target_column],
|
165 |
+
spends=None,
|
166 |
+
Target=True,
|
167 |
+
)
|
168 |
+
|
169 |
+
st.dataframe(
|
170 |
+
summary_df,
|
171 |
+
use_container_width=True,
|
172 |
+
)
|
173 |
+
|
174 |
+
if st.checkbox("Show raw data"):
|
175 |
+
st.cache_resource(show_spinner=False)
|
176 |
+
|
177 |
+
def raw_df_gen():
|
178 |
+
# Convert 'date' to datetime but do not convert to string yet for sorting
|
179 |
+
dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"])
|
180 |
+
|
181 |
+
# Concatenate the dates with other numeric columns formatted
|
182 |
+
raw_df = pd.concat(
|
183 |
+
[
|
184 |
+
dates,
|
185 |
+
st.session_state["Cleaned_data_panel"]
|
186 |
+
.select_dtypes(np.number)
|
187 |
+
.applymap(format_numbers),
|
188 |
+
],
|
189 |
+
axis=1,
|
190 |
+
)
|
191 |
+
|
192 |
+
# Now sort raw_df by the 'date' column, which is still in datetime format
|
193 |
+
sorted_raw_df = raw_df.sort_values(by="date", ascending=True)
|
194 |
+
|
195 |
+
# After sorting, convert 'date' to string format for display
|
196 |
+
sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y")
|
197 |
+
|
198 |
+
return sorted_raw_df
|
199 |
+
|
200 |
+
# Display the sorted DataFrame in Streamlit
|
201 |
+
st.dataframe(raw_df_gen())
|
202 |
+
|
203 |
+
col1 = st.columns(1)
|
204 |
+
|
205 |
+
if "selected_feature" not in st.session_state:
|
206 |
+
st.session_state["selected_feature"] = None
|
207 |
+
|
208 |
+
|
209 |
+
def generate_report_with_target(channel_data, target_feature):
|
210 |
+
report = sv.analyze([channel_data, "Dataset"], target_feat=target_feature)
|
211 |
+
temp_dir = tempfile.mkdtemp()
|
212 |
+
report_path = os.path.join(temp_dir, "report.html")
|
213 |
+
report.show_html(
|
214 |
+
filepath=report_path, open_browser=False
|
215 |
+
) # Generate the report as an HTML file
|
216 |
+
return report_path
|
217 |
+
|
218 |
+
|
219 |
+
def generate_profile_report(df):
|
220 |
+
pr = df.profile_report()
|
221 |
+
temp_dir = tempfile.mkdtemp()
|
222 |
+
report_path = os.path.join(temp_dir, "report.html")
|
223 |
+
pr.to_file(report_path)
|
224 |
+
return report_path
|
225 |
+
|
226 |
+
|
227 |
+
# st.header()
|
228 |
+
with st.expander("Univariate and Bivariate Report"):
|
229 |
+
eda_columns = st.columns(2)
|
230 |
+
with eda_columns[0]:
|
231 |
+
if st.button(
|
232 |
+
"Generate Profile Report",
|
233 |
+
help="Univariate report which inlcudes all statistical analysis",
|
234 |
+
):
|
235 |
+
with st.spinner("Generating Report"):
|
236 |
+
report_file = generate_profile_report(
|
237 |
+
st.session_state["Cleaned_data_panel"]
|
238 |
+
)
|
239 |
+
|
240 |
+
if os.path.exists(report_file):
|
241 |
+
with open(report_file, "rb") as f:
|
242 |
+
st.success("Report Generated")
|
243 |
+
st.download_button(
|
244 |
+
label="Download EDA Report",
|
245 |
+
data=f.read(),
|
246 |
+
file_name="pandas_profiling_report.html",
|
247 |
+
mime="text/html",
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
st.warning(
|
251 |
+
"Report generation failed. Unable to find the report file."
|
252 |
+
)
|
253 |
+
|
254 |
+
with eda_columns[1]:
|
255 |
+
if st.button(
|
256 |
+
"Generate Sweetviz Report",
|
257 |
+
help="Bivariate report for selected response metric",
|
258 |
+
):
|
259 |
+
with st.spinner("Generating Report"):
|
260 |
+
report_file = generate_report_with_target(
|
261 |
+
st.session_state["Cleaned_data_panel"], target_column
|
262 |
+
)
|
263 |
+
|
264 |
+
if os.path.exists(report_file):
|
265 |
+
with open(report_file, "rb") as f:
|
266 |
+
st.success("Report Generated")
|
267 |
+
st.download_button(
|
268 |
+
label="Download EDA Report",
|
269 |
+
data=f.read(),
|
270 |
+
file_name="report.html",
|
271 |
+
mime="text/html",
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
st.warning("Report generation failed. Unable to find the report file.")
|
275 |
+
|
276 |
+
|
277 |
+
# st.warning('Work in Progress')
|
278 |
+
with st.expander("Media Variables Analysis"):
|
279 |
+
# Get the selected feature
|
280 |
+
|
281 |
+
media_variables = [
|
282 |
+
col
|
283 |
+
for col in media_channel
|
284 |
+
if "cost" not in col.lower() and "spend" not in col.lower()
|
285 |
+
]
|
286 |
+
|
287 |
+
st.session_state["selected_feature"] = st.selectbox(
|
288 |
+
"Select media", media_variables, format_func=format_display
|
289 |
+
)
|
290 |
+
|
291 |
+
st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
|
292 |
+
media_variables.index(st.session_state["selected_feature"])
|
293 |
+
)
|
294 |
+
|
295 |
+
# Filter spends features based on the selected feature
|
296 |
+
spends_features = [
|
297 |
+
col
|
298 |
+
for col in st.session_state["Cleaned_data_panel"].columns
|
299 |
+
if any(keyword in col.lower() for keyword in ["cost", "spend"])
|
300 |
+
]
|
301 |
+
spends_feature = [
|
302 |
+
col
|
303 |
+
for col in spends_features
|
304 |
+
if re.split(r"_cost|_spend", col.lower())[0]
|
305 |
+
in st.session_state["selected_feature"]
|
306 |
+
]
|
307 |
+
|
308 |
+
if "validation" not in st.session_state:
|
309 |
+
|
310 |
+
st.session_state["validation"] = st.session_state["project_dct"][
|
311 |
+
"data_validation"
|
312 |
+
]["validated_variables"]
|
313 |
+
|
314 |
+
val_variables = [col for col in media_channel if col != "date"]
|
315 |
+
|
316 |
+
if not set(
|
317 |
+
st.session_state["project_dct"]["data_validation"]["validated_variables"]
|
318 |
+
).issubset(set(val_variables)):
|
319 |
+
|
320 |
+
st.session_state["validation"] = []
|
321 |
+
|
322 |
+
if len(spends_feature) == 0:
|
323 |
+
st.warning("No spends varaible available for the selected metric in data")
|
324 |
+
|
325 |
+
else:
|
326 |
+
fig_row1 = line_plot(
|
327 |
+
st.session_state["Cleaned_data_panel"],
|
328 |
+
x_col="date",
|
329 |
+
y1_cols=[st.session_state["selected_feature"]],
|
330 |
+
y2_cols=[target_column],
|
331 |
+
title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
|
332 |
+
)
|
333 |
+
st.plotly_chart(fig_row1, use_container_width=True)
|
334 |
+
st.markdown("### Summary")
|
335 |
+
st.dataframe(
|
336 |
+
summary(
|
337 |
+
st.session_state["cleaned_data"],
|
338 |
+
[st.session_state["selected_feature"]],
|
339 |
+
spends=spends_feature[0],
|
340 |
+
),
|
341 |
+
use_container_width=True,
|
342 |
+
)
|
343 |
+
|
344 |
+
cols2 = st.columns(2)
|
345 |
+
|
346 |
+
if len(set(st.session_state["validation"]).intersection(val_variables)) == len(
|
347 |
+
val_variables
|
348 |
+
):
|
349 |
+
disable = True
|
350 |
+
help = "All media variables are validated"
|
351 |
+
else:
|
352 |
+
disable = False
|
353 |
+
help = ""
|
354 |
+
|
355 |
+
with cols2[0]:
|
356 |
+
if st.button("Validate", disabled=disable, help=help):
|
357 |
+
st.session_state["validation"].append(
|
358 |
+
st.session_state["selected_feature"]
|
359 |
+
)
|
360 |
+
with cols2[1]:
|
361 |
+
|
362 |
+
if st.checkbox("Validate all", disabled=disable, help=help):
|
363 |
+
st.session_state["validation"].extend(val_variables)
|
364 |
+
st.success("All media variables are validated ✅")
|
365 |
+
|
366 |
+
if len(set(st.session_state["validation"]).intersection(val_variables)) != len(
|
367 |
+
val_variables
|
368 |
+
):
|
369 |
+
validation_data = pd.DataFrame(
|
370 |
+
{
|
371 |
+
"Validate": [
|
372 |
+
(True if col in st.session_state["validation"] else False)
|
373 |
+
for col in val_variables
|
374 |
+
],
|
375 |
+
"Variables": val_variables,
|
376 |
+
}
|
377 |
+
)
|
378 |
+
|
379 |
+
sorted_validation_df = validation_data.sort_values(
|
380 |
+
by="Variables", ascending=True, na_position="first"
|
381 |
+
)
|
382 |
+
cols3 = st.columns([1, 30])
|
383 |
+
with cols3[1]:
|
384 |
+
validation_df = st.data_editor(
|
385 |
+
sorted_validation_df,
|
386 |
+
# column_config={
|
387 |
+
# 'Validate':st.column_config.CheckboxColumn(wi)
|
388 |
+
# },
|
389 |
+
column_config={
|
390 |
+
"Validate": st.column_config.CheckboxColumn(
|
391 |
+
default=False,
|
392 |
+
width=100,
|
393 |
+
),
|
394 |
+
"Variables": st.column_config.TextColumn(width=1000),
|
395 |
+
},
|
396 |
+
hide_index=True,
|
397 |
+
)
|
398 |
+
|
399 |
+
selected_rows = validation_df[validation_df["Validate"] == True][
|
400 |
+
"Variables"
|
401 |
+
]
|
402 |
+
|
403 |
+
# st.write(selected_rows)
|
404 |
+
|
405 |
+
st.session_state["validation"].extend(selected_rows)
|
406 |
+
|
407 |
+
st.session_state["project_dct"]["data_validation"][
|
408 |
+
"validated_variables"
|
409 |
+
] = st.session_state["validation"]
|
410 |
+
|
411 |
+
not_validated_variables = [
|
412 |
+
col
|
413 |
+
for col in val_variables
|
414 |
+
if col not in st.session_state["validation"]
|
415 |
+
]
|
416 |
+
|
417 |
+
if not_validated_variables:
|
418 |
+
not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
|
419 |
+
st.warning(not_validated_message)
|
420 |
+
|
421 |
+
|
422 |
+
with st.expander("Non Media Variables Analysis"):
|
423 |
+
selected_columns_row4 = st.selectbox(
|
424 |
+
"Select Channel",
|
425 |
+
Non_media_variables,
|
426 |
+
format_func=format_display,
|
427 |
+
index=st.session_state["project_dct"]["data_validation"]["Non_media_variables"],
|
428 |
+
)
|
429 |
+
|
430 |
+
st.session_state["project_dct"]["data_validation"]["Non_media_variables"] = (
|
431 |
+
Non_media_variables.index(selected_columns_row4)
|
432 |
+
)
|
433 |
+
|
434 |
+
# # Create the dual-axis line plot
|
435 |
+
fig_row4 = line_plot(
|
436 |
+
st.session_state["Cleaned_data_panel"],
|
437 |
+
x_col="date",
|
438 |
+
y1_cols=[selected_columns_row4],
|
439 |
+
y2_cols=[target_column],
|
440 |
+
title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
|
441 |
+
)
|
442 |
+
st.plotly_chart(fig_row4, use_container_width=True)
|
443 |
+
selected_non_media = selected_columns_row4
|
444 |
+
sum_df = st.session_state["Cleaned_data_panel"][
|
445 |
+
["date", selected_non_media, target_column]
|
446 |
+
]
|
447 |
+
sum_df["Year"] = pd.to_datetime(
|
448 |
+
st.session_state["Cleaned_data_panel"]["date"]
|
449 |
+
).dt.year
|
450 |
+
# st.dataframe(df)
|
451 |
+
# st.dataframe(sum_df.head(2))
|
452 |
+
print(sum_df)
|
453 |
+
sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
|
454 |
+
sum_df.loc["Grand Total"] = sum_df.sum()
|
455 |
+
sum_df = sum_df.applymap(format_numbers)
|
456 |
+
sum_df.fillna("-", inplace=True)
|
457 |
+
sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
|
458 |
+
st.markdown("### Summary")
|
459 |
+
st.dataframe(sum_df, use_container_width=True)
|
460 |
+
|
461 |
+
# with st.expander('Interactive Dashboard'):
|
462 |
+
|
463 |
+
# pygg_app=StreamlitRenderer(st.session_state['cleaned_data'])
|
464 |
+
|
465 |
+
# pygg_app.explorer()
|
466 |
+
|
467 |
+
with st.expander("Correlation Analysis"):
|
468 |
+
options = list(
|
469 |
+
st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
|
470 |
+
)
|
471 |
+
|
472 |
+
# selected_options = []
|
473 |
+
# num_columns = 4
|
474 |
+
# num_rows = -(-len(options) // num_columns) # Ceiling division to calculate rows
|
475 |
+
|
476 |
+
# # Create a grid of checkboxes
|
477 |
+
# st.header('Select Features for Correlation Plot')
|
478 |
+
# tick=False
|
479 |
+
# if st.checkbox('Select all'):
|
480 |
+
# tick=True
|
481 |
+
# selected_options = []
|
482 |
+
# for row in range(num_rows):
|
483 |
+
# cols = st.columns(num_columns)
|
484 |
+
# for col in cols:
|
485 |
+
# if options:
|
486 |
+
# option = options.pop(0)
|
487 |
+
# selected = col.checkbox(option,value=tick)
|
488 |
+
# if selected:
|
489 |
+
# selected_options.append(option)
|
490 |
+
# # Display selected options
|
491 |
+
|
492 |
+
selected_options = st.multiselect(
|
493 |
+
"Select Variables For correlation plot",
|
494 |
+
[var for var in options if var != target_column],
|
495 |
+
default=options[3],
|
496 |
+
)
|
497 |
+
|
498 |
+
st.pyplot(
|
499 |
+
correlation_plot(
|
500 |
+
st.session_state["Cleaned_data_panel"],
|
501 |
+
selected_options,
|
502 |
+
target_column,
|
503 |
+
)
|
504 |
+
)
|
505 |
+
|
506 |
+
if st.button("Save Changes", use_container_width=True):
|
507 |
+
|
508 |
+
update_db("2_Data_Validation.py")
|
509 |
+
|
510 |
+
project_dct_path = os.path.join(st.session_state["project_path"], "project_dct.pkl")
|
511 |
+
|
512 |
+
with open(project_dct_path, "wb") as f:
|
513 |
+
pickle.dump(st.session_state["project_dct"], f)
|
514 |
+
st.success("Changes saved")
|
pages/3_Transformations.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Importing necessary libraries
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
st.set_page_config(
|
5 |
+
page_title="Transformations",
|
6 |
+
page_icon=":shark:",
|
7 |
+
layout="wide",
|
8 |
+
initial_sidebar_state="collapsed",
|
9 |
+
)
|
10 |
+
|
11 |
+
import pickle
|
12 |
+
import numpy as np
|
13 |
+
import pandas as pd
|
14 |
+
from utilities import set_header, load_local_css
|
15 |
+
import streamlit_authenticator as stauth
|
16 |
+
import yaml
|
17 |
+
from yaml import SafeLoader
|
18 |
+
import os
|
19 |
+
import sqlite3
|
20 |
+
from utilities import update_db
|
21 |
+
|
22 |
+
|
23 |
+
load_local_css("styles.css")
|
24 |
+
set_header()
|
25 |
+
|
26 |
+
|
27 |
+
# Check for authentication status
|
28 |
+
for k, v in st.session_state.items():
|
29 |
+
if k not in ["logout", "login", "config"] and not k.startswith("FormSubmitter"):
|
30 |
+
st.session_state[k] = v
|
31 |
+
with open("config.yaml") as file:
|
32 |
+
config = yaml.load(file, Loader=SafeLoader)
|
33 |
+
st.session_state["config"] = config
|
34 |
+
authenticator = stauth.Authenticate(
|
35 |
+
config["credentials"],
|
36 |
+
config["cookie"]["name"],
|
37 |
+
config["cookie"]["key"],
|
38 |
+
config["cookie"]["expiry_days"],
|
39 |
+
config["preauthorized"],
|
40 |
+
)
|
41 |
+
st.session_state["authenticator"] = authenticator
|
42 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
43 |
+
auth_status = st.session_state.get("authentication_status")
|
44 |
+
|
45 |
+
if auth_status == True:
|
46 |
+
authenticator.logout("Logout", "main")
|
47 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
48 |
+
|
49 |
+
if "project_dct" not in st.session_state:
|
50 |
+
st.error("Please load a project from Home page")
|
51 |
+
st.stop()
|
52 |
+
|
53 |
+
conn = sqlite3.connect(
|
54 |
+
r"DB/User.db", check_same_thread=False
|
55 |
+
) # connection with sql db
|
56 |
+
c = conn.cursor()
|
57 |
+
|
58 |
+
if not is_state_initiaized:
|
59 |
+
if "session_name" not in st.session_state:
|
60 |
+
st.session_state["session_name"] = None
|
61 |
+
|
62 |
+
if not os.path.exists(
|
63 |
+
os.path.join(st.session_state["project_path"], "data_import.pkl")
|
64 |
+
):
|
65 |
+
st.error("Please move to Data Import page")
|
66 |
+
# Deserialize and load the objects from the pickle file
|
67 |
+
with open(
|
68 |
+
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
|
69 |
+
) as f:
|
70 |
+
data = pickle.load(f)
|
71 |
+
|
72 |
+
# Accessing the loaded objects
|
73 |
+
final_df_loaded = data["final_df"]
|
74 |
+
bin_dict_loaded = data["bin_dict"]
|
75 |
+
# final_df_loaded.to_csv("Test/final_df_loaded.csv",index=False)
|
76 |
+
# Initialize session state==-
|
77 |
+
if "transformed_columns_dict" not in st.session_state:
|
78 |
+
st.session_state["transformed_columns_dict"] = {} # Default empty dictionary
|
79 |
+
|
80 |
+
if "final_df" not in st.session_state:
|
81 |
+
st.session_state["final_df"] = final_df_loaded # Default as original dataframe
|
82 |
+
|
83 |
+
if "summary_string" not in st.session_state:
|
84 |
+
st.session_state["summary_string"] = None # Default as None
|
85 |
+
|
86 |
+
# Extract original columns for specified categories
|
87 |
+
original_columns = {
|
88 |
+
category: bin_dict_loaded[category]
|
89 |
+
for category in ["Media", "Internal", "Exogenous"]
|
90 |
+
if category in bin_dict_loaded
|
91 |
+
}
|
92 |
+
|
93 |
+
# Retrive Panel columns
|
94 |
+
panel_1 = bin_dict_loaded.get("Panel Level 1")
|
95 |
+
panel_2 = bin_dict_loaded.get("Panel Level 2")
|
96 |
+
|
97 |
+
# # For testing on non panel level
|
98 |
+
# final_df_loaded = final_df_loaded.drop("Panel_1", axis=1)
|
99 |
+
# final_df_loaded = final_df_loaded.groupby("date").mean().reset_index()
|
100 |
+
# panel_1 = None
|
101 |
+
|
102 |
+
# Apply transformations on panel level
|
103 |
+
if panel_1:
|
104 |
+
panel = panel_1 + panel_2 if panel_2 else panel_1
|
105 |
+
else:
|
106 |
+
panel = []
|
107 |
+
|
108 |
+
# Function to build transformation widgets
|
109 |
+
def transformation_widgets(category, transform_params, date_granularity):
|
110 |
+
|
111 |
+
if (
|
112 |
+
st.session_state["project_dct"]["transformations"] is None
|
113 |
+
or st.session_state["project_dct"]["transformations"] == {}
|
114 |
+
):
|
115 |
+
st.session_state["project_dct"]["transformations"] = {}
|
116 |
+
if category not in st.session_state["project_dct"]["transformations"].keys():
|
117 |
+
st.session_state["project_dct"]["transformations"][category] = {}
|
118 |
+
|
119 |
+
# Define a dict of pre-defined default values of every transformation
|
120 |
+
predefined_defualts = {
|
121 |
+
"Lag": (1, 2),
|
122 |
+
"Lead": (1, 2),
|
123 |
+
"Moving Average": (1, 2),
|
124 |
+
"Saturation": (10, 20),
|
125 |
+
"Power": (2, 4),
|
126 |
+
"Adstock": (0.5, 0.7),
|
127 |
+
}
|
128 |
+
|
129 |
+
def selection_change():
|
130 |
+
# Handles removing transformations
|
131 |
+
if f"transformation_{category}" in st.session_state:
|
132 |
+
current_selection = st.session_state[f"transformation_{category}"]
|
133 |
+
past_selection = st.session_state["project_dct"]["transformations"][
|
134 |
+
category
|
135 |
+
][f"transformation_{category}"]
|
136 |
+
removed_selection = list(set(past_selection) - set(current_selection))
|
137 |
+
for selection in removed_selection:
|
138 |
+
# Option 1 - revert to defualt
|
139 |
+
# st.session_state['project_dct']['transformations'][category][selection] = predefined_defualts[selection]
|
140 |
+
|
141 |
+
# option 2 - delete from dict
|
142 |
+
del st.session_state["project_dct"]["transformations"][category][
|
143 |
+
selection
|
144 |
+
]
|
145 |
+
|
146 |
+
# Transformation Options
|
147 |
+
transformation_options = {
|
148 |
+
"Media": [
|
149 |
+
"Lag",
|
150 |
+
"Moving Average",
|
151 |
+
"Saturation",
|
152 |
+
"Power",
|
153 |
+
"Adstock",
|
154 |
+
],
|
155 |
+
"Internal": ["Lead", "Lag", "Moving Average"],
|
156 |
+
"Exogenous": ["Lead", "Lag", "Moving Average"],
|
157 |
+
}
|
158 |
+
|
159 |
+
expanded = st.session_state["project_dct"]["transformations"][category].get(
|
160 |
+
"expanded", False
|
161 |
+
)
|
162 |
+
st.session_state["project_dct"]["transformations"][category]["expanded"] = False
|
163 |
+
with st.expander(f"{category} Transformations", expanded=expanded):
|
164 |
+
st.session_state["project_dct"]["transformations"][category][
|
165 |
+
"expanded"
|
166 |
+
] = True
|
167 |
+
|
168 |
+
# Let users select which transformations to apply
|
169 |
+
sel_transformations = st.session_state["project_dct"]["transformations"][
|
170 |
+
category
|
171 |
+
].get(f"transformation_{category}", [])
|
172 |
+
transformations_to_apply = st.multiselect(
|
173 |
+
"Select transformations to apply",
|
174 |
+
options=transformation_options[category],
|
175 |
+
default=sel_transformations,
|
176 |
+
key=f"transformation_{category}",
|
177 |
+
# on_change=selection_change(),
|
178 |
+
)
|
179 |
+
st.session_state["project_dct"]["transformations"][category][
|
180 |
+
"transformation_" + category
|
181 |
+
] = transformations_to_apply
|
182 |
+
# Determine the number of transformations to put in each column
|
183 |
+
transformations_per_column = (
|
184 |
+
len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
|
185 |
+
)
|
186 |
+
|
187 |
+
# Create two columns
|
188 |
+
col1, col2 = st.columns(2)
|
189 |
+
|
190 |
+
# Assign transformations to each column
|
191 |
+
transformations_col1 = transformations_to_apply[:transformations_per_column]
|
192 |
+
transformations_col2 = transformations_to_apply[transformations_per_column:]
|
193 |
+
|
194 |
+
# Define a helper function to create widgets for each transformation
|
195 |
+
def create_transformation_widgets(column, transformations):
|
196 |
+
with column:
|
197 |
+
for transformation in transformations:
|
198 |
+
# Conditionally create widgets for selected transformations
|
199 |
+
if transformation == "Lead":
|
200 |
+
lead_default = st.session_state["project_dct"][
|
201 |
+
"transformations"
|
202 |
+
][category].get("Lead", predefined_defualts["Lead"])
|
203 |
+
st.markdown(f"**Lead ({date_granularity})**")
|
204 |
+
lead = st.slider(
|
205 |
+
"Lead periods",
|
206 |
+
1,
|
207 |
+
10,
|
208 |
+
lead_default,
|
209 |
+
1,
|
210 |
+
key=f"lead_{category}",
|
211 |
+
label_visibility="collapsed",
|
212 |
+
)
|
213 |
+
st.session_state["project_dct"]["transformations"][
|
214 |
+
category
|
215 |
+
]["Lead"] = lead
|
216 |
+
start = lead[0]
|
217 |
+
end = lead[1]
|
218 |
+
step = 1
|
219 |
+
transform_params[category]["Lead"] = np.arange(
|
220 |
+
start, end + step, step
|
221 |
+
)
|
222 |
+
|
223 |
+
if transformation == "Lag":
|
224 |
+
lag_default = st.session_state["project_dct"][
|
225 |
+
"transformations"
|
226 |
+
][category].get("Lag", predefined_defualts["Lag"])
|
227 |
+
st.markdown(f"**Lag ({date_granularity})**")
|
228 |
+
lag = st.slider(
|
229 |
+
"Lag periods",
|
230 |
+
1,
|
231 |
+
10,
|
232 |
+
(1, 2), # lag_default,
|
233 |
+
1,
|
234 |
+
key=f"lag_{category}",
|
235 |
+
label_visibility="collapsed",
|
236 |
+
)
|
237 |
+
st.session_state["project_dct"]["transformations"][
|
238 |
+
category
|
239 |
+
]["Lag"] = lag
|
240 |
+
start = lag[0]
|
241 |
+
end = lag[1]
|
242 |
+
step = 1
|
243 |
+
transform_params[category]["Lag"] = np.arange(
|
244 |
+
start, end + step, step
|
245 |
+
)
|
246 |
+
|
247 |
+
if transformation == "Moving Average":
|
248 |
+
ma_default = st.session_state["project_dct"][
|
249 |
+
"transformations"
|
250 |
+
][category].get("MA", predefined_defualts["Moving Average"])
|
251 |
+
st.markdown(f"**Moving Average ({date_granularity})**")
|
252 |
+
window = st.slider(
|
253 |
+
"Window size for Moving Average",
|
254 |
+
1,
|
255 |
+
10,
|
256 |
+
ma_default,
|
257 |
+
1,
|
258 |
+
key=f"ma_{category}",
|
259 |
+
label_visibility="collapsed",
|
260 |
+
)
|
261 |
+
st.session_state["project_dct"]["transformations"][
|
262 |
+
category
|
263 |
+
]["MA"] = window
|
264 |
+
start = window[0]
|
265 |
+
end = window[1]
|
266 |
+
step = 1
|
267 |
+
transform_params[category]["Moving Average"] = np.arange(
|
268 |
+
start, end + step, step
|
269 |
+
)
|
270 |
+
|
271 |
+
if transformation == "Saturation":
|
272 |
+
st.markdown("**Saturation (%)**")
|
273 |
+
saturation_default = st.session_state["project_dct"][
|
274 |
+
"transformations"
|
275 |
+
][category].get(
|
276 |
+
"Saturation", predefined_defualts["Saturation"]
|
277 |
+
)
|
278 |
+
saturation_point = st.slider(
|
279 |
+
f"Saturation Percentage",
|
280 |
+
0,
|
281 |
+
100,
|
282 |
+
saturation_default,
|
283 |
+
10,
|
284 |
+
key=f"sat_{category}",
|
285 |
+
label_visibility="collapsed",
|
286 |
+
)
|
287 |
+
st.session_state["project_dct"]["transformations"][
|
288 |
+
category
|
289 |
+
]["Saturation"] = saturation_point
|
290 |
+
start = saturation_point[0]
|
291 |
+
end = saturation_point[1]
|
292 |
+
step = 10
|
293 |
+
transform_params[category]["Saturation"] = np.arange(
|
294 |
+
start, end + step, step
|
295 |
+
)
|
296 |
+
|
297 |
+
if transformation == "Power":
|
298 |
+
st.markdown("**Power**")
|
299 |
+
power_default = st.session_state["project_dct"][
|
300 |
+
"transformations"
|
301 |
+
][category].get("Power", predefined_defualts["Power"])
|
302 |
+
power = st.slider(
|
303 |
+
f"Power",
|
304 |
+
0,
|
305 |
+
10,
|
306 |
+
power_default,
|
307 |
+
1,
|
308 |
+
key=f"power_{category}",
|
309 |
+
label_visibility="collapsed",
|
310 |
+
)
|
311 |
+
st.session_state["project_dct"]["transformations"][
|
312 |
+
category
|
313 |
+
]["Power"] = power
|
314 |
+
start = power[0]
|
315 |
+
end = power[1]
|
316 |
+
step = 1
|
317 |
+
transform_params[category]["Power"] = np.arange(
|
318 |
+
start, end + step, step
|
319 |
+
)
|
320 |
+
|
321 |
+
if transformation == "Adstock":
|
322 |
+
ads_default = st.session_state["project_dct"][
|
323 |
+
"transformations"
|
324 |
+
][category].get("Adstock", predefined_defualts["Adstock"])
|
325 |
+
st.markdown("**Adstock**")
|
326 |
+
rate = st.slider(
|
327 |
+
f"Factor ({category})",
|
328 |
+
0.0,
|
329 |
+
1.0,
|
330 |
+
ads_default,
|
331 |
+
0.05,
|
332 |
+
key=f"adstock_{category}",
|
333 |
+
label_visibility="collapsed",
|
334 |
+
)
|
335 |
+
st.session_state["project_dct"]["transformations"][
|
336 |
+
category
|
337 |
+
]["Adstock"] = rate
|
338 |
+
start = rate[0]
|
339 |
+
end = rate[1]
|
340 |
+
step = 0.05
|
341 |
+
adstock_range = [
|
342 |
+
round(a, 3) for a in np.arange(start, end + step, step)
|
343 |
+
]
|
344 |
+
transform_params[category]["Adstock"] = adstock_range
|
345 |
+
|
346 |
+
# Create widgets in each column
|
347 |
+
create_transformation_widgets(col1, transformations_col1)
|
348 |
+
create_transformation_widgets(col2, transformations_col2)
|
349 |
+
|
350 |
+
# Function to apply Lag transformation
|
351 |
+
def apply_lag(df, lag):
|
352 |
+
return df.shift(lag)
|
353 |
+
|
354 |
+
# Function to apply Lead transformation
|
355 |
+
def apply_lead(df, lead):
|
356 |
+
return df.shift(-lead)
|
357 |
+
|
358 |
+
# Function to apply Moving Average transformation
|
359 |
+
def apply_moving_average(df, window_size):
|
360 |
+
return df.rolling(window=window_size).mean()
|
361 |
+
|
362 |
+
# Function to apply Saturation transformation
|
363 |
+
def apply_saturation(df, saturation_percent_100):
|
364 |
+
# Convert saturation percentage from 100-based to fraction
|
365 |
+
saturation_percent = saturation_percent_100 / 100.0
|
366 |
+
|
367 |
+
# Calculate saturation point and steepness
|
368 |
+
column_max = df.max()
|
369 |
+
column_min = df.min()
|
370 |
+
saturation_point = (column_min + column_max) / 2
|
371 |
+
|
372 |
+
numerator = np.log(
|
373 |
+
(1 / (saturation_percent if saturation_percent != 1 else 1 - 1e-9)) - 1
|
374 |
+
)
|
375 |
+
denominator = np.log(saturation_point / max(column_max, 1e-9))
|
376 |
+
|
377 |
+
steepness = numerator / max(
|
378 |
+
denominator, 1e-9
|
379 |
+
) # Avoid division by zero with a small constant
|
380 |
+
|
381 |
+
# Apply the saturation transformation
|
382 |
+
transformed_series = df.apply(
|
383 |
+
lambda x: (1 / (1 + (saturation_point / x) ** steepness)) * x
|
384 |
+
)
|
385 |
+
|
386 |
+
return transformed_series
|
387 |
+
|
388 |
+
# Function to apply Power transformation
|
389 |
+
def apply_power(df, power):
|
390 |
+
return df**power
|
391 |
+
|
392 |
+
# Function to apply Adstock transformation
|
393 |
+
def apply_adstock(df, factor):
|
394 |
+
x = 0
|
395 |
+
# Use the walrus operator to update x iteratively with the Adstock formula
|
396 |
+
adstock_var = [x := x * factor + v for v in df]
|
397 |
+
ans = pd.Series(adstock_var, index=df.index)
|
398 |
+
return ans
|
399 |
+
|
400 |
+
# Function to generate transformed columns names
|
401 |
+
@st.cache_resource(show_spinner=False)
|
402 |
+
def generate_transformed_columns(original_columns, transform_params):
|
403 |
+
transformed_columns, summary = {}, {}
|
404 |
+
|
405 |
+
for category, columns in original_columns.items():
|
406 |
+
for column in columns:
|
407 |
+
transformed_columns[column] = []
|
408 |
+
summary_details = (
|
409 |
+
[]
|
410 |
+
) # List to hold transformation details for the current column
|
411 |
+
|
412 |
+
if category in transform_params:
|
413 |
+
for transformation, values in transform_params[category].items():
|
414 |
+
# Generate transformed column names for each value
|
415 |
+
for value in values:
|
416 |
+
transformed_name = f"{column}@{transformation}_{value}"
|
417 |
+
transformed_columns[column].append(transformed_name)
|
418 |
+
|
419 |
+
# Format the values list as a string with commas and "and" before the last item
|
420 |
+
if len(values) > 1:
|
421 |
+
formatted_values = (
|
422 |
+
", ".join(map(str, values[:-1]))
|
423 |
+
+ " and "
|
424 |
+
+ str(values[-1])
|
425 |
+
)
|
426 |
+
else:
|
427 |
+
formatted_values = str(values[0])
|
428 |
+
|
429 |
+
# Add transformation details
|
430 |
+
summary_details.append(f"{transformation} ({formatted_values})")
|
431 |
+
|
432 |
+
# Only add to summary if there are transformation details for the column
|
433 |
+
if summary_details:
|
434 |
+
formatted_summary = "⮕ ".join(summary_details)
|
435 |
+
# Use <strong> tags to make the column name bold
|
436 |
+
summary[column] = f"<strong>{column}</strong>: {formatted_summary}"
|
437 |
+
|
438 |
+
# Generate a comprehensive summary string for all columns
|
439 |
+
summary_items = [
|
440 |
+
f"{idx + 1}. {details}" for idx, details in enumerate(summary.values())
|
441 |
+
]
|
442 |
+
|
443 |
+
summary_string = "\n".join(summary_items)
|
444 |
+
|
445 |
+
return transformed_columns, summary_string
|
446 |
+
|
447 |
+
# Function to apply transformations to DataFrame slices based on specified categories and parameters
|
448 |
+
@st.cache_resource(show_spinner=False)
|
449 |
+
def apply_category_transformations(df, bin_dict, transform_params, panel):
|
450 |
+
# Dictionary for function mapping
|
451 |
+
transformation_functions = {
|
452 |
+
"Lead": apply_lead,
|
453 |
+
"Lag": apply_lag,
|
454 |
+
"Moving Average": apply_moving_average,
|
455 |
+
"Saturation": apply_saturation,
|
456 |
+
"Power": apply_power,
|
457 |
+
"Adstock": apply_adstock,
|
458 |
+
}
|
459 |
+
|
460 |
+
# Initialize category_df as an empty DataFrame
|
461 |
+
category_df = pd.DataFrame()
|
462 |
+
|
463 |
+
# Iterate through each category specified in transform_params
|
464 |
+
for category in ["Media", "Internal", "Exogenous"]:
|
465 |
+
if (
|
466 |
+
category not in transform_params
|
467 |
+
or category not in bin_dict
|
468 |
+
or not transform_params[category]
|
469 |
+
):
|
470 |
+
continue # Skip categories without transformations
|
471 |
+
|
472 |
+
# Slice the DataFrame based on the columns specified in bin_dict for the current category
|
473 |
+
df_slice = df[bin_dict[category] + panel]
|
474 |
+
|
475 |
+
# Iterate through each transformation and its parameters for the current category
|
476 |
+
for transformation, parameters in transform_params[category].items():
|
477 |
+
transformation_function = transformation_functions[transformation]
|
478 |
+
|
479 |
+
# Check if there is panel data to group by
|
480 |
+
if len(panel) > 0:
|
481 |
+
# Apply the transformation to each group
|
482 |
+
category_df = pd.concat(
|
483 |
+
[
|
484 |
+
df_slice.groupby(panel)
|
485 |
+
.transform(transformation_function, p)
|
486 |
+
.add_suffix(f"@{transformation}_{p}")
|
487 |
+
for p in parameters
|
488 |
+
],
|
489 |
+
axis=1,
|
490 |
+
)
|
491 |
+
|
492 |
+
# Replace all NaN or null values in category_df with 0
|
493 |
+
category_df.fillna(0, inplace=True)
|
494 |
+
|
495 |
+
# Update df_slice
|
496 |
+
df_slice = pd.concat(
|
497 |
+
[df[panel], category_df],
|
498 |
+
axis=1,
|
499 |
+
)
|
500 |
+
|
501 |
+
else:
|
502 |
+
for p in parameters:
|
503 |
+
# Apply the transformation function to each column
|
504 |
+
temp_df = df_slice.apply(
|
505 |
+
lambda x: transformation_function(x, p), axis=0
|
506 |
+
).rename(
|
507 |
+
lambda x: f"{x}@{transformation}_{p}",
|
508 |
+
axis="columns",
|
509 |
+
)
|
510 |
+
# Concatenate the transformed DataFrame slice to the category DataFrame
|
511 |
+
category_df = pd.concat([category_df, temp_df], axis=1)
|
512 |
+
|
513 |
+
# Replace all NaN or null values in category_df with 0
|
514 |
+
category_df.fillna(0, inplace=True)
|
515 |
+
|
516 |
+
# Update df_slice
|
517 |
+
df_slice = pd.concat(
|
518 |
+
[df[panel], category_df],
|
519 |
+
axis=1,
|
520 |
+
)
|
521 |
+
|
522 |
+
# If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
|
523 |
+
if not category_df.empty:
|
524 |
+
final_df = pd.concat([df, category_df], axis=1)
|
525 |
+
else:
|
526 |
+
# If no transformations were applied, use the original DataFrame
|
527 |
+
final_df = df
|
528 |
+
|
529 |
+
return final_df
|
530 |
+
|
531 |
+
# Function to infers the granularity of the date column in a DataFrame
|
532 |
+
@st.cache_resource(show_spinner=False)
|
533 |
+
def infer_date_granularity(df):
|
534 |
+
# Find the most common difference
|
535 |
+
common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
|
536 |
+
|
537 |
+
# Map the most common difference to a granularity
|
538 |
+
if common_freq == 1:
|
539 |
+
return "daily"
|
540 |
+
elif common_freq == 7:
|
541 |
+
return "weekly"
|
542 |
+
elif 28 <= common_freq <= 31:
|
543 |
+
return "monthly"
|
544 |
+
else:
|
545 |
+
return "irregular"
|
546 |
+
|
547 |
+
#########################################################################################################################################################
|
548 |
+
# User input for transformations
|
549 |
+
#########################################################################################################################################################
|
550 |
+
|
551 |
+
# Infer date granularity
|
552 |
+
date_granularity = infer_date_granularity(final_df_loaded)
|
553 |
+
|
554 |
+
# Initialize the main dictionary to store the transformation parameters for each category
|
555 |
+
transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
|
556 |
+
|
557 |
+
# User input for transformations
|
558 |
+
st.markdown("### Select Transformations to Apply")
|
559 |
+
for category in ["Media", "Internal", "Exogenous"]:
|
560 |
+
# Skip Internal
|
561 |
+
if category == "Internal":
|
562 |
+
continue
|
563 |
+
|
564 |
+
transformation_widgets(category, transform_params, date_granularity)
|
565 |
+
|
566 |
+
#########################################################################################################################################################
|
567 |
+
# Apply transformations
|
568 |
+
#########################################################################################################################################################
|
569 |
+
|
570 |
+
# Apply category-based transformations to the DataFrame
|
571 |
+
if st.button("Accept and Proceed", use_container_width=True):
|
572 |
+
with st.spinner("Applying transformations..."):
|
573 |
+
final_df = apply_category_transformations(
|
574 |
+
final_df_loaded, bin_dict_loaded, transform_params, panel
|
575 |
+
)
|
576 |
+
|
577 |
+
# Generate a dictionary mapping original column names to lists of transformed column names
|
578 |
+
transformed_columns_dict, summary_string = generate_transformed_columns(
|
579 |
+
original_columns, transform_params
|
580 |
+
)
|
581 |
+
|
582 |
+
# Store into transformed dataframe and summary session state
|
583 |
+
st.session_state["final_df"] = final_df
|
584 |
+
st.session_state["summary_string"] = summary_string
|
585 |
+
|
586 |
+
#########################################################################################################################################################
|
587 |
+
# Display the transformed DataFrame and summary
|
588 |
+
#########################################################################################################################################################
|
589 |
+
|
590 |
+
# Display the transformed DataFrame in the Streamlit app
|
591 |
+
st.markdown("### Transformed DataFrame")
|
592 |
+
final_df = st.session_state["final_df"].copy()
|
593 |
+
|
594 |
+
sort_col = []
|
595 |
+
for col in final_df.columns:
|
596 |
+
if col in ["Panel_1", "Panel_2", "date"]:
|
597 |
+
sort_col.append(col)
|
598 |
+
|
599 |
+
sorted_final_df = final_df.sort_values(
|
600 |
+
by=sort_col, ascending=True, na_position="first"
|
601 |
+
)
|
602 |
+
st.dataframe(sorted_final_df, hide_index=True)
|
603 |
+
|
604 |
+
# Total rows and columns
|
605 |
+
total_rows, total_columns = st.session_state["final_df"].shape
|
606 |
+
st.markdown(
|
607 |
+
f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
|
608 |
+
unsafe_allow_html=True,
|
609 |
+
)
|
610 |
+
|
611 |
+
# Display the summary of transformations as markdown
|
612 |
+
if "summary_string" in st.session_state and st.session_state["summary_string"]:
|
613 |
+
with st.expander("Summary of Transformations"):
|
614 |
+
st.markdown("### Summary of Transformations")
|
615 |
+
st.markdown(st.session_state["summary_string"], unsafe_allow_html=True)
|
616 |
+
|
617 |
+
@st.cache_resource(show_spinner=False)
|
618 |
+
def save_to_pickle(file_path, final_df):
|
619 |
+
# Open the file in write-binary mode and dump the objects
|
620 |
+
with open(file_path, "wb") as f:
|
621 |
+
pickle.dump({"final_df_transformed": final_df}, f)
|
622 |
+
# Data is now saved to file
|
623 |
+
|
624 |
+
if st.button("Accept and Save", use_container_width=True):
|
625 |
+
|
626 |
+
save_to_pickle(
|
627 |
+
os.path.join(st.session_state["project_path"], "final_df_transformed.pkl"),
|
628 |
+
st.session_state["final_df"],
|
629 |
+
)
|
630 |
+
project_dct_path = os.path.join(
|
631 |
+
st.session_state["project_path"], "project_dct.pkl"
|
632 |
+
)
|
633 |
+
|
634 |
+
with open(project_dct_path, "wb") as f:
|
635 |
+
pickle.dump(st.session_state["project_dct"], f)
|
636 |
+
|
637 |
+
update_db("3_Transformations.py")
|
638 |
+
|
639 |
+
st.toast("💾 Saved Successfully!")
|
pages/4_Model_Build 2.py
ADDED
@@ -0,0 +1,1288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MMO Build Sprint 3
|
3 |
+
additions : adding more variables to session state for saved model : random effect, predicted train & test
|
4 |
+
|
5 |
+
MMO Build Sprint 4
|
6 |
+
additions : ability to run models for different response metrics
|
7 |
+
"""
|
8 |
+
|
9 |
+
import streamlit as st
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.express as px
|
12 |
+
import plotly.graph_objects as go
|
13 |
+
from Eda_functions import format_numbers
|
14 |
+
import numpy as np
|
15 |
+
import pickle
|
16 |
+
from st_aggrid import AgGrid
|
17 |
+
from st_aggrid import GridOptionsBuilder, GridUpdateMode
|
18 |
+
from utilities import set_header, load_local_css
|
19 |
+
from st_aggrid import GridOptionsBuilder
|
20 |
+
import time
|
21 |
+
import itertools
|
22 |
+
import statsmodels.api as sm
|
23 |
+
import numpy as npc
|
24 |
+
import re
|
25 |
+
import itertools
|
26 |
+
from sklearn.metrics import (
|
27 |
+
mean_absolute_error,
|
28 |
+
r2_score,
|
29 |
+
mean_absolute_percentage_error,
|
30 |
+
)
|
31 |
+
from sklearn.preprocessing import MinMaxScaler
|
32 |
+
import os
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
35 |
+
import yaml
|
36 |
+
from yaml import SafeLoader
|
37 |
+
import streamlit_authenticator as stauth
|
38 |
+
|
39 |
+
st.set_option("deprecation.showPyplotGlobalUse", False)
|
40 |
+
import statsmodels.api as sm
|
41 |
+
import statsmodels.formula.api as smf
|
42 |
+
|
43 |
+
from datetime import datetime
|
44 |
+
import seaborn as sns
|
45 |
+
from Data_prep_functions import *
|
46 |
+
import sqlite3
|
47 |
+
from utilities import update_db
|
48 |
+
from datetime import datetime, timedelta
|
49 |
+
|
50 |
+
@st.cache_resource(show_spinner=False)
|
51 |
+
# def save_to_pickle(file_path, final_df):
|
52 |
+
# # Open the file in write-binary mode and dump the objects
|
53 |
+
# with open(file_path, "wb") as f:
|
54 |
+
# pickle.dump({file_path: final_df}, f)
|
55 |
+
@st.cache_resource(show_spinner=True)
|
56 |
+
def prepare_data_df(data):
|
57 |
+
data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
|
58 |
+
drop=True
|
59 |
+
) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
|
60 |
+
data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
|
61 |
+
data.drop_duplicates(subset="Model_iteration", inplace=True)
|
62 |
+
|
63 |
+
# Applying the function to each row in the DataFrame
|
64 |
+
data["coefficients"] = data["coefficients"].apply(process_dict)
|
65 |
+
|
66 |
+
# Convert dictionary items into separate DataFrame columns
|
67 |
+
coefficients_df = data["coefficients"].apply(pd.Series)
|
68 |
+
|
69 |
+
# Rename the columns to remove any trailing underscores and capitalize the words
|
70 |
+
coefficients_df.columns = [
|
71 |
+
col.strip("_").replace("_", " ").title() for col in coefficients_df.columns
|
72 |
+
]
|
73 |
+
|
74 |
+
# Normalize each row so that the sum equals 100%
|
75 |
+
coefficients_df = coefficients_df.apply(
|
76 |
+
lambda x: round((x / x.sum()) * 100, 2), axis=1
|
77 |
+
)
|
78 |
+
|
79 |
+
# Join the new columns back to the original DataFrame
|
80 |
+
data = data.join(coefficients_df)
|
81 |
+
|
82 |
+
data_df = data[
|
83 |
+
[
|
84 |
+
"Model_iteration",
|
85 |
+
"MAPE",
|
86 |
+
"ADJR2",
|
87 |
+
"R2",
|
88 |
+
"Total Positive Contributions",
|
89 |
+
"Significance",
|
90 |
+
]
|
91 |
+
+ list(coefficients_df.columns)
|
92 |
+
]
|
93 |
+
data_df.rename(columns={"Model_iteration": "Model Iteration"}, inplace=True)
|
94 |
+
data_df.insert(0, "Rank", range(1, len(data_df) + 1))
|
95 |
+
|
96 |
+
return coefficients_df, data_df
|
97 |
+
def format_display(inp):
|
98 |
+
return inp.title().replace("_", " ").strip()
|
99 |
+
|
100 |
+
def get_random_effects(media_data, panel_col, _mdf):
|
101 |
+
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
102 |
+
|
103 |
+
for i, market in enumerate(media_data[panel_col].unique()):
|
104 |
+
print(i, end="\r")
|
105 |
+
intercept = _mdf.random_effects[market].values[0]
|
106 |
+
random_eff_df.loc[i, "random_effect"] = intercept
|
107 |
+
random_eff_df.loc[i, panel_col] = market
|
108 |
+
|
109 |
+
return random_eff_df
|
110 |
+
|
111 |
+
|
112 |
+
def mdf_predict(X_df, mdf, random_eff_df):
|
113 |
+
X = X_df.copy()
|
114 |
+
X["fixed_effect"] = mdf.predict(X)
|
115 |
+
X = pd.merge(X, random_eff_df, on=panel_col, how="left")
|
116 |
+
X["pred"] = X["fixed_effect"] + X["random_effect"]
|
117 |
+
# X.to_csv('Test/megred_df.csv',index=False)
|
118 |
+
X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
|
119 |
+
return X["pred"]
|
120 |
+
|
121 |
+
|
122 |
+
st.set_page_config(
|
123 |
+
page_title="Model Build",
|
124 |
+
page_icon=":shark:",
|
125 |
+
layout="wide",
|
126 |
+
initial_sidebar_state="collapsed",
|
127 |
+
)
|
128 |
+
|
129 |
+
load_local_css("styles.css")
|
130 |
+
set_header()
|
131 |
+
|
132 |
+
# Check for authentication status
|
133 |
+
for k, v in st.session_state.items():
|
134 |
+
if k not in [
|
135 |
+
"logout",
|
136 |
+
"login",
|
137 |
+
"config",
|
138 |
+
"model_build_button",
|
139 |
+
] and not k.startswith("FormSubmitter"):
|
140 |
+
st.session_state[k] = v
|
141 |
+
with open("config.yaml") as file:
|
142 |
+
config = yaml.load(file, Loader=SafeLoader)
|
143 |
+
st.session_state["config"] = config
|
144 |
+
authenticator = stauth.Authenticate(
|
145 |
+
config["credentials"],
|
146 |
+
config["cookie"]["name"],
|
147 |
+
config["cookie"]["key"],
|
148 |
+
config["cookie"]["expiry_days"],
|
149 |
+
config["preauthorized"],
|
150 |
+
)
|
151 |
+
st.session_state["authenticator"] = authenticator
|
152 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
153 |
+
auth_status = st.session_state.get("authentication_status")
|
154 |
+
|
155 |
+
if auth_status == True:
|
156 |
+
authenticator.logout("Logout", "main")
|
157 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
158 |
+
|
159 |
+
conn = sqlite3.connect(
|
160 |
+
r"DB/User.db", check_same_thread=False
|
161 |
+
) # connection with sql db
|
162 |
+
c = conn.cursor()
|
163 |
+
|
164 |
+
if not is_state_initiaized:
|
165 |
+
|
166 |
+
if "session_name" not in st.session_state:
|
167 |
+
st.session_state["session_name"] = None
|
168 |
+
|
169 |
+
if "project_dct" not in st.session_state:
|
170 |
+
st.error("Please load a project from Home page")
|
171 |
+
st.stop()
|
172 |
+
|
173 |
+
st.title("1. Build Your Model")
|
174 |
+
|
175 |
+
if not os.path.exists(
|
176 |
+
os.path.join(st.session_state["project_path"], "data_import.pkl")
|
177 |
+
):
|
178 |
+
st.error("Please move to Data Import Page and save.")
|
179 |
+
st.stop()
|
180 |
+
with open(
|
181 |
+
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
|
182 |
+
) as f:
|
183 |
+
data = pickle.load(f)
|
184 |
+
st.session_state["bin_dict"] = data["bin_dict"]
|
185 |
+
|
186 |
+
if not os.path.exists(
|
187 |
+
os.path.join(
|
188 |
+
st.session_state["project_path"], "final_df_transformed.pkl"
|
189 |
+
)
|
190 |
+
):
|
191 |
+
st.error(
|
192 |
+
"Please move to Transformation Page and save transformations."
|
193 |
+
)
|
194 |
+
st.stop()
|
195 |
+
with open(
|
196 |
+
os.path.join(
|
197 |
+
st.session_state["project_path"], "final_df_transformed.pkl"
|
198 |
+
),
|
199 |
+
"rb",
|
200 |
+
) as f:
|
201 |
+
data = pickle.load(f)
|
202 |
+
media_data = data["final_df_transformed"]
|
203 |
+
|
204 |
+
|
205 |
+
# Sprint4 - available response metrics is a list of all reponse metrics in the data
|
206 |
+
## these will be put in a drop down
|
207 |
+
|
208 |
+
st.session_state["media_data"] = media_data
|
209 |
+
|
210 |
+
if "available_response_metrics" not in st.session_state:
|
211 |
+
# st.session_state['available_response_metrics'] = ['Total Approved Accounts - Revenue',
|
212 |
+
# 'Total Approved Accounts - Appsflyer',
|
213 |
+
# 'Account Requests - Appsflyer',
|
214 |
+
# 'App Installs - Appsflyer']
|
215 |
+
|
216 |
+
st.session_state["available_response_metrics"] = st.session_state[
|
217 |
+
"bin_dict"
|
218 |
+
]["Response Metrics"]
|
219 |
+
# Sprint4
|
220 |
+
if "is_tuned_model" not in st.session_state:
|
221 |
+
st.session_state["is_tuned_model"] = {}
|
222 |
+
for resp_metric in st.session_state["available_response_metrics"]:
|
223 |
+
resp_metric = (
|
224 |
+
resp_metric.lower()
|
225 |
+
.replace(" ", "_")
|
226 |
+
.replace("-", "")
|
227 |
+
.replace(":", "")
|
228 |
+
.replace("__", "_")
|
229 |
+
)
|
230 |
+
st.session_state["is_tuned_model"][resp_metric] = False
|
231 |
+
|
232 |
+
# Sprint4 - used_response_metrics is a list of resp metrics for which user has created & saved a model
|
233 |
+
if "used_response_metrics" not in st.session_state:
|
234 |
+
st.session_state["used_response_metrics"] = []
|
235 |
+
|
236 |
+
# Sprint4 - saved_model_names
|
237 |
+
if "saved_model_names" not in st.session_state:
|
238 |
+
st.session_state["saved_model_names"] = []
|
239 |
+
|
240 |
+
if "Model" not in st.session_state:
|
241 |
+
if (
|
242 |
+
"session_state_saved"
|
243 |
+
in st.session_state["project_dct"]["model_build"].keys()
|
244 |
+
and st.session_state["project_dct"]["model_build"][
|
245 |
+
"session_state_saved"
|
246 |
+
]
|
247 |
+
is not None
|
248 |
+
and "Model"
|
249 |
+
in st.session_state["project_dct"]["model_build"][
|
250 |
+
"session_state_saved"
|
251 |
+
].keys()
|
252 |
+
):
|
253 |
+
st.session_state["Model"] = st.session_state["project_dct"][
|
254 |
+
"model_build"
|
255 |
+
]["session_state_saved"]["Model"]
|
256 |
+
else:
|
257 |
+
st.session_state["Model"] = {}
|
258 |
+
|
259 |
+
date_col = "date"
|
260 |
+
date = media_data[date_col]
|
261 |
+
|
262 |
+
# Sprint4 - select a response metric
|
263 |
+
default_target_idx = (
|
264 |
+
st.session_state["project_dct"]["model_build"].get(
|
265 |
+
"sel_target_col", None
|
266 |
+
)
|
267 |
+
if st.session_state["project_dct"]["model_build"].get(
|
268 |
+
"sel_target_col", None
|
269 |
+
)
|
270 |
+
is not None
|
271 |
+
else st.session_state["available_response_metrics"][0]
|
272 |
+
)
|
273 |
+
|
274 |
+
start_cols = st.columns(2)
|
275 |
+
min_date = min(date)
|
276 |
+
max_date = max(date)
|
277 |
+
|
278 |
+
with start_cols[0]:
|
279 |
+
sel_target_col = st.selectbox(
|
280 |
+
"Select the response metric",
|
281 |
+
st.session_state["available_response_metrics"],
|
282 |
+
index=st.session_state["available_response_metrics"].index(
|
283 |
+
default_target_idx
|
284 |
+
),
|
285 |
+
format_func=format_display
|
286 |
+
)
|
287 |
+
# , on_change=reset_save())
|
288 |
+
st.session_state["project_dct"]["model_build"][
|
289 |
+
"sel_target_col"
|
290 |
+
] = sel_target_col
|
291 |
+
|
292 |
+
|
293 |
+
default_test_start = min_date + (3*(max_date-min_date)/4)
|
294 |
+
|
295 |
+
with start_cols[1]:
|
296 |
+
test_start = st.date_input(
|
297 |
+
"Select test start date",
|
298 |
+
default_test_start,
|
299 |
+
min_value=min_date,
|
300 |
+
max_value=max_date,
|
301 |
+
)
|
302 |
+
train_idx = media_data[media_data[date_col] <= pd.to_datetime(test_start)].index[-1]
|
303 |
+
# st.write(train_idx, media_data.index[-1])
|
304 |
+
|
305 |
+
target_col = (
|
306 |
+
sel_target_col.lower()
|
307 |
+
.replace(" ", "_")
|
308 |
+
.replace("-", "")
|
309 |
+
.replace(":", "")
|
310 |
+
.replace("__", "_")
|
311 |
+
)
|
312 |
+
new_name_dct = {
|
313 |
+
col: col.lower()
|
314 |
+
.replace(".", "_")
|
315 |
+
.lower()
|
316 |
+
.replace("@", "_")
|
317 |
+
.replace(" ", "_")
|
318 |
+
.replace("-", "")
|
319 |
+
.replace(":", "")
|
320 |
+
.replace("__", "_")
|
321 |
+
for col in media_data.columns
|
322 |
+
}
|
323 |
+
media_data.columns = [
|
324 |
+
col.lower()
|
325 |
+
.replace(".", "_")
|
326 |
+
.replace("@", "_")
|
327 |
+
.replace(" ", "_")
|
328 |
+
.replace("-", "")
|
329 |
+
.replace(":", "")
|
330 |
+
.replace("__", "_")
|
331 |
+
for col in media_data.columns
|
332 |
+
]
|
333 |
+
panel_col = [
|
334 |
+
col.lower()
|
335 |
+
.replace(".", "_")
|
336 |
+
.replace("@", "_")
|
337 |
+
.replace(" ", "_")
|
338 |
+
.replace("-", "")
|
339 |
+
.replace(":", "")
|
340 |
+
.replace("__", "_")
|
341 |
+
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
342 |
+
][0] # set the panel column
|
343 |
+
|
344 |
+
is_panel = True if len(panel_col) > 0 else False
|
345 |
+
|
346 |
+
if "is_panel" not in st.session_state:
|
347 |
+
st.session_state["is_panel"] = is_panel
|
348 |
+
|
349 |
+
if is_panel:
|
350 |
+
media_data.sort_values([date_col, panel_col], inplace=True)
|
351 |
+
else:
|
352 |
+
media_data.sort_values(date_col, inplace=True)
|
353 |
+
|
354 |
+
media_data.reset_index(drop=True, inplace=True)
|
355 |
+
|
356 |
+
st.session_state["date"] = date
|
357 |
+
y = media_data[target_col]
|
358 |
+
|
359 |
+
if is_panel:
|
360 |
+
spends_data = media_data[
|
361 |
+
[
|
362 |
+
c
|
363 |
+
for c in media_data.columns
|
364 |
+
if "_cost" in c.lower() or "_spend" in c.lower()
|
365 |
+
]
|
366 |
+
+ [date_col, panel_col]
|
367 |
+
]
|
368 |
+
# Sprint3 - spends for resp curves
|
369 |
+
else:
|
370 |
+
spends_data = media_data[
|
371 |
+
[
|
372 |
+
c
|
373 |
+
for c in media_data.columns
|
374 |
+
if "_cost" in c.lower() or "_spend" in c.lower()
|
375 |
+
]
|
376 |
+
+ [date_col]
|
377 |
+
]
|
378 |
+
|
379 |
+
y = media_data[target_col]
|
380 |
+
media_data.drop([date_col], axis=1, inplace=True)
|
381 |
+
media_data.reset_index(drop=True, inplace=True)
|
382 |
+
|
383 |
+
columns = st.columns(2)
|
384 |
+
|
385 |
+
old_shape = media_data.shape
|
386 |
+
|
387 |
+
if "old_shape" not in st.session_state:
|
388 |
+
st.session_state["old_shape"] = old_shape
|
389 |
+
|
390 |
+
if "media_data" not in st.session_state:
|
391 |
+
st.session_state["media_data"] = pd.DataFrame()
|
392 |
+
|
393 |
+
# Sprint3
|
394 |
+
if "orig_media_data" not in st.session_state:
|
395 |
+
st.session_state["orig_media_data"] = pd.DataFrame()
|
396 |
+
|
397 |
+
# Sprint3 additions
|
398 |
+
if "random_effects" not in st.session_state:
|
399 |
+
st.session_state["random_effects"] = pd.DataFrame()
|
400 |
+
if "pred_train" not in st.session_state:
|
401 |
+
st.session_state["pred_train"] = []
|
402 |
+
if "pred_test" not in st.session_state:
|
403 |
+
st.session_state["pred_test"] = []
|
404 |
+
# end of Sprint3 additions
|
405 |
+
|
406 |
+
# Section 3 - Create combinations
|
407 |
+
|
408 |
+
# bucket=['paid_search', 'kwai','indicacao','infleux', 'influencer','FB: Level Achieved - Tier 1 Impressions',
|
409 |
+
# ' FB: Level Achieved - Tier 2 Impressions','paid_social_others',
|
410 |
+
# ' GA App: Will And Cid Pequena Baixo Risco Clicks',
|
411 |
+
# 'digital_tactic_others',"programmatic"
|
412 |
+
# ]
|
413 |
+
|
414 |
+
# srishti - bucket names changed
|
415 |
+
bucket = [
|
416 |
+
"paid_search",
|
417 |
+
"kwai",
|
418 |
+
"indicacao",
|
419 |
+
"infleux",
|
420 |
+
"influencer",
|
421 |
+
"fb_level_achieved_tier_2",
|
422 |
+
"fb_level_achieved_tier_1",
|
423 |
+
"paid_social_others",
|
424 |
+
"ga_app",
|
425 |
+
"digital_tactic_others",
|
426 |
+
"programmatic",
|
427 |
+
]
|
428 |
+
|
429 |
+
# with columns[0]:
|
430 |
+
# if st.button('Create Combinations of Variables'):
|
431 |
+
|
432 |
+
top_3_correlated_features = []
|
433 |
+
# # for col in st.session_state['media_data'].columns[:19]:
|
434 |
+
# original_cols = [c for c in st.session_state['media_data'].columns if
|
435 |
+
# "_clicks" in c.lower() or "_impressions" in c.lower()]
|
436 |
+
# original_cols = [c for c in original_cols if "_lag" not in c.lower() and "_adstock" not in c.lower()]
|
437 |
+
|
438 |
+
original_cols = (
|
439 |
+
st.session_state["bin_dict"]["Media"]
|
440 |
+
+ st.session_state["bin_dict"]["Internal"]
|
441 |
+
)
|
442 |
+
|
443 |
+
original_cols = [
|
444 |
+
col.lower()
|
445 |
+
.replace(".", "_")
|
446 |
+
.replace("@", "_")
|
447 |
+
.replace(" ", "_")
|
448 |
+
.replace("-", "")
|
449 |
+
.replace(":", "")
|
450 |
+
.replace("__", "_")
|
451 |
+
for col in original_cols
|
452 |
+
]
|
453 |
+
original_cols = [col for col in original_cols if "_cost" not in col]
|
454 |
+
# for col in st.session_state['media_data'].columns[:19]:
|
455 |
+
for col in original_cols: # srishti - new
|
456 |
+
corr_df = (
|
457 |
+
pd.concat(
|
458 |
+
[st.session_state["media_data"].filter(regex=col), y], axis=1
|
459 |
+
)
|
460 |
+
.corr()[target_col]
|
461 |
+
.iloc[:-1]
|
462 |
+
)
|
463 |
+
top_3_correlated_features.append(
|
464 |
+
list(corr_df.sort_values(ascending=False).head(2).index)
|
465 |
+
)
|
466 |
+
flattened_list = [
|
467 |
+
item for sublist in top_3_correlated_features for item in sublist
|
468 |
+
]
|
469 |
+
# all_features_set={var:[col for col in flattened_list if var in col] for var in bucket}
|
470 |
+
all_features_set = {
|
471 |
+
var: [col for col in flattened_list if var in col]
|
472 |
+
for var in bucket
|
473 |
+
if len([col for col in flattened_list if var in col]) > 0
|
474 |
+
} # srishti
|
475 |
+
channels_all = [values for values in all_features_set.values()]
|
476 |
+
st.session_state["combinations"] = list(itertools.product(*channels_all))
|
477 |
+
# if 'combinations' not in st.session_state:
|
478 |
+
# st.session_state['combinations']=combinations_all
|
479 |
+
|
480 |
+
st.session_state["final_selection"] = st.session_state["combinations"]
|
481 |
+
# st.success('Created combinations')
|
482 |
+
|
483 |
+
# revenue.reset_index(drop=True,inplace=True)
|
484 |
+
y.reset_index(drop=True, inplace=True)
|
485 |
+
if "Model_results" not in st.session_state:
|
486 |
+
st.session_state["Model_results"] = {
|
487 |
+
"Model_object": [],
|
488 |
+
"Model_iteration": [],
|
489 |
+
"Feature_set": [],
|
490 |
+
"MAPE": [],
|
491 |
+
"R2": [],
|
492 |
+
"ADJR2": [],
|
493 |
+
"pos_count": [],
|
494 |
+
}
|
495 |
+
|
496 |
+
def reset_model_result_dct():
|
497 |
+
st.session_state["Model_results"] = {
|
498 |
+
"Model_object": [],
|
499 |
+
"Model_iteration": [],
|
500 |
+
"Feature_set": [],
|
501 |
+
"MAPE": [],
|
502 |
+
"R2": [],
|
503 |
+
"ADJR2": [],
|
504 |
+
"pos_count": [],
|
505 |
+
}
|
506 |
+
|
507 |
+
# if st.button('Build Model'):
|
508 |
+
|
509 |
+
if "iterations" not in st.session_state:
|
510 |
+
st.session_state["iterations"] = 0
|
511 |
+
|
512 |
+
if "final_selection" not in st.session_state:
|
513 |
+
st.session_state["final_selection"] = False
|
514 |
+
|
515 |
+
save_path = r"Model/"
|
516 |
+
if st.session_state["final_selection"]:
|
517 |
+
st.write(
|
518 |
+
f'Total combinations created {format_numbers(len(st.session_state["final_selection"]))}'
|
519 |
+
)
|
520 |
+
|
521 |
+
# st.session_state["project_dct"]["model_build"]["all_iters_check"] = False
|
522 |
+
|
523 |
+
checkbox_default = (
|
524 |
+
st.session_state["project_dct"]["model_build"]["all_iters_check"]
|
525 |
+
if st.session_state["project_dct"]["model_build"]["all_iters_check"]
|
526 |
+
is not None
|
527 |
+
else False
|
528 |
+
)
|
529 |
+
end_date = test_start - timedelta(days=1)
|
530 |
+
disp_str = "Data Split -- Training Period: " + min_date.strftime("%B %d, %Y") + " - " + end_date.strftime("%B %d, %Y") +", Testing Period: " + test_start.strftime("%B %d, %Y") + " - " + max_date.strftime("%B %d, %Y")
|
531 |
+
st.markdown(disp_str)
|
532 |
+
if st.checkbox("Build all iterations", value=checkbox_default):
|
533 |
+
# st.session_state["project_dct"]["model_build"]["all_iters_check"]
|
534 |
+
iterations = len(st.session_state["final_selection"])
|
535 |
+
st.session_state["project_dct"]["model_build"][
|
536 |
+
"all_iters_check"
|
537 |
+
] = True
|
538 |
+
|
539 |
+
else:
|
540 |
+
iterations = st.number_input(
|
541 |
+
"Select the number of iterations to perform",
|
542 |
+
min_value=0,
|
543 |
+
step=100,
|
544 |
+
value=st.session_state["iterations"],
|
545 |
+
on_change=reset_model_result_dct,
|
546 |
+
)
|
547 |
+
st.session_state["project_dct"]["model_build"][
|
548 |
+
"all_iters_check"
|
549 |
+
] = False
|
550 |
+
st.session_state["project_dct"]["model_build"][
|
551 |
+
"iterations"
|
552 |
+
] = iterations
|
553 |
+
|
554 |
+
# st.stop()
|
555 |
+
|
556 |
+
# build_button = st.session_state["project_dct"]["model_build"]["build_button"] if \
|
557 |
+
# "build_button" in st.session_state["project_dct"]["model_build"].keys() else False
|
558 |
+
# model_button =st.button('Build Model', on_click=reset_model_result_dct, key='model_build_button')
|
559 |
+
# if
|
560 |
+
# if model_button:
|
561 |
+
if st.button(
|
562 |
+
"Build Model",
|
563 |
+
on_click=reset_model_result_dct,
|
564 |
+
key="model_build_button",
|
565 |
+
):
|
566 |
+
if iterations < 1:
|
567 |
+
st.error("Please select number of iterations")
|
568 |
+
st.stop()
|
569 |
+
st.session_state["project_dct"]["model_build"]["build_button"] = True
|
570 |
+
st.session_state["iterations"] = iterations
|
571 |
+
|
572 |
+
# Section 4 - Model
|
573 |
+
# st.session_state['media_data'] = st.session_state['media_data'].fillna(method='ffill')
|
574 |
+
st.session_state["media_data"] = st.session_state["media_data"].ffill()
|
575 |
+
progress_bar = st.progress(0) # Initialize the progress bar
|
576 |
+
# time_remaining_text = st.empty() # Create an empty space for time remaining text
|
577 |
+
start_time = time.time() # Record the start time
|
578 |
+
progress_text = st.empty()
|
579 |
+
|
580 |
+
# time_elapsed_text = st.empty()
|
581 |
+
# for i, selected_features in enumerate(st.session_state["final_selection"][40000:40000 + int(iterations)]):
|
582 |
+
# for i, selected_features in enumerate(st.session_state["final_selection"]):
|
583 |
+
|
584 |
+
if is_panel == True:
|
585 |
+
for i, selected_features in enumerate(
|
586 |
+
st.session_state["final_selection"][0 : int(iterations)]
|
587 |
+
): # srishti
|
588 |
+
df = st.session_state["media_data"]
|
589 |
+
|
590 |
+
fet = [var for var in selected_features if len(var) > 0]
|
591 |
+
inp_vars_str = " + ".join(fet) # new
|
592 |
+
|
593 |
+
X = df[fet]
|
594 |
+
y = df[target_col]
|
595 |
+
ss = MinMaxScaler()
|
596 |
+
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
597 |
+
|
598 |
+
X[target_col] = y # Sprint2
|
599 |
+
X[panel_col] = df[panel_col] # Sprint2
|
600 |
+
|
601 |
+
X_train = X.iloc[:train_idx]
|
602 |
+
X_test = X.iloc[train_idx:]
|
603 |
+
y_train = y.iloc[:train_idx]
|
604 |
+
y_test = y.iloc[train_idx:]
|
605 |
+
|
606 |
+
print(X_train.shape)
|
607 |
+
# model = sm.OLS(y_train, X_train).fit()
|
608 |
+
md_str = target_col + " ~ " + inp_vars_str
|
609 |
+
# md = smf.mixedlm("total_approved_accounts_revenue ~ {}".format(inp_vars_str),
|
610 |
+
# data=X_train[[target_col] + fet],
|
611 |
+
# groups=X_train[panel_col])
|
612 |
+
md = smf.mixedlm(
|
613 |
+
md_str,
|
614 |
+
data=X_train[[target_col] + fet],
|
615 |
+
groups=X_train[panel_col],
|
616 |
+
)
|
617 |
+
mdf = md.fit()
|
618 |
+
predicted_values = mdf.fittedvalues
|
619 |
+
|
620 |
+
coefficients = mdf.fe_params.to_dict()
|
621 |
+
model_positive = [
|
622 |
+
col for col in coefficients.keys() if coefficients[col] > 0
|
623 |
+
]
|
624 |
+
|
625 |
+
pvalues = [var for var in list(mdf.pvalues) if var <= 0.06]
|
626 |
+
|
627 |
+
if (len(model_positive) / len(selected_features)) > 0 and (
|
628 |
+
len(pvalues) / len(selected_features)
|
629 |
+
) >= 0: # srishti - changed just for testing, revert later
|
630 |
+
# predicted_values = model.predict(X_train)
|
631 |
+
mape = mean_absolute_percentage_error(
|
632 |
+
y_train, predicted_values
|
633 |
+
)
|
634 |
+
r2 = r2_score(y_train, predicted_values)
|
635 |
+
adjr2 = 1 - (1 - r2) * (len(y_train) - 1) / (
|
636 |
+
len(y_train) - len(selected_features) - 1
|
637 |
+
)
|
638 |
+
|
639 |
+
filename = os.path.join(save_path, f"model_{i}.pkl")
|
640 |
+
with open(filename, "wb") as f:
|
641 |
+
pickle.dump(mdf, f)
|
642 |
+
# with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
|
643 |
+
# model = pickle.load(file)
|
644 |
+
|
645 |
+
st.session_state["Model_results"]["Model_object"].append(
|
646 |
+
filename
|
647 |
+
)
|
648 |
+
st.session_state["Model_results"][
|
649 |
+
"Model_iteration"
|
650 |
+
].append(i)
|
651 |
+
st.session_state["Model_results"]["Feature_set"].append(
|
652 |
+
fet
|
653 |
+
)
|
654 |
+
st.session_state["Model_results"]["MAPE"].append(mape)
|
655 |
+
st.session_state["Model_results"]["R2"].append(r2)
|
656 |
+
st.session_state["Model_results"]["pos_count"].append(
|
657 |
+
len(model_positive)
|
658 |
+
)
|
659 |
+
st.session_state["Model_results"]["ADJR2"].append(adjr2)
|
660 |
+
|
661 |
+
current_time = time.time()
|
662 |
+
time_taken = current_time - start_time
|
663 |
+
time_elapsed_minutes = time_taken / 60
|
664 |
+
completed_iterations_text = f"{i + 1}/{iterations}"
|
665 |
+
progress_bar.progress((i + 1) / int(iterations))
|
666 |
+
progress_text.text(
|
667 |
+
f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
|
668 |
+
)
|
669 |
+
st.write(
|
670 |
+
f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
|
671 |
+
)
|
672 |
+
|
673 |
+
else:
|
674 |
+
|
675 |
+
for i, selected_features in enumerate(
|
676 |
+
st.session_state["final_selection"][0 : int(iterations)]
|
677 |
+
): # srishti
|
678 |
+
df = st.session_state["media_data"]
|
679 |
+
|
680 |
+
fet = [var for var in selected_features if len(var) > 0]
|
681 |
+
inp_vars_str = " + ".join(fet)
|
682 |
+
|
683 |
+
X = df[fet]
|
684 |
+
y = df[target_col]
|
685 |
+
ss = MinMaxScaler()
|
686 |
+
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
687 |
+
X = sm.add_constant(X)
|
688 |
+
X_train = X.iloc[:130]
|
689 |
+
X_test = X.iloc[130:]
|
690 |
+
y_train = y.iloc[:130]
|
691 |
+
y_test = y.iloc[130:]
|
692 |
+
|
693 |
+
model = sm.OLS(y_train, X_train).fit()
|
694 |
+
|
695 |
+
coefficients = model.params.to_list()
|
696 |
+
model_positive = [coef for coef in coefficients if coef > 0]
|
697 |
+
predicted_values = model.predict(X_train)
|
698 |
+
pvalues = [var for var in list(model.pvalues) if var <= 0.06]
|
699 |
+
|
700 |
+
# if (len(model_possitive) / len(selected_features)) > 0.9 and (len(pvalues) / len(selected_features)) >= 0.8:
|
701 |
+
if (len(model_positive) / len(selected_features)) > 0 and (
|
702 |
+
len(pvalues) / len(selected_features)
|
703 |
+
) >= 0.5: # srishti - changed just for testing, revert later VALID MODEL CRITERIA
|
704 |
+
# predicted_values = model.predict(X_train)
|
705 |
+
mape = mean_absolute_percentage_error(
|
706 |
+
y_train, predicted_values
|
707 |
+
)
|
708 |
+
adjr2 = model.rsquared_adj
|
709 |
+
r2 = model.rsquared
|
710 |
+
|
711 |
+
filename = os.path.join(save_path, f"model_{i}.pkl")
|
712 |
+
with open(filename, "wb") as f:
|
713 |
+
pickle.dump(model, f)
|
714 |
+
# with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
|
715 |
+
# model = pickle.load(file)
|
716 |
+
|
717 |
+
st.session_state["Model_results"]["Model_object"].append(
|
718 |
+
filename
|
719 |
+
)
|
720 |
+
st.session_state["Model_results"][
|
721 |
+
"Model_iteration"
|
722 |
+
].append(i)
|
723 |
+
st.session_state["Model_results"]["Feature_set"].append(
|
724 |
+
fet
|
725 |
+
)
|
726 |
+
st.session_state["Model_results"]["MAPE"].append(mape)
|
727 |
+
st.session_state["Model_results"]["R2"].append(r2)
|
728 |
+
st.session_state["Model_results"]["ADJR2"].append(adjr2)
|
729 |
+
st.session_state["Model_results"]["pos_count"].append(
|
730 |
+
len(model_positive)
|
731 |
+
)
|
732 |
+
|
733 |
+
current_time = time.time()
|
734 |
+
time_taken = current_time - start_time
|
735 |
+
time_elapsed_minutes = time_taken / 60
|
736 |
+
completed_iterations_text = f"{i + 1}/{iterations}"
|
737 |
+
progress_bar.progress((i + 1) / int(iterations))
|
738 |
+
progress_text.text(
|
739 |
+
f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
|
740 |
+
)
|
741 |
+
st.write(
|
742 |
+
f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
|
743 |
+
)
|
744 |
+
|
745 |
+
pd.DataFrame(st.session_state["Model_results"]).to_csv(
|
746 |
+
"model_output.csv"
|
747 |
+
)
|
748 |
+
|
749 |
+
def to_percentage(value):
|
750 |
+
return f"{value * 100:.1f}%"
|
751 |
+
|
752 |
+
## Section 5 - Select Model
|
753 |
+
st.title("2. Select Models")
|
754 |
+
show_results_defualt = (
|
755 |
+
st.session_state["project_dct"]["model_build"]["show_results_check"]
|
756 |
+
if st.session_state["project_dct"]["model_build"]["show_results_check"]
|
757 |
+
is not None
|
758 |
+
else False
|
759 |
+
)
|
760 |
+
if "tick" not in st.session_state:
|
761 |
+
st.session_state["tick"] = False
|
762 |
+
if st.checkbox(
|
763 |
+
"Show results of top 10 models (based on MAPE and Adj. R2)",
|
764 |
+
value=True,
|
765 |
+
):
|
766 |
+
st.session_state["project_dct"]["model_build"][
|
767 |
+
"show_results_check"
|
768 |
+
] = True
|
769 |
+
st.session_state["tick"] = True
|
770 |
+
st.write(
|
771 |
+
"Select one model iteration to generate performance metrics for it:"
|
772 |
+
)
|
773 |
+
data = pd.DataFrame(st.session_state["Model_results"])
|
774 |
+
data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
|
775 |
+
drop=True
|
776 |
+
) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
|
777 |
+
data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
|
778 |
+
data.drop_duplicates(subset="Model_iteration", inplace=True)
|
779 |
+
top_10 = data.head(10)
|
780 |
+
top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
|
781 |
+
top_10[["MAPE", "R2", "ADJR2"]] = np.round(
|
782 |
+
top_10[["MAPE", "R2", "ADJR2"]], 4
|
783 |
+
).applymap(to_percentage)
|
784 |
+
top_10_table = top_10[
|
785 |
+
["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
|
786 |
+
]
|
787 |
+
# top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
|
788 |
+
gd = GridOptionsBuilder.from_dataframe(top_10_table)
|
789 |
+
gd.configure_pagination(enabled=True)
|
790 |
+
|
791 |
+
gd.configure_selection(
|
792 |
+
use_checkbox=True,
|
793 |
+
selection_mode="single",
|
794 |
+
pre_select_all_rows=False,
|
795 |
+
pre_selected_rows=[1],
|
796 |
+
)
|
797 |
+
|
798 |
+
gridoptions = gd.build()
|
799 |
+
|
800 |
+
table = AgGrid(
|
801 |
+
top_10,
|
802 |
+
gridOptions=gridoptions,
|
803 |
+
update_mode=GridUpdateMode.SELECTION_CHANGED,
|
804 |
+
)
|
805 |
+
|
806 |
+
selected_rows = table.selected_rows
|
807 |
+
# if st.session_state["selected_rows"] != selected_rows:
|
808 |
+
# st.session_state["build_rc_cb"] = False
|
809 |
+
st.session_state["selected_rows"] = selected_rows
|
810 |
+
# st.write(
|
811 |
+
# """
|
812 |
+
# ### Filter Results
|
813 |
+
|
814 |
+
# Use the filters below to refine the displayed model results. This helps in isolating models that do not meet the required business criteria, ensuring only the most relevant models are considered for further analysis. If multiple models meet the criteria, select the first model, as it is considered the best-ranked based on evaluation criteria.
|
815 |
+
# """
|
816 |
+
# )
|
817 |
+
|
818 |
+
# data = pd.DataFrame(st.session_state["Model_results"])
|
819 |
+
# coefficients_df, data_df = prepare_data_df(data)
|
820 |
+
|
821 |
+
# # Define the structure of the empty DataFrame
|
822 |
+
# filter_df_data = {
|
823 |
+
# "Channel Name": pd.Series([], dtype="str"),
|
824 |
+
# "Filter Condition": pd.Series([], dtype="str"),
|
825 |
+
# "Percent Contribution": pd.Series([], dtype="str"),
|
826 |
+
# }
|
827 |
+
# filter_df = pd.DataFrame(filter_df_data)
|
828 |
+
|
829 |
+
# filter_df_editable = st.data_editor(
|
830 |
+
# filter_df,
|
831 |
+
# column_config={
|
832 |
+
# "Channel Name": st.column_config.SelectboxColumn(
|
833 |
+
# options=list(coefficients_df.columns),
|
834 |
+
# required=True,
|
835 |
+
# default="Base Sales",
|
836 |
+
# ),
|
837 |
+
# "Filter Condition": st.column_config.SelectboxColumn(
|
838 |
+
# options=[
|
839 |
+
# "<",
|
840 |
+
# ">",
|
841 |
+
# "=",
|
842 |
+
# "<=",
|
843 |
+
# ">=",
|
844 |
+
# ],
|
845 |
+
# required=True,
|
846 |
+
# default=">",
|
847 |
+
# ),
|
848 |
+
# "Percent Contribution": st.column_config.NumberColumn(
|
849 |
+
# required=True, default=0
|
850 |
+
# ),
|
851 |
+
# },
|
852 |
+
# hide_index=True,
|
853 |
+
# use_container_width=True,
|
854 |
+
# num_rows="dynamic",
|
855 |
+
# )
|
856 |
+
|
857 |
+
# # Apply filters from filter_df_editable to data_df
|
858 |
+
# if "filtered_df" not in st.session_state:
|
859 |
+
# st.session_state["filtered_df"] = data_df.copy()
|
860 |
+
|
861 |
+
# if st.button("Filter", args=(data_df)):
|
862 |
+
# st.session_state["filtered_df"] = data_df.copy()
|
863 |
+
# for index, row in filter_df_editable.iterrows():
|
864 |
+
# channel_name = row["Channel Name"]
|
865 |
+
# condition = row["Filter Condition"]
|
866 |
+
# value = row["Percent Contribution"]
|
867 |
+
|
868 |
+
# if channel_name in st.session_state["filtered_df"].columns:
|
869 |
+
# # Construct the query string based on the condition
|
870 |
+
# query_string = f"`{channel_name}` {condition} {value}"
|
871 |
+
# st.session_state["filtered_df"] = st.session_state["filtered_df"].query(
|
872 |
+
# query_string
|
873 |
+
# )
|
874 |
+
|
875 |
+
# # After filtering, check if the DataFrame is empty
|
876 |
+
# if st.session_state["filtered_df"].empty:
|
877 |
+
# # Display a warning message if no rows meet the filter criteria
|
878 |
+
# st.warning("No model meets the specified filter conditions", icon="⚠️")
|
879 |
+
# st.stop() # Optionally stop further execution
|
880 |
+
|
881 |
+
# # Output the filtered data
|
882 |
+
# st.write("Select one model iteration to generate performance metrics for it:")
|
883 |
+
# st.dataframe(st.session_state["filtered_df"], hide_index=True)
|
884 |
+
|
885 |
+
#############################################################################################
|
886 |
+
|
887 |
+
# top_10 = data.head(10)
|
888 |
+
# top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
|
889 |
+
# top_10[["MAPE", "R2", "ADJR2"]] = np.round(
|
890 |
+
# top_10[["MAPE", "R2", "ADJR2"]], 4
|
891 |
+
# ).applymap(to_percentage)
|
892 |
+
|
893 |
+
# top_10_table = top_10[
|
894 |
+
# ["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
|
895 |
+
# + list(coefficients_df.columns)
|
896 |
+
# ]
|
897 |
+
# top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
|
898 |
+
|
899 |
+
# gd = GridOptionsBuilder.from_dataframe(top_10_table)
|
900 |
+
# gd.configure_pagination(enabled=True)
|
901 |
+
|
902 |
+
# gd.configure_selection(
|
903 |
+
# use_checkbox=True,
|
904 |
+
# selection_mode="single",
|
905 |
+
# pre_select_all_rows=False,
|
906 |
+
# pre_selected_rows=[1],
|
907 |
+
# )
|
908 |
+
|
909 |
+
# gridoptions = gd.build()
|
910 |
+
|
911 |
+
# table = AgGrid(
|
912 |
+
# top_10, gridOptions=gridoptions, update_mode=GridUpdateMode.SELECTION_CHANGED
|
913 |
+
# )
|
914 |
+
|
915 |
+
# selected_rows = table.selected_rows
|
916 |
+
|
917 |
+
# gd = GridOptionsBuilder.from_dataframe(st.session_state["filtered_df"])
|
918 |
+
# gd.configure_pagination(enabled=True)
|
919 |
+
|
920 |
+
# gd.configure_selection(
|
921 |
+
# use_checkbox=True,
|
922 |
+
# selection_mode="single",
|
923 |
+
# pre_select_all_rows=False,
|
924 |
+
# pre_selected_rows=[1],
|
925 |
+
# )
|
926 |
+
|
927 |
+
# gridoptions = gd.build()
|
928 |
+
|
929 |
+
# table = AgGrid(
|
930 |
+
# st.session_state["filtered_df"],
|
931 |
+
# gridOptions=gridoptions,
|
932 |
+
# update_mode=GridUpdateMode.SELECTION_CHANGED,
|
933 |
+
# )
|
934 |
+
|
935 |
+
# selected_rows_table = table.selected_rows
|
936 |
+
|
937 |
+
# Dataframe
|
938 |
+
# display_df = st.session_state.filtered_df.rename(columns={"Rank": "Model Number"})
|
939 |
+
# st.dataframe(display_df, hide_index=True)
|
940 |
+
|
941 |
+
# min_rank = min(st.session_state["filtered_df"]["Rank"])
|
942 |
+
# max_rank = max(st.session_state["filtered_df"]["Rank"])
|
943 |
+
# available_ranks = st.session_state["filtered_df"]["Rank"].unique()
|
944 |
+
|
945 |
+
# # Get row number input from the user
|
946 |
+
# rank_number = st.number_input(
|
947 |
+
# "Select model by Model Number:",
|
948 |
+
# min_value=min_rank,
|
949 |
+
# max_value=max_rank,
|
950 |
+
# value=min_rank,
|
951 |
+
# step=1,
|
952 |
+
# )
|
953 |
+
|
954 |
+
# # Get row
|
955 |
+
# if rank_number not in available_ranks:
|
956 |
+
# st.warning("No model is available with selected Rank", icon="⚠️")
|
957 |
+
# st.stop()
|
958 |
+
|
959 |
+
# Find the row that matches the selected rank
|
960 |
+
# selected_rows = st.session_state["filtered_df"][
|
961 |
+
# st.session_state["filtered_df"]["Rank"] == rank_number
|
962 |
+
# ]
|
963 |
+
|
964 |
+
# selected_rows = [
|
965 |
+
# (selected_rows.to_dict(orient="records")[0] if not selected_rows.empty else {})
|
966 |
+
# ]
|
967 |
+
|
968 |
+
# if st.session_state["selected_rows"] != selected_rows:
|
969 |
+
# st.session_state["build_rc_cb"] = False
|
970 |
+
st.session_state["selected_rows"] = selected_rows
|
971 |
+
if "Model" not in st.session_state:
|
972 |
+
st.session_state["Model"] = {}
|
973 |
+
|
974 |
+
# Section 6 - Display Results
|
975 |
+
|
976 |
+
|
977 |
+
# Section 6 - Display Results
|
978 |
+
|
979 |
+
if len(selected_rows) > 0:
|
980 |
+
st.header("2.1 Results Summary")
|
981 |
+
|
982 |
+
model_object = data[
|
983 |
+
data["Model_iteration"] == selected_rows[0]["Model_iteration"]
|
984 |
+
]["Model_object"]
|
985 |
+
features_set = data[
|
986 |
+
data["Model_iteration"] == selected_rows[0]["Model_iteration"]
|
987 |
+
]["Feature_set"]
|
988 |
+
|
989 |
+
with open(str(model_object.values[0]), "rb") as file:
|
990 |
+
# print(file)
|
991 |
+
model = pickle.load(file)
|
992 |
+
st.write(model.summary())
|
993 |
+
st.header("2.2 Actual vs. Predicted Plot")
|
994 |
+
|
995 |
+
if is_panel:
|
996 |
+
df = st.session_state["media_data"]
|
997 |
+
X = df[features_set.values[0]]
|
998 |
+
y = df[target_col]
|
999 |
+
|
1000 |
+
ss = MinMaxScaler()
|
1001 |
+
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
1002 |
+
|
1003 |
+
# Sprint2 changes
|
1004 |
+
X[target_col] = y # new
|
1005 |
+
X[panel_col] = df[panel_col]
|
1006 |
+
X[date_col] = date
|
1007 |
+
|
1008 |
+
X_train = X.iloc[:train_idx]
|
1009 |
+
X_test = X.iloc[train_idx:].reset_index(drop=True)
|
1010 |
+
y_train = y.iloc[:train_idx]
|
1011 |
+
y_test = y.iloc[train_idx:].reset_index(drop=True)
|
1012 |
+
|
1013 |
+
test_spends = spends_data[
|
1014 |
+
train_idx:
|
1015 |
+
] # Sprint3 - test spends for resp curves
|
1016 |
+
random_eff_df = get_random_effects(
|
1017 |
+
media_data, panel_col, model
|
1018 |
+
)
|
1019 |
+
train_pred = model.fittedvalues
|
1020 |
+
test_pred = mdf_predict(X_test, model, random_eff_df)
|
1021 |
+
print("__" * 20, test_pred.isna().sum())
|
1022 |
+
|
1023 |
+
else:
|
1024 |
+
df = st.session_state["media_data"]
|
1025 |
+
X = df[features_set.values[0]]
|
1026 |
+
y = df[target_col]
|
1027 |
+
|
1028 |
+
ss = MinMaxScaler()
|
1029 |
+
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
1030 |
+
X = sm.add_constant(X)
|
1031 |
+
|
1032 |
+
X[date_col] = date
|
1033 |
+
|
1034 |
+
X_train = X.iloc[:130]
|
1035 |
+
X_test = X.iloc[130:].reset_index(drop=True)
|
1036 |
+
y_train = y.iloc[:130]
|
1037 |
+
y_test = y.iloc[130:].reset_index(drop=True)
|
1038 |
+
|
1039 |
+
test_spends = spends_data[
|
1040 |
+
130:
|
1041 |
+
] # Sprint3 - test spends for resp curves
|
1042 |
+
train_pred = model.predict(
|
1043 |
+
X_train[features_set.values[0] + ["const"]]
|
1044 |
+
)
|
1045 |
+
test_pred = model.predict(
|
1046 |
+
X_test[features_set.values[0] + ["const"]]
|
1047 |
+
)
|
1048 |
+
|
1049 |
+
# save x test to test - srishti
|
1050 |
+
# x_test_to_save = X_test.copy()
|
1051 |
+
# x_test_to_save['Actuals'] = y_test
|
1052 |
+
# x_test_to_save['Predictions'] = test_pred
|
1053 |
+
#
|
1054 |
+
# x_train_to_save = X_train.copy()
|
1055 |
+
# x_train_to_save['Actuals'] = y_train
|
1056 |
+
# x_train_to_save['Predictions'] = train_pred
|
1057 |
+
#
|
1058 |
+
# x_train_to_save.to_csv('Test/x_train_to_save.csv', index=False)
|
1059 |
+
# x_test_to_save.to_csv('Test/x_test_to_save.csv', index=False)
|
1060 |
+
|
1061 |
+
st.session_state["X"] = X_train
|
1062 |
+
st.session_state["features_set"] = features_set.values[0]
|
1063 |
+
print(
|
1064 |
+
"**" * 20, "selected model features : ", features_set.values[0]
|
1065 |
+
)
|
1066 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
1067 |
+
plot_actual_vs_predicted(
|
1068 |
+
X_train[date_col],
|
1069 |
+
y_train,
|
1070 |
+
train_pred,
|
1071 |
+
model,
|
1072 |
+
target_column=sel_target_col,
|
1073 |
+
is_panel=is_panel,
|
1074 |
+
)
|
1075 |
+
) # Sprint2
|
1076 |
+
|
1077 |
+
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
|
1078 |
+
|
1079 |
+
st.markdown("## 2.3 Residual Analysis")
|
1080 |
+
columns = st.columns(2)
|
1081 |
+
with columns[0]:
|
1082 |
+
fig = plot_residual_predicted(
|
1083 |
+
y_train, train_pred, X_train
|
1084 |
+
) # Sprint2
|
1085 |
+
st.plotly_chart(fig)
|
1086 |
+
|
1087 |
+
with columns[1]:
|
1088 |
+
st.empty()
|
1089 |
+
fig = qqplot(y_train, train_pred) # Sprint2
|
1090 |
+
st.plotly_chart(fig)
|
1091 |
+
|
1092 |
+
with columns[0]:
|
1093 |
+
fig = residual_distribution(y_train, train_pred) # Sprint2
|
1094 |
+
st.pyplot(fig)
|
1095 |
+
|
1096 |
+
vif_data = pd.DataFrame()
|
1097 |
+
# X=X.drop('const',axis=1)
|
1098 |
+
X_train_orig = (
|
1099 |
+
X_train.copy()
|
1100 |
+
) # Sprint2 -- creating a copy of xtrain. Later deleting panel, target & date from xtrain
|
1101 |
+
del_col_list = list(
|
1102 |
+
set([target_col, panel_col, date_col]).intersection(
|
1103 |
+
set(X_train.columns)
|
1104 |
+
)
|
1105 |
+
)
|
1106 |
+
X_train.drop(columns=del_col_list, inplace=True) # Sprint2
|
1107 |
+
|
1108 |
+
vif_data["Variable"] = X_train.columns
|
1109 |
+
vif_data["VIF"] = [
|
1110 |
+
variance_inflation_factor(X_train.values, i)
|
1111 |
+
for i in range(X_train.shape[1])
|
1112 |
+
]
|
1113 |
+
vif_data.sort_values(by=["VIF"], ascending=False, inplace=True)
|
1114 |
+
vif_data = np.round(vif_data)
|
1115 |
+
vif_data["VIF"] = vif_data["VIF"].astype(float)
|
1116 |
+
st.header("2.4 Variance Inflation Factor (VIF)")
|
1117 |
+
# st.dataframe(vif_data)
|
1118 |
+
color_mapping = {
|
1119 |
+
"darkgreen": (vif_data["VIF"] < 3),
|
1120 |
+
"orange": (vif_data["VIF"] >= 3) & (vif_data["VIF"] <= 10),
|
1121 |
+
"darkred": (vif_data["VIF"] > 10),
|
1122 |
+
}
|
1123 |
+
|
1124 |
+
# Create a horizontal bar plot
|
1125 |
+
fig, ax = plt.subplots()
|
1126 |
+
fig.set_figwidth(10) # Adjust the width of the figure as needed
|
1127 |
+
|
1128 |
+
# Sort the bars by descending VIF values
|
1129 |
+
vif_data = vif_data.sort_values(by="VIF", ascending=False)
|
1130 |
+
|
1131 |
+
# Iterate through the color mapping and plot bars with corresponding colors
|
1132 |
+
for color, condition in color_mapping.items():
|
1133 |
+
subset = vif_data[condition]
|
1134 |
+
bars = ax.barh(
|
1135 |
+
subset["Variable"], subset["VIF"], color=color, label=color
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
# Add text annotations on top of the bars
|
1139 |
+
for bar in bars:
|
1140 |
+
width = bar.get_width()
|
1141 |
+
ax.annotate(
|
1142 |
+
f"{width:}",
|
1143 |
+
xy=(width, bar.get_y() + bar.get_height() / 2),
|
1144 |
+
xytext=(5, 0),
|
1145 |
+
textcoords="offset points",
|
1146 |
+
va="center",
|
1147 |
+
)
|
1148 |
+
|
1149 |
+
# Customize the plot
|
1150 |
+
ax.set_xlabel("VIF Values")
|
1151 |
+
# ax.set_title('2.4 Variance Inflation Factor (VIF)')
|
1152 |
+
# ax.legend(loc='upper right')
|
1153 |
+
|
1154 |
+
# Display the plot in Streamlit
|
1155 |
+
st.pyplot(fig)
|
1156 |
+
|
1157 |
+
with st.expander("Results Summary Test data"):
|
1158 |
+
# ss = MinMaxScaler()
|
1159 |
+
# X_test = pd.DataFrame(ss.fit_transform(X_test), columns=X_test.columns)
|
1160 |
+
st.header("2.2 Actual vs. Predicted Plot")
|
1161 |
+
|
1162 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
1163 |
+
plot_actual_vs_predicted(
|
1164 |
+
X_test[date_col],
|
1165 |
+
y_test,
|
1166 |
+
test_pred,
|
1167 |
+
model,
|
1168 |
+
target_column=sel_target_col,
|
1169 |
+
is_panel=is_panel,
|
1170 |
+
)
|
1171 |
+
) # Sprint2
|
1172 |
+
|
1173 |
+
st.plotly_chart(
|
1174 |
+
actual_vs_predicted_plot, use_container_width=True
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
st.markdown("## 2.3 Residual Analysis")
|
1178 |
+
columns = st.columns(2)
|
1179 |
+
with columns[0]:
|
1180 |
+
fig = plot_residual_predicted(
|
1181 |
+
y, test_pred, X_test
|
1182 |
+
) # Sprint2
|
1183 |
+
st.plotly_chart(fig)
|
1184 |
+
|
1185 |
+
with columns[1]:
|
1186 |
+
st.empty()
|
1187 |
+
fig = qqplot(y, test_pred) # Sprint2
|
1188 |
+
st.plotly_chart(fig)
|
1189 |
+
|
1190 |
+
with columns[0]:
|
1191 |
+
fig = residual_distribution(y, test_pred) # Sprint2
|
1192 |
+
st.pyplot(fig)
|
1193 |
+
|
1194 |
+
value = False
|
1195 |
+
save_button_model = st.checkbox(
|
1196 |
+
"Save this model to tune", key="build_rc_cb"
|
1197 |
+
) # , on_click=set_save())
|
1198 |
+
|
1199 |
+
if save_button_model:
|
1200 |
+
mod_name = st.text_input("Enter model name")
|
1201 |
+
if len(mod_name) > 0:
|
1202 |
+
mod_name = (
|
1203 |
+
mod_name + "__" + target_col
|
1204 |
+
) # Sprint4 - adding target col to model name
|
1205 |
+
if is_panel:
|
1206 |
+
pred_train = model.fittedvalues
|
1207 |
+
pred_test = mdf_predict(X_test, model, random_eff_df)
|
1208 |
+
else:
|
1209 |
+
st.session_state["features_set"] = st.session_state[
|
1210 |
+
"features_set"
|
1211 |
+
] + ["const"]
|
1212 |
+
pred_train = model.predict(
|
1213 |
+
X_train_orig[st.session_state["features_set"]]
|
1214 |
+
)
|
1215 |
+
pred_test = model.predict(
|
1216 |
+
X_test[st.session_state["features_set"]]
|
1217 |
+
)
|
1218 |
+
|
1219 |
+
st.session_state["Model"][mod_name] = {
|
1220 |
+
"Model_object": model,
|
1221 |
+
"feature_set": st.session_state["features_set"],
|
1222 |
+
"X_train": X_train_orig,
|
1223 |
+
"X_test": X_test,
|
1224 |
+
"y_train": y_train,
|
1225 |
+
"y_test": y_test,
|
1226 |
+
"pred_train": pred_train,
|
1227 |
+
"pred_test": pred_test,
|
1228 |
+
}
|
1229 |
+
st.session_state["X_train"] = X_train_orig
|
1230 |
+
st.session_state["X_test_spends"] = test_spends
|
1231 |
+
st.session_state["saved_model_names"].append(mod_name)
|
1232 |
+
# Sprint3 additions
|
1233 |
+
if is_panel:
|
1234 |
+
random_eff_df = get_random_effects(
|
1235 |
+
media_data, panel_col, model
|
1236 |
+
)
|
1237 |
+
st.session_state["random_effects"] = random_eff_df
|
1238 |
+
|
1239 |
+
with open(
|
1240 |
+
os.path.join(
|
1241 |
+
st.session_state["project_path"], "best_models.pkl"
|
1242 |
+
),
|
1243 |
+
"wb",
|
1244 |
+
) as f:
|
1245 |
+
pickle.dump(st.session_state["Model"], f)
|
1246 |
+
st.success(
|
1247 |
+
mod_name
|
1248 |
+
+ " model saved! Proceed to the next page to tune the model"
|
1249 |
+
)
|
1250 |
+
|
1251 |
+
urm = st.session_state["used_response_metrics"]
|
1252 |
+
urm.append(sel_target_col)
|
1253 |
+
st.session_state["used_response_metrics"] = list(
|
1254 |
+
set(urm)
|
1255 |
+
)
|
1256 |
+
mod_name = ""
|
1257 |
+
# Sprint4 - add the formatted name of the target col to used resp metrics
|
1258 |
+
value = False
|
1259 |
+
|
1260 |
+
st.session_state["project_dct"]["model_build"][
|
1261 |
+
"session_state_saved"
|
1262 |
+
] = {}
|
1263 |
+
for key in [
|
1264 |
+
"Model",
|
1265 |
+
"bin_dict",
|
1266 |
+
"used_response_metrics",
|
1267 |
+
"date",
|
1268 |
+
"saved_model_names",
|
1269 |
+
"media_data",
|
1270 |
+
"X_test_spends",
|
1271 |
+
]:
|
1272 |
+
st.session_state["project_dct"]["model_build"][
|
1273 |
+
"session_state_saved"
|
1274 |
+
][key] = st.session_state[key]
|
1275 |
+
|
1276 |
+
project_dct_path = os.path.join(
|
1277 |
+
st.session_state["project_path"], "project_dct.pkl"
|
1278 |
+
)
|
1279 |
+
with open(project_dct_path, "wb") as f:
|
1280 |
+
pickle.dump(st.session_state["project_dct"], f)
|
1281 |
+
|
1282 |
+
update_db("4_Model_Build.py")
|
1283 |
+
|
1284 |
+
st.toast("💾 Saved Successfully!")
|
1285 |
+
else:
|
1286 |
+
st.session_state["project_dct"]["model_build"][
|
1287 |
+
"show_results_check"
|
1288 |
+
] = False
|
pages/5_Model_Tuning.py
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MMO Build Sprint 3
|
3 |
+
date :
|
4 |
+
changes : capability to tune MixedLM as well as simple LR in the same page
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
import streamlit as st
|
10 |
+
import pandas as pd
|
11 |
+
from Eda_functions import format_numbers
|
12 |
+
import pickle
|
13 |
+
from utilities import set_header, load_local_css
|
14 |
+
import statsmodels.api as sm
|
15 |
+
import re
|
16 |
+
from sklearn.preprocessing import MinMaxScaler
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
19 |
+
import yaml
|
20 |
+
from yaml import SafeLoader
|
21 |
+
import streamlit_authenticator as stauth
|
22 |
+
|
23 |
+
st.set_option("deprecation.showPyplotGlobalUse", False)
|
24 |
+
import statsmodels.formula.api as smf
|
25 |
+
from Data_prep_functions import *
|
26 |
+
import sqlite3
|
27 |
+
from utilities import update_db
|
28 |
+
|
29 |
+
# for i in ["model_tuned", "X_train_tuned", "X_test_tuned", "tuned_model_features", "tuned_model", "tuned_model_dict"] :
|
30 |
+
|
31 |
+
st.set_page_config(
|
32 |
+
page_title="Model Tuning",
|
33 |
+
page_icon=":shark:",
|
34 |
+
layout="wide",
|
35 |
+
initial_sidebar_state="collapsed",
|
36 |
+
)
|
37 |
+
load_local_css("styles.css")
|
38 |
+
set_header()
|
39 |
+
# Check for authentication status
|
40 |
+
for k, v in st.session_state.items():
|
41 |
+
# print(k, v)
|
42 |
+
if k not in [
|
43 |
+
"logout",
|
44 |
+
"login",
|
45 |
+
"config",
|
46 |
+
"build_tuned_model",
|
47 |
+
] and not k.startswith("FormSubmitter"):
|
48 |
+
st.session_state[k] = v
|
49 |
+
with open("config.yaml") as file:
|
50 |
+
config = yaml.load(file, Loader=SafeLoader)
|
51 |
+
st.session_state["config"] = config
|
52 |
+
authenticator = stauth.Authenticate(
|
53 |
+
config["credentials"],
|
54 |
+
config["cookie"]["name"],
|
55 |
+
config["cookie"]["key"],
|
56 |
+
config["cookie"]["expiry_days"],
|
57 |
+
config["preauthorized"],
|
58 |
+
)
|
59 |
+
st.session_state["authenticator"] = authenticator
|
60 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
61 |
+
auth_status = st.session_state.get("authentication_status")
|
62 |
+
|
63 |
+
if auth_status == True:
|
64 |
+
authenticator.logout("Logout", "main")
|
65 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
66 |
+
|
67 |
+
if "project_dct" not in st.session_state:
|
68 |
+
st.error("Please load a project from Home page")
|
69 |
+
st.stop()
|
70 |
+
|
71 |
+
if not os.path.exists(
|
72 |
+
os.path.join(st.session_state["project_path"], "best_models.pkl")
|
73 |
+
):
|
74 |
+
st.error("Please save a model before tuning")
|
75 |
+
st.stop()
|
76 |
+
|
77 |
+
conn = sqlite3.connect(
|
78 |
+
r"DB/User.db", check_same_thread=False
|
79 |
+
) # connection with sql db
|
80 |
+
c = conn.cursor()
|
81 |
+
|
82 |
+
if not is_state_initiaized:
|
83 |
+
if "session_name" not in st.session_state:
|
84 |
+
st.session_state["session_name"] = None
|
85 |
+
|
86 |
+
if (
|
87 |
+
"session_state_saved"
|
88 |
+
in st.session_state["project_dct"]["model_build"].keys()
|
89 |
+
):
|
90 |
+
for key in [
|
91 |
+
"Model",
|
92 |
+
"date",
|
93 |
+
"saved_model_names",
|
94 |
+
"media_data",
|
95 |
+
"X_test_spends",
|
96 |
+
]:
|
97 |
+
if key not in st.session_state:
|
98 |
+
st.session_state[key] = st.session_state["project_dct"][
|
99 |
+
"model_build"
|
100 |
+
]["session_state_saved"][key]
|
101 |
+
st.session_state["bin_dict"] = st.session_state["project_dct"][
|
102 |
+
"model_build"
|
103 |
+
]["session_state_saved"]["bin_dict"]
|
104 |
+
if (
|
105 |
+
"used_response_metrics" not in st.session_state
|
106 |
+
or st.session_state["used_response_metrics"] == []
|
107 |
+
):
|
108 |
+
st.session_state["used_response_metrics"] = st.session_state[
|
109 |
+
"project_dct"
|
110 |
+
]["model_build"]["session_state_saved"][
|
111 |
+
"used_response_metrics"
|
112 |
+
]
|
113 |
+
else:
|
114 |
+
st.error("Please load a session with a built model")
|
115 |
+
st.stop()
|
116 |
+
|
117 |
+
# if 'sel_model' not in st.session_state["project_dct"]["model_tuning"].keys():
|
118 |
+
# st.session_state["project_dct"]["model_tuning"]['sel_model']= {}
|
119 |
+
|
120 |
+
for key in ["select_all_flags_check", "selected_flags", "sel_model"]:
|
121 |
+
if key not in st.session_state["project_dct"]["model_tuning"].keys():
|
122 |
+
st.session_state["project_dct"]["model_tuning"][key] = {}
|
123 |
+
# Sprint3
|
124 |
+
# is_panel = st.session_state['is_panel']
|
125 |
+
# panel_col = 'markets' # set the panel column
|
126 |
+
date_col = "date"
|
127 |
+
|
128 |
+
panel_col = [
|
129 |
+
col.lower()
|
130 |
+
.replace(".", "_")
|
131 |
+
.replace("@", "_")
|
132 |
+
.replace(" ", "_")
|
133 |
+
.replace("-", "")
|
134 |
+
.replace(":", "")
|
135 |
+
.replace("__", "_")
|
136 |
+
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
137 |
+
][
|
138 |
+
0
|
139 |
+
] # set the panel column
|
140 |
+
is_panel = True if len(panel_col) > 0 else False
|
141 |
+
|
142 |
+
# flag indicating there is not tuned model till now
|
143 |
+
|
144 |
+
# Sprint4 - model tuned dict
|
145 |
+
if "Model_Tuned" not in st.session_state:
|
146 |
+
st.session_state["Model_Tuned"] = {}
|
147 |
+
|
148 |
+
st.title("1. Model Tuning")
|
149 |
+
|
150 |
+
if "is_tuned_model" not in st.session_state:
|
151 |
+
st.session_state["is_tuned_model"] = {}
|
152 |
+
# Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
|
153 |
+
if (
|
154 |
+
"used_response_metrics" in st.session_state
|
155 |
+
and st.session_state["used_response_metrics"] != []
|
156 |
+
):
|
157 |
+
default_target_idx = (
|
158 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
159 |
+
"sel_target_col", None
|
160 |
+
)
|
161 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
162 |
+
"sel_target_col", None
|
163 |
+
)
|
164 |
+
is not None
|
165 |
+
else st.session_state["used_response_metrics"][0]
|
166 |
+
)
|
167 |
+
|
168 |
+
def format_display(inp):
|
169 |
+
return inp.title().replace("_", " ").strip()
|
170 |
+
|
171 |
+
sel_target_col = st.selectbox(
|
172 |
+
"Select the response metric",
|
173 |
+
st.session_state["used_response_metrics"],
|
174 |
+
index=st.session_state["used_response_metrics"].index(
|
175 |
+
default_target_idx
|
176 |
+
),
|
177 |
+
format_func=format_display,
|
178 |
+
)
|
179 |
+
target_col = (
|
180 |
+
sel_target_col.lower()
|
181 |
+
.replace(" ", "_")
|
182 |
+
.replace("-", "")
|
183 |
+
.replace(":", "")
|
184 |
+
.replace("__", "_")
|
185 |
+
)
|
186 |
+
st.session_state["project_dct"]["model_tuning"][
|
187 |
+
"sel_target_col"
|
188 |
+
] = sel_target_col
|
189 |
+
|
190 |
+
else:
|
191 |
+
sel_target_col = "Total Approved Accounts - Revenue"
|
192 |
+
target_col = "total_approved_accounts_revenue"
|
193 |
+
|
194 |
+
# Sprint4 - Look through all saved models, only show saved models of the sel resp metric (target_col)
|
195 |
+
# saved_models = st.session_state['saved_model_names']
|
196 |
+
with open(
|
197 |
+
os.path.join(st.session_state["project_path"], "best_models.pkl"), "rb"
|
198 |
+
) as file:
|
199 |
+
model_dict = pickle.load(file)
|
200 |
+
|
201 |
+
saved_models = model_dict.keys()
|
202 |
+
required_saved_models = [
|
203 |
+
m.split("__")[0]
|
204 |
+
for m in saved_models
|
205 |
+
if m.split("__")[1] == target_col
|
206 |
+
]
|
207 |
+
|
208 |
+
if len(required_saved_models) > 0:
|
209 |
+
default_model_idx = st.session_state["project_dct"]["model_tuning"][
|
210 |
+
"sel_model"
|
211 |
+
].get(sel_target_col, required_saved_models[0])
|
212 |
+
sel_model = st.selectbox(
|
213 |
+
"Select the model to tune",
|
214 |
+
required_saved_models,
|
215 |
+
index=required_saved_models.index(default_model_idx),
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
default_model_idx = st.session_state["project_dct"]["model_tuning"][
|
219 |
+
"sel_model"
|
220 |
+
].get(sel_target_col, 0)
|
221 |
+
sel_model = st.selectbox(
|
222 |
+
"Select the model to tune", required_saved_models
|
223 |
+
)
|
224 |
+
|
225 |
+
st.session_state["project_dct"]["model_tuning"]["sel_model"][
|
226 |
+
sel_target_col
|
227 |
+
] = default_model_idx
|
228 |
+
|
229 |
+
sel_model_dict = model_dict[
|
230 |
+
sel_model + "__" + target_col
|
231 |
+
] # Sprint4 - get the model obj of the selected model
|
232 |
+
|
233 |
+
X_train = sel_model_dict["X_train"]
|
234 |
+
X_test = sel_model_dict["X_test"]
|
235 |
+
y_train = sel_model_dict["y_train"]
|
236 |
+
y_test = sel_model_dict["y_test"]
|
237 |
+
df = st.session_state["media_data"]
|
238 |
+
|
239 |
+
if "selected_model" not in st.session_state:
|
240 |
+
st.session_state["selected_model"] = 0
|
241 |
+
|
242 |
+
st.markdown("### 1.1 Event Flags")
|
243 |
+
st.markdown(
|
244 |
+
"Helps in quantifying the impact of specific occurrences of events"
|
245 |
+
)
|
246 |
+
|
247 |
+
flag_expander_default = (
|
248 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
249 |
+
"flag_expander", None
|
250 |
+
)
|
251 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
252 |
+
"flag_expander", None
|
253 |
+
)
|
254 |
+
is not None
|
255 |
+
else False
|
256 |
+
)
|
257 |
+
|
258 |
+
with st.expander("Apply Event Flags", flag_expander_default):
|
259 |
+
st.session_state["project_dct"]["model_tuning"]["flag_expander"] = True
|
260 |
+
|
261 |
+
model = sel_model_dict["Model_object"]
|
262 |
+
date = st.session_state["date"]
|
263 |
+
date = pd.to_datetime(date)
|
264 |
+
X_train = sel_model_dict["X_train"]
|
265 |
+
|
266 |
+
# features_set= model_dict[st.session_state["selected_model"]]['feature_set']
|
267 |
+
features_set = sel_model_dict["feature_set"]
|
268 |
+
|
269 |
+
col = st.columns(3)
|
270 |
+
min_date = min(date)
|
271 |
+
max_date = max(date)
|
272 |
+
|
273 |
+
start_date_default = (
|
274 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
275 |
+
"start_date_default"
|
276 |
+
)
|
277 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
278 |
+
"start_date_default"
|
279 |
+
)
|
280 |
+
is not None
|
281 |
+
else min_date
|
282 |
+
)
|
283 |
+
end_date_default = (
|
284 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
285 |
+
"end_date_default"
|
286 |
+
)
|
287 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
288 |
+
"end_date_default"
|
289 |
+
)
|
290 |
+
is not None
|
291 |
+
else max_date
|
292 |
+
)
|
293 |
+
with col[0]:
|
294 |
+
start_date = st.date_input(
|
295 |
+
"Select Start Date",
|
296 |
+
start_date_default,
|
297 |
+
min_value=min_date,
|
298 |
+
max_value=max_date,
|
299 |
+
)
|
300 |
+
with col[1]:
|
301 |
+
end_date_default = (
|
302 |
+
end_date_default
|
303 |
+
if end_date_default >= start_date
|
304 |
+
else start_date
|
305 |
+
)
|
306 |
+
end_date = st.date_input(
|
307 |
+
"Select End Date",
|
308 |
+
end_date_default,
|
309 |
+
min_value=max(min_date, start_date),
|
310 |
+
max_value=max_date,
|
311 |
+
)
|
312 |
+
with col[2]:
|
313 |
+
repeat_default = (
|
314 |
+
st.session_state["project_dct"]["model_tuning"].get(
|
315 |
+
"repeat_default"
|
316 |
+
)
|
317 |
+
if st.session_state["project_dct"]["model_tuning"].get(
|
318 |
+
"repeat_default"
|
319 |
+
)
|
320 |
+
is not None
|
321 |
+
else "No"
|
322 |
+
)
|
323 |
+
repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1
|
324 |
+
repeat = st.selectbox(
|
325 |
+
"Repeat Annually", ["Yes", "No"], index=repeat_default_idx
|
326 |
+
)
|
327 |
+
st.session_state["project_dct"]["model_tuning"][
|
328 |
+
"start_date_default"
|
329 |
+
] = start_date
|
330 |
+
st.session_state["project_dct"]["model_tuning"][
|
331 |
+
"end_date_default"
|
332 |
+
] = end_date
|
333 |
+
st.session_state["project_dct"]["model_tuning"][
|
334 |
+
"repeat_default"
|
335 |
+
] = repeat
|
336 |
+
|
337 |
+
if repeat == "Yes":
|
338 |
+
repeat = True
|
339 |
+
else:
|
340 |
+
repeat = False
|
341 |
+
|
342 |
+
if "Flags" not in st.session_state:
|
343 |
+
st.session_state["Flags"] = {}
|
344 |
+
if "flags" in st.session_state["project_dct"]["model_tuning"].keys():
|
345 |
+
st.session_state["Flags"] = st.session_state["project_dct"][
|
346 |
+
"model_tuning"
|
347 |
+
]["flags"]
|
348 |
+
# print("**"*50)
|
349 |
+
# print(y_train)
|
350 |
+
# print("**"*50)
|
351 |
+
# print(model.fittedvalues)
|
352 |
+
if is_panel: # Sprint3
|
353 |
+
met, line_values, fig_flag = plot_actual_vs_predicted(
|
354 |
+
X_train[date_col],
|
355 |
+
y_train,
|
356 |
+
model.fittedvalues,
|
357 |
+
model,
|
358 |
+
target_column=sel_target_col,
|
359 |
+
flag=(start_date, end_date),
|
360 |
+
repeat_all_years=repeat,
|
361 |
+
is_panel=True,
|
362 |
+
)
|
363 |
+
st.plotly_chart(fig_flag, use_container_width=True)
|
364 |
+
|
365 |
+
# create flag on test
|
366 |
+
met, test_line_values, fig_flag = plot_actual_vs_predicted(
|
367 |
+
X_test[date_col],
|
368 |
+
y_test,
|
369 |
+
sel_model_dict["pred_test"],
|
370 |
+
model,
|
371 |
+
target_column=sel_target_col,
|
372 |
+
flag=(start_date, end_date),
|
373 |
+
repeat_all_years=repeat,
|
374 |
+
is_panel=True,
|
375 |
+
)
|
376 |
+
|
377 |
+
else:
|
378 |
+
pred_train = model.predict(X_train[features_set])
|
379 |
+
met, line_values, fig_flag = plot_actual_vs_predicted(
|
380 |
+
X_train[date_col],
|
381 |
+
y_train,
|
382 |
+
pred_train,
|
383 |
+
model,
|
384 |
+
flag=(start_date, end_date),
|
385 |
+
repeat_all_years=repeat,
|
386 |
+
is_panel=False,
|
387 |
+
)
|
388 |
+
st.plotly_chart(fig_flag, use_container_width=True)
|
389 |
+
|
390 |
+
pred_test = model.predict(X_test[features_set])
|
391 |
+
met, test_line_values, fig_flag = plot_actual_vs_predicted(
|
392 |
+
X_test[date_col],
|
393 |
+
y_test,
|
394 |
+
pred_test,
|
395 |
+
model,
|
396 |
+
flag=(start_date, end_date),
|
397 |
+
repeat_all_years=repeat,
|
398 |
+
is_panel=False,
|
399 |
+
)
|
400 |
+
flag_name = "f1_flag"
|
401 |
+
flag_name = st.text_input("Enter Flag Name")
|
402 |
+
# Sprint4 - add selected target col to flag name
|
403 |
+
if st.button("Update flag"):
|
404 |
+
st.session_state["Flags"][flag_name + "__" + target_col] = {}
|
405 |
+
st.session_state["Flags"][flag_name + "__" + target_col][
|
406 |
+
"train"
|
407 |
+
] = line_values
|
408 |
+
st.session_state["Flags"][flag_name + "__" + target_col][
|
409 |
+
"test"
|
410 |
+
] = test_line_values
|
411 |
+
st.success(f'{flag_name + "__" + target_col} stored')
|
412 |
+
|
413 |
+
st.session_state["project_dct"]["model_tuning"]["flags"] = (
|
414 |
+
st.session_state["Flags"]
|
415 |
+
)
|
416 |
+
# Sprint4 - only show flag created for the particular target col
|
417 |
+
if st.session_state["Flags"] is None:
|
418 |
+
st.session_state["Flags"] = {}
|
419 |
+
target_model_flags = [
|
420 |
+
f.split("__")[0]
|
421 |
+
for f in st.session_state["Flags"].keys()
|
422 |
+
if f.split("__")[1] == target_col
|
423 |
+
]
|
424 |
+
options = list(target_model_flags)
|
425 |
+
selected_options = []
|
426 |
+
num_columns = 4
|
427 |
+
num_rows = -(-len(options) // num_columns)
|
428 |
+
|
429 |
+
tick = False
|
430 |
+
if st.checkbox(
|
431 |
+
"Select all",
|
432 |
+
value=st.session_state["project_dct"]["model_tuning"][
|
433 |
+
"select_all_flags_check"
|
434 |
+
].get(sel_target_col, False),
|
435 |
+
):
|
436 |
+
tick = True
|
437 |
+
st.session_state["project_dct"]["model_tuning"][
|
438 |
+
"select_all_flags_check"
|
439 |
+
][sel_target_col] = True
|
440 |
+
else:
|
441 |
+
st.session_state["project_dct"]["model_tuning"][
|
442 |
+
"select_all_flags_check"
|
443 |
+
][sel_target_col] = False
|
444 |
+
selection_defualts = st.session_state["project_dct"]["model_tuning"][
|
445 |
+
"selected_flags"
|
446 |
+
].get(sel_target_col, [])
|
447 |
+
selected_options = selection_defualts
|
448 |
+
for row in range(num_rows):
|
449 |
+
cols = st.columns(num_columns)
|
450 |
+
for col in cols:
|
451 |
+
if options:
|
452 |
+
option = options.pop(0)
|
453 |
+
option_default = (
|
454 |
+
True if option in selection_defualts else False
|
455 |
+
)
|
456 |
+
selected = col.checkbox(option, value=(tick or option_default))
|
457 |
+
if selected:
|
458 |
+
selected_options.append(option)
|
459 |
+
st.session_state["project_dct"]["model_tuning"]["selected_flags"][
|
460 |
+
sel_target_col
|
461 |
+
] = selected_options
|
462 |
+
|
463 |
+
st.markdown("### 1.2 Select Parameters to Apply")
|
464 |
+
parameters = st.columns(3)
|
465 |
+
with parameters[0]:
|
466 |
+
Trend = st.checkbox(
|
467 |
+
"**Trend**",
|
468 |
+
value=st.session_state["project_dct"]["model_tuning"].get(
|
469 |
+
"trend_check", False
|
470 |
+
),
|
471 |
+
)
|
472 |
+
st.markdown(
|
473 |
+
"Helps account for long-term trends or seasonality that could influence advertising effectiveness"
|
474 |
+
)
|
475 |
+
with parameters[1]:
|
476 |
+
week_number = st.checkbox(
|
477 |
+
"**Week_number**",
|
478 |
+
value=st.session_state["project_dct"]["model_tuning"].get(
|
479 |
+
"week_num_check", False
|
480 |
+
),
|
481 |
+
)
|
482 |
+
st.markdown(
|
483 |
+
"Assists in detecting and incorporating weekly patterns or seasonality"
|
484 |
+
)
|
485 |
+
with parameters[2]:
|
486 |
+
sine_cosine = st.checkbox(
|
487 |
+
"**Sine and Cosine Waves**",
|
488 |
+
value=st.session_state["project_dct"]["model_tuning"].get(
|
489 |
+
"sine_cosine_check", False
|
490 |
+
),
|
491 |
+
)
|
492 |
+
st.markdown(
|
493 |
+
"Helps in capturing cyclical patterns or seasonality in the data"
|
494 |
+
)
|
495 |
+
#
|
496 |
+
# def get_tuned_model():
|
497 |
+
# st.session_state['build_tuned_model']=True
|
498 |
+
|
499 |
+
if st.button(
|
500 |
+
"Build model with Selected Parameters and Flags",
|
501 |
+
key="build_tuned_model",use_container_width=True
|
502 |
+
):
|
503 |
+
new_features = features_set
|
504 |
+
st.header("2.1 Results Summary")
|
505 |
+
# date=list(df.index)
|
506 |
+
# df = df.reset_index(drop=True)
|
507 |
+
# X_train=df[features_set]
|
508 |
+
ss = MinMaxScaler()
|
509 |
+
if is_panel == True:
|
510 |
+
X_train_tuned = X_train[features_set]
|
511 |
+
# X_train_tuned = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
512 |
+
X_train_tuned[target_col] = X_train[target_col]
|
513 |
+
X_train_tuned[date_col] = X_train[date_col]
|
514 |
+
X_train_tuned[panel_col] = X_train[panel_col]
|
515 |
+
|
516 |
+
X_test_tuned = X_test[features_set]
|
517 |
+
# X_test_tuned = pd.DataFrame(ss.transform(X), columns=X.columns)
|
518 |
+
X_test_tuned[target_col] = X_test[target_col]
|
519 |
+
X_test_tuned[date_col] = X_test[date_col]
|
520 |
+
X_test_tuned[panel_col] = X_test[panel_col]
|
521 |
+
|
522 |
+
else:
|
523 |
+
X_train_tuned = X_train[features_set]
|
524 |
+
# X_train_tuned = pd.DataFrame(ss.fit_transform(X_train_tuned), columns=X_train_tuned.columns)
|
525 |
+
|
526 |
+
X_test_tuned = X_test[features_set]
|
527 |
+
# X_test_tuned = pd.DataFrame(ss.transform(X_test_tuned), columns=X_test_tuned.columns)
|
528 |
+
|
529 |
+
for flag in selected_options:
|
530 |
+
# Spirnt4 - added target_col in flag name
|
531 |
+
X_train_tuned[flag] = st.session_state["Flags"][
|
532 |
+
flag + "__" + target_col
|
533 |
+
]["train"]
|
534 |
+
X_test_tuned[flag] = st.session_state["Flags"][
|
535 |
+
flag + "__" + target_col
|
536 |
+
]["test"]
|
537 |
+
|
538 |
+
# test
|
539 |
+
# X_train_tuned.to_csv("Test/X_train_tuned_flag.csv",index=False)
|
540 |
+
# X_test_tuned.to_csv("Test/X_test_tuned_flag.csv",index=False)
|
541 |
+
|
542 |
+
# print("()()"*20,flag, len(st.session_state['Flags'][flag]))
|
543 |
+
if Trend:
|
544 |
+
st.session_state["project_dct"]["model_tuning"][
|
545 |
+
"trend_check"
|
546 |
+
] = True
|
547 |
+
# Sprint3 - group by panel, calculate trend of each panel spearately. Add trend to new feature set
|
548 |
+
if is_panel:
|
549 |
+
newdata = pd.DataFrame()
|
550 |
+
panel_wise_end_point_train = {}
|
551 |
+
for panel, groupdf in X_train_tuned.groupby(panel_col):
|
552 |
+
groupdf.sort_values(date_col, inplace=True)
|
553 |
+
groupdf["Trend"] = np.arange(1, len(groupdf) + 1, 1)
|
554 |
+
newdata = pd.concat([newdata, groupdf])
|
555 |
+
panel_wise_end_point_train[panel] = len(groupdf)
|
556 |
+
X_train_tuned = newdata.copy()
|
557 |
+
|
558 |
+
test_newdata = pd.DataFrame()
|
559 |
+
for panel, test_groupdf in X_test_tuned.groupby(panel_col):
|
560 |
+
test_groupdf.sort_values(date_col, inplace=True)
|
561 |
+
start = panel_wise_end_point_train[panel] + 1
|
562 |
+
end = start + len(test_groupdf) # should be + 1? - Sprint4
|
563 |
+
# print("??"*20, panel, len(test_groupdf), len(np.arange(start, end, 1)), start)
|
564 |
+
test_groupdf["Trend"] = np.arange(start, end, 1)
|
565 |
+
test_newdata = pd.concat([test_newdata, test_groupdf])
|
566 |
+
X_test_tuned = test_newdata.copy()
|
567 |
+
|
568 |
+
new_features = new_features + ["Trend"]
|
569 |
+
|
570 |
+
else:
|
571 |
+
X_train_tuned["Trend"] = np.arange(
|
572 |
+
1, len(X_train_tuned) + 1, 1
|
573 |
+
)
|
574 |
+
X_test_tuned["Trend"] = np.arange(
|
575 |
+
len(X_train_tuned) + 1,
|
576 |
+
len(X_train_tuned) + len(X_test_tuned) + 1,
|
577 |
+
1,
|
578 |
+
)
|
579 |
+
new_features = new_features + ["Trend"]
|
580 |
+
else:
|
581 |
+
st.session_state["project_dct"]["model_tuning"][
|
582 |
+
"trend_check"
|
583 |
+
] = False
|
584 |
+
|
585 |
+
if week_number:
|
586 |
+
st.session_state["project_dct"]["model_tuning"][
|
587 |
+
"week_num_check"
|
588 |
+
] = True
|
589 |
+
# Sprint3 - create weeknumber from date column in xtrain tuned. add week num to new feature set
|
590 |
+
if is_panel:
|
591 |
+
X_train_tuned[date_col] = pd.to_datetime(
|
592 |
+
X_train_tuned[date_col]
|
593 |
+
)
|
594 |
+
X_train_tuned["Week_number"] = X_train_tuned[
|
595 |
+
date_col
|
596 |
+
].dt.day_of_week
|
597 |
+
if X_train_tuned["Week_number"].nunique() == 1:
|
598 |
+
st.write(
|
599 |
+
"All dates in the data are of the same week day. Hence Week number can't be used."
|
600 |
+
)
|
601 |
+
else:
|
602 |
+
X_test_tuned[date_col] = pd.to_datetime(
|
603 |
+
X_test_tuned[date_col]
|
604 |
+
)
|
605 |
+
X_test_tuned["Week_number"] = X_test_tuned[
|
606 |
+
date_col
|
607 |
+
].dt.day_of_week
|
608 |
+
new_features = new_features + ["Week_number"]
|
609 |
+
|
610 |
+
else:
|
611 |
+
date = pd.to_datetime(date.values)
|
612 |
+
X_train_tuned["Week_number"] = pd.to_datetime(
|
613 |
+
X_train[date_col]
|
614 |
+
).dt.day_of_week
|
615 |
+
X_test_tuned["Week_number"] = pd.to_datetime(
|
616 |
+
X_test[date_col]
|
617 |
+
).dt.day_of_week
|
618 |
+
new_features = new_features + ["Week_number"]
|
619 |
+
else:
|
620 |
+
st.session_state["project_dct"]["model_tuning"][
|
621 |
+
"week_num_check"
|
622 |
+
] = False
|
623 |
+
|
624 |
+
if sine_cosine:
|
625 |
+
st.session_state["project_dct"]["model_tuning"][
|
626 |
+
"sine_cosine_check"
|
627 |
+
] = True
|
628 |
+
# Sprint3 - create panel wise sine cosine waves in xtrain tuned. add to new feature set
|
629 |
+
if is_panel:
|
630 |
+
new_features = new_features + ["sine_wave", "cosine_wave"]
|
631 |
+
newdata = pd.DataFrame()
|
632 |
+
newdata_test = pd.DataFrame()
|
633 |
+
groups = X_train_tuned.groupby(panel_col)
|
634 |
+
frequency = 2 * np.pi / 365 # Adjust the frequency as needed
|
635 |
+
|
636 |
+
train_panel_wise_end_point = {}
|
637 |
+
for panel, groupdf in groups:
|
638 |
+
num_samples = len(groupdf)
|
639 |
+
train_panel_wise_end_point[panel] = num_samples
|
640 |
+
days_since_start = np.arange(num_samples)
|
641 |
+
sine_wave = np.sin(frequency * days_since_start)
|
642 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
643 |
+
sine_cosine_df = pd.DataFrame(
|
644 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
645 |
+
)
|
646 |
+
assert len(sine_cosine_df) == len(groupdf)
|
647 |
+
# groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
|
648 |
+
groupdf["sine_wave"] = sine_wave
|
649 |
+
groupdf["cosine_wave"] = cosine_wave
|
650 |
+
newdata = pd.concat([newdata, groupdf])
|
651 |
+
|
652 |
+
X_train_tuned = newdata.copy()
|
653 |
+
|
654 |
+
test_groups = X_test_tuned.groupby(panel_col)
|
655 |
+
for panel, test_groupdf in test_groups:
|
656 |
+
num_samples = len(test_groupdf)
|
657 |
+
start = train_panel_wise_end_point[panel]
|
658 |
+
days_since_start = np.arange(start, start + num_samples, 1)
|
659 |
+
# print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1)))
|
660 |
+
sine_wave = np.sin(frequency * days_since_start)
|
661 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
662 |
+
sine_cosine_df = pd.DataFrame(
|
663 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
664 |
+
)
|
665 |
+
assert len(sine_cosine_df) == len(test_groupdf)
|
666 |
+
# groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
|
667 |
+
test_groupdf["sine_wave"] = sine_wave
|
668 |
+
test_groupdf["cosine_wave"] = cosine_wave
|
669 |
+
newdata_test = pd.concat([newdata_test, test_groupdf])
|
670 |
+
|
671 |
+
X_test_tuned = newdata_test.copy()
|
672 |
+
|
673 |
+
else:
|
674 |
+
new_features = new_features + ["sine_wave", "cosine_wave"]
|
675 |
+
|
676 |
+
num_samples = len(X_train_tuned)
|
677 |
+
frequency = 2 * np.pi / 365 # Adjust the frequency as needed
|
678 |
+
days_since_start = np.arange(num_samples)
|
679 |
+
sine_wave = np.sin(frequency * days_since_start)
|
680 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
681 |
+
sine_cosine_df = pd.DataFrame(
|
682 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
683 |
+
)
|
684 |
+
# Concatenate the sine and cosine waves with the scaled X DataFrame
|
685 |
+
X_train_tuned = pd.concat(
|
686 |
+
[X_train_tuned, sine_cosine_df], axis=1
|
687 |
+
)
|
688 |
+
|
689 |
+
test_num_samples = len(X_test_tuned)
|
690 |
+
start = num_samples
|
691 |
+
days_since_start = np.arange(
|
692 |
+
start, start + test_num_samples, 1
|
693 |
+
)
|
694 |
+
sine_wave = np.sin(frequency * days_since_start)
|
695 |
+
cosine_wave = np.cos(frequency * days_since_start)
|
696 |
+
sine_cosine_df = pd.DataFrame(
|
697 |
+
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
698 |
+
)
|
699 |
+
# Concatenate the sine and cosine waves with the scaled X DataFrame
|
700 |
+
X_test_tuned = pd.concat(
|
701 |
+
[X_test_tuned, sine_cosine_df], axis=1
|
702 |
+
)
|
703 |
+
else:
|
704 |
+
st.session_state["project_dct"]["model_tuning"][
|
705 |
+
"sine_cosine_check"
|
706 |
+
] = False
|
707 |
+
|
708 |
+
# model
|
709 |
+
if selected_options:
|
710 |
+
new_features = new_features + selected_options
|
711 |
+
if is_panel:
|
712 |
+
inp_vars_str = " + ".join(new_features)
|
713 |
+
new_features = list(set(new_features))
|
714 |
+
|
715 |
+
md_str = target_col + " ~ " + inp_vars_str
|
716 |
+
md_tuned = smf.mixedlm(
|
717 |
+
md_str,
|
718 |
+
data=X_train_tuned[[target_col] + new_features],
|
719 |
+
groups=X_train_tuned[panel_col],
|
720 |
+
)
|
721 |
+
model_tuned = md_tuned.fit()
|
722 |
+
|
723 |
+
# plot act v pred for original model and tuned model
|
724 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
725 |
+
plot_actual_vs_predicted(
|
726 |
+
X_train[date_col],
|
727 |
+
y_train,
|
728 |
+
model.fittedvalues,
|
729 |
+
model,
|
730 |
+
target_column=sel_target_col,
|
731 |
+
is_panel=True,
|
732 |
+
)
|
733 |
+
)
|
734 |
+
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
|
735 |
+
plot_actual_vs_predicted(
|
736 |
+
X_train_tuned[date_col],
|
737 |
+
X_train_tuned[target_col],
|
738 |
+
model_tuned.fittedvalues,
|
739 |
+
model_tuned,
|
740 |
+
target_column=sel_target_col,
|
741 |
+
is_panel=True,
|
742 |
+
)
|
743 |
+
)
|
744 |
+
|
745 |
+
else:
|
746 |
+
new_features = list(set(new_features))
|
747 |
+
model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit()
|
748 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
749 |
+
plot_actual_vs_predicted(
|
750 |
+
date[:130],
|
751 |
+
y_train,
|
752 |
+
model.predict(X_train[features_set]),
|
753 |
+
model,
|
754 |
+
target_column=sel_target_col,
|
755 |
+
)
|
756 |
+
)
|
757 |
+
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
|
758 |
+
plot_actual_vs_predicted(
|
759 |
+
date[:130],
|
760 |
+
y_train,
|
761 |
+
model_tuned.predict(X_train_tuned),
|
762 |
+
model_tuned,
|
763 |
+
target_column=sel_target_col,
|
764 |
+
)
|
765 |
+
)
|
766 |
+
|
767 |
+
mape = np.round(metrics_table.iloc[0, 1], 2)
|
768 |
+
r2 = np.round(metrics_table.iloc[1, 1], 2)
|
769 |
+
adjr2 = np.round(metrics_table.iloc[2, 1], 2)
|
770 |
+
|
771 |
+
mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2)
|
772 |
+
r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2)
|
773 |
+
adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2)
|
774 |
+
|
775 |
+
parameters_ = st.columns(3)
|
776 |
+
with parameters_[0]:
|
777 |
+
st.metric("R2", r2_tuned, np.round(r2_tuned - r2, 2))
|
778 |
+
with parameters_[1]:
|
779 |
+
st.metric(
|
780 |
+
"Adjusted R2", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2)
|
781 |
+
)
|
782 |
+
with parameters_[2]:
|
783 |
+
st.metric(
|
784 |
+
"MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse"
|
785 |
+
)
|
786 |
+
st.write(model_tuned.summary())
|
787 |
+
|
788 |
+
X_train_tuned[date_col] = X_train[date_col]
|
789 |
+
X_test_tuned[date_col] = X_test[date_col]
|
790 |
+
X_train_tuned[target_col] = y_train
|
791 |
+
X_test_tuned[target_col] = y_test
|
792 |
+
|
793 |
+
st.header("2.2 Actual vs. Predicted Plot")
|
794 |
+
# if is_panel:
|
795 |
+
# metrics_table, line, actual_vs_predicted_plot = plot_actual_vs_predicted(date, y_train, model.predict(X_train),
|
796 |
+
# model, target_column='Revenue',is_panel=True)
|
797 |
+
# else:
|
798 |
+
# metrics_table,line,actual_vs_predicted_plot=plot_actual_vs_predicted(date, y_train, model.predict(X_train), model,target_column='Revenue')
|
799 |
+
if is_panel:
|
800 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
801 |
+
plot_actual_vs_predicted(
|
802 |
+
X_train_tuned[date_col],
|
803 |
+
X_train_tuned[target_col],
|
804 |
+
model_tuned.fittedvalues,
|
805 |
+
model_tuned,
|
806 |
+
target_column=sel_target_col,
|
807 |
+
is_panel=True,
|
808 |
+
)
|
809 |
+
)
|
810 |
+
else:
|
811 |
+
metrics_table, line, actual_vs_predicted_plot = (
|
812 |
+
plot_actual_vs_predicted(
|
813 |
+
X_train_tuned[date_col],
|
814 |
+
X_train_tuned[target_col],
|
815 |
+
model_tuned.predict(X_train_tuned[new_features]),
|
816 |
+
model_tuned,
|
817 |
+
target_column=sel_target_col,
|
818 |
+
is_panel=False,
|
819 |
+
)
|
820 |
+
)
|
821 |
+
# plot_actual_vs_predicted(X_train[date_col], y_train,
|
822 |
+
# model.fittedvalues, model,
|
823 |
+
# target_column='Revenue',
|
824 |
+
# is_panel=is_panel)
|
825 |
+
|
826 |
+
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
|
827 |
+
|
828 |
+
st.markdown("## 2.3 Residual Analysis")
|
829 |
+
if is_panel:
|
830 |
+
columns = st.columns(2)
|
831 |
+
with columns[0]:
|
832 |
+
fig = plot_residual_predicted(
|
833 |
+
y_train, model_tuned.fittedvalues, X_train_tuned
|
834 |
+
)
|
835 |
+
st.plotly_chart(fig)
|
836 |
+
|
837 |
+
with columns[1]:
|
838 |
+
st.empty()
|
839 |
+
fig = qqplot(y_train, model_tuned.fittedvalues)
|
840 |
+
st.plotly_chart(fig)
|
841 |
+
|
842 |
+
with columns[0]:
|
843 |
+
fig = residual_distribution(y_train, model_tuned.fittedvalues)
|
844 |
+
st.pyplot(fig)
|
845 |
+
else:
|
846 |
+
columns = st.columns(2)
|
847 |
+
with columns[0]:
|
848 |
+
fig = plot_residual_predicted(
|
849 |
+
y_train,
|
850 |
+
model_tuned.predict(X_train_tuned[new_features]),
|
851 |
+
X_train,
|
852 |
+
)
|
853 |
+
st.plotly_chart(fig)
|
854 |
+
|
855 |
+
with columns[1]:
|
856 |
+
st.empty()
|
857 |
+
fig = qqplot(
|
858 |
+
y_train, model_tuned.predict(X_train_tuned[new_features])
|
859 |
+
)
|
860 |
+
st.plotly_chart(fig)
|
861 |
+
|
862 |
+
with columns[0]:
|
863 |
+
fig = residual_distribution(
|
864 |
+
y_train, model_tuned.predict(X_train_tuned[new_features])
|
865 |
+
)
|
866 |
+
st.pyplot(fig)
|
867 |
+
|
868 |
+
# st.session_state['is_tuned_model'][target_col] = True
|
869 |
+
# Sprint4 - saved tuned model in a dict
|
870 |
+
st.session_state["Model_Tuned"][sel_model + "__" + target_col] = {
|
871 |
+
"Model_object": model_tuned,
|
872 |
+
"feature_set": new_features,
|
873 |
+
"X_train_tuned": X_train_tuned,
|
874 |
+
"X_test_tuned": X_test_tuned,
|
875 |
+
}
|
876 |
+
|
877 |
+
# Pending
|
878 |
+
# if st.session_state['build_tuned_model']==True:
|
879 |
+
if st.session_state["Model_Tuned"] is not None:
|
880 |
+
if st.button(
|
881 |
+
"Use This model for Media Planning",use_container_width=True
|
882 |
+
):
|
883 |
+
# save_model = st.button('Use this model to build response curves', key='saved_tuned_model')
|
884 |
+
# if save_model:
|
885 |
+
st.session_state["is_tuned_model"][target_col] = True
|
886 |
+
with open(
|
887 |
+
os.path.join(
|
888 |
+
st.session_state["project_path"], "tuned_model.pkl"
|
889 |
+
),
|
890 |
+
"wb",
|
891 |
+
) as f:
|
892 |
+
# pickle.dump(st.session_state['tuned_model'], f)
|
893 |
+
pickle.dump(st.session_state["Model_Tuned"], f) # Sprint4
|
894 |
+
|
895 |
+
st.session_state["project_dct"]["model_tuning"][
|
896 |
+
"session_state_saved"
|
897 |
+
] = {}
|
898 |
+
for key in [
|
899 |
+
"bin_dict",
|
900 |
+
"used_response_metrics",
|
901 |
+
"is_tuned_model",
|
902 |
+
"media_data",
|
903 |
+
"X_test_spends",
|
904 |
+
]:
|
905 |
+
st.session_state["project_dct"]["model_tuning"][
|
906 |
+
"session_state_saved"
|
907 |
+
][key] = st.session_state[key]
|
908 |
+
|
909 |
+
project_dct_path = os.path.join(
|
910 |
+
st.session_state["project_path"], "project_dct.pkl"
|
911 |
+
)
|
912 |
+
with open(project_dct_path, "wb") as f:
|
913 |
+
pickle.dump(st.session_state["project_dct"], f)
|
914 |
+
|
915 |
+
update_db("5_Model_Tuning.py")
|
916 |
+
|
917 |
+
st.success(sel_model + "__" + target_col + " Tuned saved!")
|
pages/6_AI_Model_Results.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.express as px
|
2 |
+
import numpy as np
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
import streamlit as st
|
5 |
+
import pandas as pd
|
6 |
+
import statsmodels.api as sm
|
7 |
+
from sklearn.metrics import mean_absolute_percentage_error
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
from utilities import set_header, load_local_css, load_authenticator
|
11 |
+
import seaborn as sns
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import sweetviz as sv
|
14 |
+
import tempfile
|
15 |
+
from sklearn.preprocessing import MinMaxScaler
|
16 |
+
from st_aggrid import AgGrid
|
17 |
+
from st_aggrid import GridOptionsBuilder, GridUpdateMode
|
18 |
+
from st_aggrid import GridOptionsBuilder
|
19 |
+
import sys
|
20 |
+
import re
|
21 |
+
import pickle
|
22 |
+
from sklearn.metrics import r2_score, mean_absolute_percentage_error
|
23 |
+
from Data_prep_functions import plot_actual_vs_predicted
|
24 |
+
import sqlite3
|
25 |
+
from utilities import update_db
|
26 |
+
|
27 |
+
sys.setrecursionlimit(10**6)
|
28 |
+
|
29 |
+
original_stdout = sys.stdout
|
30 |
+
sys.stdout = open("temp_stdout.txt", "w")
|
31 |
+
sys.stdout.close()
|
32 |
+
sys.stdout = original_stdout
|
33 |
+
|
34 |
+
st.set_page_config(layout="wide")
|
35 |
+
load_local_css("styles.css")
|
36 |
+
set_header()
|
37 |
+
|
38 |
+
# TODO :
|
39 |
+
## 1. Add non panel model support
|
40 |
+
## 2. EDA Function
|
41 |
+
|
42 |
+
for k, v in st.session_state.items():
|
43 |
+
if k not in ["logout", "login", "config"] and not k.startswith("FormSubmitter"):
|
44 |
+
st.session_state[k] = v
|
45 |
+
|
46 |
+
authenticator = st.session_state.get("authenticator")
|
47 |
+
if authenticator is None:
|
48 |
+
authenticator = load_authenticator()
|
49 |
+
|
50 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
51 |
+
auth_status = st.session_state.get("authentication_status")
|
52 |
+
|
53 |
+
if auth_status == True:
|
54 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
55 |
+
if not is_state_initiaized:
|
56 |
+
if "session_name" not in st.session_state:
|
57 |
+
st.session_state["session_name"] = None
|
58 |
+
|
59 |
+
if "project_dct" not in st.session_state:
|
60 |
+
st.error("Please load a project from Home page")
|
61 |
+
st.stop()
|
62 |
+
|
63 |
+
conn = sqlite3.connect(
|
64 |
+
r"DB/User.db", check_same_thread=False
|
65 |
+
) # connection with sql db
|
66 |
+
c = conn.cursor()
|
67 |
+
|
68 |
+
if not os.path.exists(
|
69 |
+
os.path.join(st.session_state["project_path"], "tuned_model.pkl")
|
70 |
+
):
|
71 |
+
st.error("Please save a tuned model")
|
72 |
+
st.stop()
|
73 |
+
|
74 |
+
if (
|
75 |
+
"session_state_saved" in st.session_state["project_dct"]["model_tuning"].keys()
|
76 |
+
and st.session_state["project_dct"]["model_tuning"]["session_state_saved"] != []
|
77 |
+
):
|
78 |
+
for key in ["used_response_metrics", "media_data", "bin_dict"]:
|
79 |
+
if key not in st.session_state:
|
80 |
+
st.session_state[key] = st.session_state["project_dct"]["model_tuning"][
|
81 |
+
"session_state_saved"
|
82 |
+
][key]
|
83 |
+
st.session_state["bin_dict"] = st.session_state["project_dct"][
|
84 |
+
"model_build"
|
85 |
+
]["session_state_saved"]["bin_dict"]
|
86 |
+
|
87 |
+
media_data = st.session_state["media_data"]
|
88 |
+
|
89 |
+
st.write(media_data.columns)
|
90 |
+
|
91 |
+
panel_col = [
|
92 |
+
col.lower()
|
93 |
+
.replace(".", "_")
|
94 |
+
.replace("@", "_")
|
95 |
+
.replace(" ", "_")
|
96 |
+
.replace("-", "")
|
97 |
+
.replace(":", "")
|
98 |
+
.replace("__", "_")
|
99 |
+
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
100 |
+
][
|
101 |
+
0
|
102 |
+
] # set the panel column
|
103 |
+
is_panel = True if len(panel_col) > 0 else False
|
104 |
+
date_col = "date"
|
105 |
+
|
106 |
+
def plot_residual_predicted(actual, predicted, df_):
|
107 |
+
df_["Residuals"] = actual - pd.Series(predicted)
|
108 |
+
df_["StdResidual"] = (df_["Residuals"] - df_["Residuals"].mean()) / df_[
|
109 |
+
"Residuals"
|
110 |
+
].std()
|
111 |
+
|
112 |
+
# Create a Plotly scatter plot
|
113 |
+
fig = px.scatter(
|
114 |
+
df_,
|
115 |
+
x=predicted,
|
116 |
+
y="StdResidual",
|
117 |
+
opacity=0.5,
|
118 |
+
color_discrete_sequence=["#11B6BD"],
|
119 |
+
)
|
120 |
+
|
121 |
+
# Add horizontal lines
|
122 |
+
fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
|
123 |
+
fig.add_hline(y=2, line_color="red")
|
124 |
+
fig.add_hline(y=-2, line_color="red")
|
125 |
+
|
126 |
+
fig.update_xaxes(title="Predicted")
|
127 |
+
fig.update_yaxes(title="Standardized Residuals (Actual - Predicted)")
|
128 |
+
|
129 |
+
# Set the same width and height for both figures
|
130 |
+
fig.update_layout(
|
131 |
+
title="Residuals over Predicted Values",
|
132 |
+
autosize=False,
|
133 |
+
width=600,
|
134 |
+
height=400,
|
135 |
+
)
|
136 |
+
|
137 |
+
return fig
|
138 |
+
|
139 |
+
def residual_distribution(actual, predicted):
|
140 |
+
Residuals = actual - pd.Series(predicted)
|
141 |
+
|
142 |
+
# Create a Seaborn distribution plot
|
143 |
+
sns.set(style="whitegrid")
|
144 |
+
plt.figure(figsize=(6, 4))
|
145 |
+
sns.histplot(Residuals, kde=True, color="#11B6BD")
|
146 |
+
|
147 |
+
plt.title(" Distribution of Residuals")
|
148 |
+
plt.xlabel("Residuals")
|
149 |
+
plt.ylabel("Probability Density")
|
150 |
+
|
151 |
+
return plt
|
152 |
+
|
153 |
+
def qqplot(actual, predicted):
|
154 |
+
Residuals = actual - pd.Series(predicted)
|
155 |
+
Residuals = pd.Series(Residuals)
|
156 |
+
Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
|
157 |
+
|
158 |
+
# Create a QQ plot using Plotly with custom colors
|
159 |
+
fig = go.Figure()
|
160 |
+
fig.add_trace(
|
161 |
+
go.Scatter(
|
162 |
+
x=sm.ProbPlot(Resud_std).theoretical_quantiles,
|
163 |
+
y=sm.ProbPlot(Resud_std).sample_quantiles,
|
164 |
+
mode="markers",
|
165 |
+
marker=dict(size=5, color="#11B6BD"),
|
166 |
+
name="QQ Plot",
|
167 |
+
)
|
168 |
+
)
|
169 |
+
|
170 |
+
# Add the 45-degree reference line
|
171 |
+
diagonal_line = go.Scatter(
|
172 |
+
x=[
|
173 |
+
-2,
|
174 |
+
2,
|
175 |
+
], # Adjust the x values as needed to fit the range of your data
|
176 |
+
y=[-2, 2], # Adjust the y values accordingly
|
177 |
+
mode="lines",
|
178 |
+
line=dict(color="red"), # Customize the line color and style
|
179 |
+
name=" ",
|
180 |
+
)
|
181 |
+
fig.add_trace(diagonal_line)
|
182 |
+
|
183 |
+
# Customize the layout
|
184 |
+
fig.update_layout(
|
185 |
+
title="QQ Plot of Residuals",
|
186 |
+
title_x=0.5,
|
187 |
+
autosize=False,
|
188 |
+
width=600,
|
189 |
+
height=400,
|
190 |
+
xaxis_title="Theoretical Quantiles",
|
191 |
+
yaxis_title="Sample Quantiles",
|
192 |
+
)
|
193 |
+
|
194 |
+
return fig
|
195 |
+
|
196 |
+
def get_random_effects(media_data, panel_col, mdf):
|
197 |
+
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
198 |
+
for i, market in enumerate(media_data[panel_col].unique()):
|
199 |
+
print(i, end="\r")
|
200 |
+
intercept = mdf.random_effects[market].values[0]
|
201 |
+
random_eff_df.loc[i, "random_effect"] = intercept
|
202 |
+
random_eff_df.loc[i, panel_col] = market
|
203 |
+
|
204 |
+
return random_eff_df
|
205 |
+
|
206 |
+
def mdf_predict(X_df, mdf, random_eff_df):
|
207 |
+
X = X_df.copy()
|
208 |
+
X = pd.merge(
|
209 |
+
X,
|
210 |
+
random_eff_df[[panel_col, "random_effect"]],
|
211 |
+
on=panel_col,
|
212 |
+
how="left",
|
213 |
+
)
|
214 |
+
X["pred_fixed_effect"] = mdf.predict(X)
|
215 |
+
|
216 |
+
X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
|
217 |
+
X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
|
218 |
+
return X
|
219 |
+
|
220 |
+
def metrics_df_panel(model_dict):
|
221 |
+
metrics_df = pd.DataFrame(
|
222 |
+
columns=[
|
223 |
+
"Model",
|
224 |
+
"R2",
|
225 |
+
"ADJR2",
|
226 |
+
"Train Mape",
|
227 |
+
"Test Mape",
|
228 |
+
"Summary",
|
229 |
+
"Model_object",
|
230 |
+
]
|
231 |
+
)
|
232 |
+
i = 0
|
233 |
+
for key in model_dict.keys():
|
234 |
+
target = key.split("__")[1]
|
235 |
+
metrics_df.at[i, "Model"] = target
|
236 |
+
y = model_dict[key]["X_train_tuned"][target]
|
237 |
+
|
238 |
+
random_df = get_random_effects(
|
239 |
+
media_data, panel_col, model_dict[key]["Model_object"]
|
240 |
+
)
|
241 |
+
pred = mdf_predict(
|
242 |
+
model_dict[key]["X_train_tuned"],
|
243 |
+
model_dict[key]["Model_object"],
|
244 |
+
random_df,
|
245 |
+
)["pred"]
|
246 |
+
|
247 |
+
ytest = model_dict[key]["X_test_tuned"][target]
|
248 |
+
predtest = mdf_predict(
|
249 |
+
model_dict[key]["X_test_tuned"],
|
250 |
+
model_dict[key]["Model_object"],
|
251 |
+
random_df,
|
252 |
+
)["pred"]
|
253 |
+
|
254 |
+
metrics_df.at[i, "R2"] = r2_score(y, pred)
|
255 |
+
metrics_df.at[i, "ADJR2"] = 1 - (1 - metrics_df.loc[i, "R2"]) * (
|
256 |
+
len(y) - 1
|
257 |
+
) / (len(y) - len(model_dict[key]["feature_set"]) - 1)
|
258 |
+
metrics_df.at[i, "Train Mape"] = mean_absolute_percentage_error(y, pred)
|
259 |
+
metrics_df.at[i, "Test Mape"] = mean_absolute_percentage_error(
|
260 |
+
ytest, predtest
|
261 |
+
)
|
262 |
+
metrics_df.at[i, "Summary"] = model_dict[key]["Model_object"].summary()
|
263 |
+
metrics_df.at[i, "Model_object"] = model_dict[key]["Model_object"]
|
264 |
+
i += 1
|
265 |
+
metrics_df = np.round(metrics_df, 2)
|
266 |
+
return metrics_df
|
267 |
+
|
268 |
+
with open(
|
269 |
+
os.path.join(st.session_state["project_path"], "final_df_transformed.pkl"),
|
270 |
+
"rb",
|
271 |
+
) as f:
|
272 |
+
data = pickle.load(f)
|
273 |
+
transformed_data = data["final_df_transformed"]
|
274 |
+
with open(
|
275 |
+
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
|
276 |
+
) as f:
|
277 |
+
data = pickle.load(f)
|
278 |
+
st.session_state["bin_dict"] = data["bin_dict"]
|
279 |
+
with open(
|
280 |
+
os.path.join(st.session_state["project_path"], "tuned_model.pkl"), "rb"
|
281 |
+
) as file:
|
282 |
+
tuned_model_dict = pickle.load(file)
|
283 |
+
feature_set_dct = {
|
284 |
+
key.split("__")[1]: key_dict["feature_set"]
|
285 |
+
for key, key_dict in tuned_model_dict.items()
|
286 |
+
}
|
287 |
+
|
288 |
+
# """ the above part should be modified so that we are fetching features set from the saved model"""
|
289 |
+
|
290 |
+
def contributions(X, model, target):
|
291 |
+
X1 = X.copy()
|
292 |
+
for j, col in enumerate(X1.columns):
|
293 |
+
X1[col] = X1[col] * model.params.values[j]
|
294 |
+
|
295 |
+
contributions = np.round(
|
296 |
+
(X1.sum() / sum(X1.sum()) * 100).sort_values(ascending=False), 2
|
297 |
+
)
|
298 |
+
contributions = (
|
299 |
+
pd.DataFrame(contributions, columns=target)
|
300 |
+
.reset_index()
|
301 |
+
.rename(columns={"index": "Channel"})
|
302 |
+
)
|
303 |
+
contributions["Channel"] = [
|
304 |
+
re.split(r"_imp|_cli", col)[0] for col in contributions["Channel"]
|
305 |
+
]
|
306 |
+
|
307 |
+
return contributions
|
308 |
+
|
309 |
+
if "contribution_df" not in st.session_state:
|
310 |
+
st.session_state["contribution_df"] = None
|
311 |
+
|
312 |
+
def contributions_panel(model_dict):
|
313 |
+
media_data = st.session_state["media_data"]
|
314 |
+
contribution_df = pd.DataFrame(columns=["Channel"])
|
315 |
+
for key in model_dict.keys():
|
316 |
+
best_feature_set = model_dict[key]["feature_set"]
|
317 |
+
model = model_dict[key]["Model_object"]
|
318 |
+
target = key.split("__")[1]
|
319 |
+
X_train = model_dict[key]["X_train_tuned"]
|
320 |
+
contri_df = pd.DataFrame()
|
321 |
+
|
322 |
+
y = []
|
323 |
+
y_pred = []
|
324 |
+
|
325 |
+
random_eff_df = get_random_effects(media_data, panel_col, model)
|
326 |
+
random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
|
327 |
+
random_eff_df["panel_effect"] = (
|
328 |
+
random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
|
329 |
+
)
|
330 |
+
|
331 |
+
coef_df = pd.DataFrame(model.fe_params)
|
332 |
+
coef_df.reset_index(inplace=True)
|
333 |
+
coef_df.columns = ["feature", "coef"]
|
334 |
+
|
335 |
+
x_train_contribution = X_train.copy()
|
336 |
+
x_train_contribution = mdf_predict(
|
337 |
+
x_train_contribution, model, random_eff_df
|
338 |
+
)
|
339 |
+
|
340 |
+
x_train_contribution = pd.merge(
|
341 |
+
x_train_contribution,
|
342 |
+
random_eff_df[[panel_col, "panel_effect"]],
|
343 |
+
on=panel_col,
|
344 |
+
how="left",
|
345 |
+
)
|
346 |
+
|
347 |
+
for i in range(len(coef_df))[1:]:
|
348 |
+
coef = coef_df.loc[i, "coef"]
|
349 |
+
col = coef_df.loc[i, "feature"]
|
350 |
+
x_train_contribution[str(col) + "_contr"] = (
|
351 |
+
coef * x_train_contribution[col]
|
352 |
+
)
|
353 |
+
|
354 |
+
# x_train_contribution['sum_contributions'] = x_train_contribution.filter(regex="contr").sum(axis=1)
|
355 |
+
# x_train_contribution['sum_contributions'] = x_train_contribution['sum_contributions'] + x_train_contribution[
|
356 |
+
# 'panel_effect']
|
357 |
+
|
358 |
+
base_cols = ["panel_effect"] + [
|
359 |
+
c
|
360 |
+
for c in x_train_contribution.filter(regex="contr").columns
|
361 |
+
if c
|
362 |
+
in [
|
363 |
+
"Week_number_contr",
|
364 |
+
"Trend_contr",
|
365 |
+
"sine_wave_contr",
|
366 |
+
"cosine_wave_contr",
|
367 |
+
]
|
368 |
+
]
|
369 |
+
x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(
|
370 |
+
axis=1
|
371 |
+
)
|
372 |
+
x_train_contribution.drop(columns=base_cols, inplace=True)
|
373 |
+
# x_train_contribution.to_csv("Test/smr_x_train_contribution.csv", index=False)
|
374 |
+
|
375 |
+
contri_df = pd.DataFrame(
|
376 |
+
x_train_contribution.filter(regex="contr").sum(axis=0)
|
377 |
+
)
|
378 |
+
contri_df.reset_index(inplace=True)
|
379 |
+
contri_df.columns = ["Channel", target]
|
380 |
+
contri_df["Channel"] = (
|
381 |
+
contri_df["Channel"]
|
382 |
+
.str.split("(_impres|_clicks)")
|
383 |
+
.apply(lambda c: c[0])
|
384 |
+
)
|
385 |
+
contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
|
386 |
+
contri_df["Channel"].replace("base_contr", "base", inplace=True)
|
387 |
+
contribution_df = pd.merge(
|
388 |
+
contribution_df, contri_df, on="Channel", how="outer"
|
389 |
+
)
|
390 |
+
# st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
|
391 |
+
return contribution_df
|
392 |
+
|
393 |
+
metrics_table = metrics_df_panel(tuned_model_dict)
|
394 |
+
|
395 |
+
st.title("AI Model Results")
|
396 |
+
|
397 |
+
st.header('Contribution Overview')
|
398 |
+
|
399 |
+
options = st.session_state["used_response_metrics"]
|
400 |
+
st.write(options)
|
401 |
+
|
402 |
+
options = [
|
403 |
+
opt.lower()
|
404 |
+
.replace(" ", "_")
|
405 |
+
.replace("-", "")
|
406 |
+
.replace(":", "")
|
407 |
+
.replace("__", "_")
|
408 |
+
for opt in options
|
409 |
+
]
|
410 |
+
|
411 |
+
default_options = (
|
412 |
+
st.session_state["project_dct"]["saved_model_results"].get("selected_options")
|
413 |
+
if st.session_state["project_dct"]["saved_model_results"].get(
|
414 |
+
"selected_options"
|
415 |
+
)
|
416 |
+
is not None
|
417 |
+
else [options[-1]]
|
418 |
+
)
|
419 |
+
for i in default_options:
|
420 |
+
if i not in options:
|
421 |
+
st.write(i)
|
422 |
+
default_options.remove(i)
|
423 |
+
|
424 |
+
def format_display(inp):
|
425 |
+
return inp.title().replace("_", " ").strip()
|
426 |
+
|
427 |
+
contribution_selections = st.multiselect(
|
428 |
+
"Select the Response Metrics to compare contributions",
|
429 |
+
options,
|
430 |
+
default=default_options,
|
431 |
+
format_func=format_display,
|
432 |
+
)
|
433 |
+
trace_data = []
|
434 |
+
|
435 |
+
st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
|
436 |
+
st.write(st.session_state["contribution_df"].columns)
|
437 |
+
# for selection in contribution_selections:
|
438 |
+
|
439 |
+
# trace = go.Bar(
|
440 |
+
# x=st.session_state["contribution_df"]["Channel"],
|
441 |
+
# y=st.session_state["contribution_df"][selection],
|
442 |
+
# name=selection,
|
443 |
+
# text=np.round(st.session_state["contribution_df"][selection], 0)
|
444 |
+
# .astype(int)
|
445 |
+
# .astype(str)
|
446 |
+
# + "%",
|
447 |
+
# textposition="outside",
|
448 |
+
# )
|
449 |
+
# trace_data.append(trace)
|
450 |
+
|
451 |
+
# layout = go.Layout(
|
452 |
+
# title="Metrics Contribution by Channel",
|
453 |
+
# xaxis=dict(title="Channel Name"),
|
454 |
+
# yaxis=dict(title="Metrics Contribution"),
|
455 |
+
# barmode="group",
|
456 |
+
# )
|
457 |
+
# fig = go.Figure(data=trace_data, layout=layout)
|
458 |
+
# st.plotly_chart(fig, use_container_width=True)
|
459 |
+
|
460 |
+
def create_grouped_bar_plot(contribution_df, contribution_selections):
|
461 |
+
# Extract the 'Channel' names
|
462 |
+
channel_names = contribution_df["Channel"].tolist()
|
463 |
+
|
464 |
+
# Dictionary to store all contributions except 'const' and 'base'
|
465 |
+
all_contributions = {
|
466 |
+
name: [] for name in channel_names if name not in ["const", "base"]
|
467 |
+
}
|
468 |
+
|
469 |
+
# Dictionary to store base sales for each selection
|
470 |
+
base_sales_dict = {}
|
471 |
+
|
472 |
+
# Accumulate contributions for each channel from each selection
|
473 |
+
for selection in contribution_selections:
|
474 |
+
contributions = contribution_df[selection].values.astype(float)
|
475 |
+
base_sales = 0 # Initialize base sales for the current selection
|
476 |
+
|
477 |
+
for channel_name, contribution in zip(channel_names, contributions):
|
478 |
+
if channel_name in all_contributions:
|
479 |
+
all_contributions[channel_name].append(contribution)
|
480 |
+
elif channel_name == "base":
|
481 |
+
base_sales = (
|
482 |
+
contribution # Capture base sales for the current selection
|
483 |
+
)
|
484 |
+
|
485 |
+
# Store base sales for each selection
|
486 |
+
base_sales_dict[selection] = base_sales
|
487 |
+
|
488 |
+
# Calculate the average of contributions and sort by this average
|
489 |
+
sorted_channels = sorted(
|
490 |
+
all_contributions.items(), key=lambda x: -np.mean(x[1])
|
491 |
+
)
|
492 |
+
sorted_channel_names = [name for name, _ in sorted_channels]
|
493 |
+
sorted_channel_names = [
|
494 |
+
"Base Sales"
|
495 |
+
] + sorted_channel_names # Adding 'Base Sales' at the start
|
496 |
+
|
497 |
+
trace_data = []
|
498 |
+
max_value = (
|
499 |
+
0 # Initialize max_value to find the highest bar for y-axis adjustment
|
500 |
+
)
|
501 |
+
|
502 |
+
# Create traces for the grouped bar chart
|
503 |
+
for selection in contribution_selections:
|
504 |
+
display_name = sorted_channel_names
|
505 |
+
display_contribution = [base_sales_dict[selection]] + [
|
506 |
+
np.mean(all_contributions[name]) for name in sorted_channel_names[1:]
|
507 |
+
] # Start with base sales for the current selection
|
508 |
+
|
509 |
+
# Generating text labels for each bar
|
510 |
+
text_values = [
|
511 |
+
f"{val}%" for val in np.round(display_contribution, 0).astype(int)
|
512 |
+
]
|
513 |
+
|
514 |
+
# Find the max value for y-axis calculation
|
515 |
+
max_contribution = max(display_contribution)
|
516 |
+
if max_contribution > max_value:
|
517 |
+
max_value = max_contribution
|
518 |
+
|
519 |
+
# Create a bar trace for each selection
|
520 |
+
trace = go.Bar(
|
521 |
+
x=display_name,
|
522 |
+
y=display_contribution,
|
523 |
+
name=selection,
|
524 |
+
text=text_values,
|
525 |
+
textposition="outside",
|
526 |
+
)
|
527 |
+
trace_data.append(trace)
|
528 |
+
|
529 |
+
# Define layout for the bar chart
|
530 |
+
layout = go.Layout(
|
531 |
+
title="Metrics Contribution by Channel",
|
532 |
+
xaxis=dict(title="Channel Name"),
|
533 |
+
yaxis=dict(
|
534 |
+
title="Metrics Contribution", range=[0, max_value * 1.2]
|
535 |
+
), # Set y-axis 20% higher than the max bar
|
536 |
+
barmode="group",
|
537 |
+
plot_bgcolor="white",
|
538 |
+
)
|
539 |
+
|
540 |
+
# Create the figure with trace data and layout
|
541 |
+
fig = go.Figure(data=trace_data, layout=layout)
|
542 |
+
|
543 |
+
return fig
|
544 |
+
|
545 |
+
# Display the chart in Streamlit
|
546 |
+
st.plotly_chart(
|
547 |
+
create_grouped_bar_plot(
|
548 |
+
st.session_state["contribution_df"], contribution_selections
|
549 |
+
),
|
550 |
+
use_container_width=True,
|
551 |
+
)
|
552 |
+
|
553 |
+
############################################ Waterfall Chart ############################################
|
554 |
+
|
555 |
+
import plotly.graph_objects as go
|
556 |
+
|
557 |
+
# # Initialize a Plotly figure
|
558 |
+
# fig = go.Figure()
|
559 |
+
|
560 |
+
# for selection in contribution_selections:
|
561 |
+
# # Ensure contributions are numeric
|
562 |
+
# contributions = (
|
563 |
+
# st.session_state["contribution_df"][selection].values.astype(float).tolist()
|
564 |
+
# )
|
565 |
+
# channel_names = st.session_state["contribution_df"]["Channel"].tolist()
|
566 |
+
|
567 |
+
# display_name, display_contribution, base_contribution = [], [], 0
|
568 |
+
# for channel_name, contribution in zip(channel_names, contributions):
|
569 |
+
# if channel_name != "const" and channel_name != "base":
|
570 |
+
# display_name.append(channel_name)
|
571 |
+
# display_contribution.append(contribution)
|
572 |
+
# else:
|
573 |
+
# base_contribution = contribution
|
574 |
+
|
575 |
+
# display_name = ["Base Sales"] + display_name
|
576 |
+
# display_contribution = [base_contribution] + display_contribution
|
577 |
+
|
578 |
+
# # Generating text labels for each bar, ensuring operations are compatible with string formats
|
579 |
+
# text_values = [
|
580 |
+
# f"{val}%" for val in np.round(display_contribution, 0).astype(int)
|
581 |
+
# ]
|
582 |
+
|
583 |
+
# fig.add_trace(
|
584 |
+
# go.Waterfall(
|
585 |
+
# orientation="v",
|
586 |
+
# measure=["relative"] * len(display_contribution),
|
587 |
+
# x=display_name,
|
588 |
+
# text=text_values,
|
589 |
+
# textposition="outside",
|
590 |
+
# y=display_contribution,
|
591 |
+
# increasing={"marker": {"color": "green"}},
|
592 |
+
# decreasing={"marker": {"color": "red"}},
|
593 |
+
# totals={"marker": {"color": "blue"}},
|
594 |
+
# name=selection,
|
595 |
+
# )
|
596 |
+
# )
|
597 |
+
|
598 |
+
# fig.update_layout(
|
599 |
+
# title="Metrics Contribution by Channel",
|
600 |
+
# xaxis={"title": "Channel Name"},
|
601 |
+
# yaxis={"title": "Metrics Contribution"},
|
602 |
+
# height=600,
|
603 |
+
# )
|
604 |
+
|
605 |
+
# # Displaying the waterfall chart in Streamlit
|
606 |
+
# st.plotly_chart(fig, use_container_width=True)
|
607 |
+
|
608 |
+
def preprocess_and_plot(contribution_df, contribution_selections):
|
609 |
+
# Extract the 'Channel' names
|
610 |
+
channel_names = contribution_df["Channel"].tolist()
|
611 |
+
|
612 |
+
# Dictionary to store all contributions except 'const' and 'base'
|
613 |
+
all_contributions = {
|
614 |
+
name: [] for name in channel_names if name not in ["const", "base"]
|
615 |
+
}
|
616 |
+
|
617 |
+
# Dictionary to store base sales for each selection
|
618 |
+
base_sales_dict = {}
|
619 |
+
|
620 |
+
# Accumulate contributions for each channel from each selection
|
621 |
+
for selection in contribution_selections:
|
622 |
+
contributions = contribution_df[selection].values.astype(float)
|
623 |
+
base_sales = 0 # Initialize base sales for the current selection
|
624 |
+
|
625 |
+
for channel_name, contribution in zip(channel_names, contributions):
|
626 |
+
if channel_name in all_contributions:
|
627 |
+
all_contributions[channel_name].append(contribution)
|
628 |
+
elif channel_name == "base":
|
629 |
+
base_sales = (
|
630 |
+
contribution # Capture base sales for the current selection
|
631 |
+
)
|
632 |
+
|
633 |
+
# Store base sales for each selection
|
634 |
+
base_sales_dict[selection] = base_sales
|
635 |
+
|
636 |
+
# Calculate the average of contributions and sort by this average
|
637 |
+
sorted_channels = sorted(
|
638 |
+
all_contributions.items(), key=lambda x: -np.mean(x[1])
|
639 |
+
)
|
640 |
+
sorted_channel_names = [name for name, _ in sorted_channels]
|
641 |
+
sorted_channel_names = [
|
642 |
+
"Base Sales"
|
643 |
+
] + sorted_channel_names # Adding 'Base Sales' at the start
|
644 |
+
|
645 |
+
# Initialize a Plotly figure
|
646 |
+
fig = go.Figure()
|
647 |
+
|
648 |
+
for selection in contribution_selections:
|
649 |
+
display_name = ["Base Sales"] + sorted_channel_names[
|
650 |
+
1:
|
651 |
+
] # Channel names for the plot
|
652 |
+
display_contribution = [
|
653 |
+
base_sales_dict[selection]
|
654 |
+
] # Start with base sales for the current selection
|
655 |
+
|
656 |
+
# Append average contributions for other channels
|
657 |
+
for name in sorted_channel_names[1:]:
|
658 |
+
display_contribution.append(np.mean(all_contributions[name]))
|
659 |
+
|
660 |
+
# Generating text labels for each bar
|
661 |
+
text_values = [
|
662 |
+
f"{val}%" for val in np.round(display_contribution, 0).astype(int)
|
663 |
+
]
|
664 |
+
|
665 |
+
# Add a waterfall trace for each selection
|
666 |
+
fig.add_trace(
|
667 |
+
go.Waterfall(
|
668 |
+
orientation="v",
|
669 |
+
measure=["relative"] * len(display_contribution),
|
670 |
+
x=display_name,
|
671 |
+
text=text_values,
|
672 |
+
textposition="outside",
|
673 |
+
y=display_contribution,
|
674 |
+
increasing={"marker": {"color": "green"}},
|
675 |
+
decreasing={"marker": {"color": "red"}},
|
676 |
+
totals={"marker": {"color": "blue"}},
|
677 |
+
name=selection,
|
678 |
+
)
|
679 |
+
)
|
680 |
+
|
681 |
+
# Update layout of the figure
|
682 |
+
fig.update_layout(
|
683 |
+
title="Metrics Contribution by Channel",
|
684 |
+
xaxis={"title": "Channel Name"},
|
685 |
+
yaxis=dict(title="Metrics Contribution", range=[0, 100 * 1.2]),
|
686 |
+
)
|
687 |
+
|
688 |
+
return fig
|
689 |
+
|
690 |
+
# Displaying the waterfall chart
|
691 |
+
st.plotly_chart(
|
692 |
+
preprocess_and_plot(
|
693 |
+
st.session_state["contribution_df"], contribution_selections
|
694 |
+
),
|
695 |
+
use_container_width=True,
|
696 |
+
)
|
697 |
+
|
698 |
+
############################################ Waterfall Chart ############################################
|
699 |
+
|
700 |
+
st.header("Analysis of Models Result")
|
701 |
+
# st.markdown()
|
702 |
+
previous_selection = st.session_state["project_dct"]["saved_model_results"].get(
|
703 |
+
"model_grid_sel", [1]
|
704 |
+
)
|
705 |
+
# st.write(np.round(metrics_table, 2))
|
706 |
+
gd_table = metrics_table.iloc[:, :-2]
|
707 |
+
|
708 |
+
gd = GridOptionsBuilder.from_dataframe(gd_table)
|
709 |
+
# gd.configure_pagination(enabled=True)
|
710 |
+
gd.configure_selection(
|
711 |
+
use_checkbox=True,
|
712 |
+
selection_mode="single",
|
713 |
+
pre_select_all_rows=False,
|
714 |
+
pre_selected_rows=previous_selection,
|
715 |
+
)
|
716 |
+
|
717 |
+
gridoptions = gd.build()
|
718 |
+
table = AgGrid(
|
719 |
+
gd_table,
|
720 |
+
gridOptions=gridoptions,
|
721 |
+
fit_columns_on_grid_load=True,
|
722 |
+
height=200,
|
723 |
+
)
|
724 |
+
# table=metrics_table.iloc[:,:-2]
|
725 |
+
# table.insert(0, "Select", False)
|
726 |
+
# selection_table=st.data_editor(table,column_config={"Select": st.column_config.CheckboxColumn(required=True)})
|
727 |
+
if len(table.selected_rows) > 0:
|
728 |
+
st.session_state["project_dct"]["saved_model_results"]["model_grid_sel"] = (
|
729 |
+
table.selected_rows[0]["_selectedRowNodeInfo"]["nodeRowIndex"]
|
730 |
+
)
|
731 |
+
if len(table.selected_rows) == 0:
|
732 |
+
st.warning(
|
733 |
+
"Click on the checkbox to view comprehensive results of the selected model."
|
734 |
+
)
|
735 |
+
st.stop()
|
736 |
+
else:
|
737 |
+
target_column = table.selected_rows[0]["Model"]
|
738 |
+
feature_set = feature_set_dct[target_column]
|
739 |
+
|
740 |
+
|
741 |
+
model = metrics_table[metrics_table["Model"] == target_column]["Model_object"].iloc[
|
742 |
+
0
|
743 |
+
]
|
744 |
+
target = metrics_table[metrics_table["Model"] == target_column]["Model"].iloc[0]
|
745 |
+
st.header("Model Summary")
|
746 |
+
st.write(model.summary())
|
747 |
+
|
748 |
+
sel_dict = tuned_model_dict[
|
749 |
+
[k for k in tuned_model_dict.keys() if k.split("__")[1] == target][0]
|
750 |
+
]
|
751 |
+
X_train = sel_dict["X_train_tuned"]
|
752 |
+
y_train = X_train[target]
|
753 |
+
random_effects = get_random_effects(media_data, panel_col, model)
|
754 |
+
pred = mdf_predict(X_train, model, random_effects)["pred"]
|
755 |
+
|
756 |
+
X_test = sel_dict["X_test_tuned"]
|
757 |
+
y_test = X_test[target]
|
758 |
+
predtest = mdf_predict(X_test, model, random_effects)["pred"]
|
759 |
+
metrics_table_train, _, fig_train = plot_actual_vs_predicted(
|
760 |
+
X_train[date_col],
|
761 |
+
y_train,
|
762 |
+
pred,
|
763 |
+
model,
|
764 |
+
target_column=target_column,
|
765 |
+
flag=None,
|
766 |
+
repeat_all_years=False,
|
767 |
+
is_panel=is_panel,
|
768 |
+
)
|
769 |
+
|
770 |
+
metrics_table_test, _, fig_test = plot_actual_vs_predicted(
|
771 |
+
X_test[date_col],
|
772 |
+
y_test,
|
773 |
+
predtest,
|
774 |
+
model,
|
775 |
+
target_column=target_column,
|
776 |
+
flag=None,
|
777 |
+
repeat_all_years=False,
|
778 |
+
is_panel=is_panel,
|
779 |
+
)
|
780 |
+
|
781 |
+
metrics_table_train = metrics_table_train.set_index("Metric").transpose()
|
782 |
+
metrics_table_train.index = ["Train"]
|
783 |
+
metrics_table_test = metrics_table_test.set_index("Metric").transpose()
|
784 |
+
metrics_table_test.index = ["Test"]
|
785 |
+
metrics_table = np.round(pd.concat([metrics_table_train, metrics_table_test]), 2)
|
786 |
+
|
787 |
+
st.markdown("Result Overview")
|
788 |
+
st.dataframe(np.round(metrics_table, 2), use_container_width=True)
|
789 |
+
|
790 |
+
st.subheader("Actual vs Predicted Plot Train")
|
791 |
+
|
792 |
+
st.plotly_chart(fig_train, use_container_width=True)
|
793 |
+
st.subheader("Actual vs Predicted Plot Test")
|
794 |
+
st.plotly_chart(fig_test, use_container_width=True)
|
795 |
+
|
796 |
+
st.markdown("## Residual Analysis")
|
797 |
+
columns = st.columns(2)
|
798 |
+
|
799 |
+
Xtrain1 = X_train.copy()
|
800 |
+
with columns[0]:
|
801 |
+
fig = plot_residual_predicted(y_train, model.predict(Xtrain1), Xtrain1)
|
802 |
+
st.plotly_chart(fig)
|
803 |
+
|
804 |
+
with columns[1]:
|
805 |
+
st.empty()
|
806 |
+
fig = qqplot(y_train, model.predict(X_train))
|
807 |
+
st.plotly_chart(fig)
|
808 |
+
|
809 |
+
with columns[0]:
|
810 |
+
fig = residual_distribution(y_train, model.predict(X_train))
|
811 |
+
st.pyplot(fig)
|
812 |
+
|
813 |
+
update_db("6_AI_Model_Result.py")
|
814 |
+
|
815 |
+
|
816 |
+
elif auth_status == False:
|
817 |
+
st.error("Username/Password is incorrect")
|
818 |
+
try:
|
819 |
+
username_forgot_pw, email_forgot_password, random_password = (
|
820 |
+
authenticator.forgot_password("Forgot password")
|
821 |
+
)
|
822 |
+
if username_forgot_pw:
|
823 |
+
st.success("New password sent securely")
|
824 |
+
# Random password to be transferred to the user securely
|
825 |
+
elif username_forgot_pw == False:
|
826 |
+
st.error("Username not found")
|
827 |
+
except Exception as e:
|
828 |
+
st.error(e)
|
pages/7_Current_Media_Performance.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MMO Build Sprint 3
|
3 |
+
additions : contributions calculated using tuned Mixed LM model
|
4 |
+
pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
|
5 |
+
|
6 |
+
MMO Build Sprint 4
|
7 |
+
additions : response metrics selection
|
8 |
+
pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
|
9 |
+
"""
|
10 |
+
|
11 |
+
import streamlit as st
|
12 |
+
import pandas as pd
|
13 |
+
from sklearn.preprocessing import MinMaxScaler
|
14 |
+
import pickle
|
15 |
+
import os
|
16 |
+
|
17 |
+
from utilities_with_panel import load_local_css, set_header
|
18 |
+
import yaml
|
19 |
+
from yaml import SafeLoader
|
20 |
+
import streamlit_authenticator as stauth
|
21 |
+
import sqlite3
|
22 |
+
from utilities import update_db
|
23 |
+
|
24 |
+
st.set_page_config(layout="wide")
|
25 |
+
load_local_css("styles.css")
|
26 |
+
set_header()
|
27 |
+
for k, v in st.session_state.items():
|
28 |
+
# print(k, v)
|
29 |
+
if k not in [
|
30 |
+
"logout",
|
31 |
+
"login",
|
32 |
+
"config",
|
33 |
+
"build_tuned_model",
|
34 |
+
] and not k.startswith("FormSubmitter"):
|
35 |
+
st.session_state[k] = v
|
36 |
+
with open("config.yaml") as file:
|
37 |
+
config = yaml.load(file, Loader=SafeLoader)
|
38 |
+
st.session_state["config"] = config
|
39 |
+
authenticator = stauth.Authenticate(
|
40 |
+
config["credentials"],
|
41 |
+
config["cookie"]["name"],
|
42 |
+
config["cookie"]["key"],
|
43 |
+
config["cookie"]["expiry_days"],
|
44 |
+
config["preauthorized"],
|
45 |
+
)
|
46 |
+
st.session_state["authenticator"] = authenticator
|
47 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
48 |
+
auth_status = st.session_state.get("authentication_status")
|
49 |
+
|
50 |
+
if auth_status == True:
|
51 |
+
authenticator.logout("Logout", "main")
|
52 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
53 |
+
|
54 |
+
if "project_dct" not in st.session_state:
|
55 |
+
st.error("Please load a project from Home page")
|
56 |
+
st.stop()
|
57 |
+
|
58 |
+
conn = sqlite3.connect(
|
59 |
+
r"DB/User.db", check_same_thread=False
|
60 |
+
) # connection with sql db
|
61 |
+
c = conn.cursor()
|
62 |
+
|
63 |
+
if not os.path.exists(
|
64 |
+
os.path.join(st.session_state["project_path"], "tuned_model.pkl")
|
65 |
+
):
|
66 |
+
st.error("Please save a tuned model")
|
67 |
+
st.stop()
|
68 |
+
|
69 |
+
if (
|
70 |
+
"session_state_saved"
|
71 |
+
in st.session_state["project_dct"]["model_tuning"].keys()
|
72 |
+
and st.session_state["project_dct"]["model_tuning"][
|
73 |
+
"session_state_saved"
|
74 |
+
]
|
75 |
+
!= []
|
76 |
+
):
|
77 |
+
for key in [
|
78 |
+
"used_response_metrics",
|
79 |
+
"is_tuned_model",
|
80 |
+
"media_data",
|
81 |
+
"X_test_spends",
|
82 |
+
]:
|
83 |
+
st.session_state[key] = st.session_state["project_dct"][
|
84 |
+
"model_tuning"
|
85 |
+
]["session_state_saved"][key]
|
86 |
+
elif (
|
87 |
+
"session_state_saved"
|
88 |
+
in st.session_state["project_dct"]["model_build"].keys()
|
89 |
+
and st.session_state["project_dct"]["model_build"][
|
90 |
+
"session_state_saved"
|
91 |
+
]
|
92 |
+
!= []
|
93 |
+
):
|
94 |
+
for key in [
|
95 |
+
"used_response_metrics",
|
96 |
+
"date",
|
97 |
+
"saved_model_names",
|
98 |
+
"media_data",
|
99 |
+
"X_test_spends",
|
100 |
+
]:
|
101 |
+
st.session_state[key] = st.session_state["project_dct"][
|
102 |
+
"model_build"
|
103 |
+
]["session_state_saved"][key]
|
104 |
+
else:
|
105 |
+
st.error("Please tune a model first")
|
106 |
+
st.session_state["bin_dict"] = st.session_state["project_dct"][
|
107 |
+
"model_build"
|
108 |
+
]["session_state_saved"]["bin_dict"]
|
109 |
+
st.session_state["media_data"].columns = [
|
110 |
+
c.lower() for c in st.session_state["media_data"].columns
|
111 |
+
]
|
112 |
+
|
113 |
+
from utilities_with_panel import (
|
114 |
+
overview_test_data_prep_panel,
|
115 |
+
overview_test_data_prep_nonpanel,
|
116 |
+
initialize_data,
|
117 |
+
create_channel_summary,
|
118 |
+
create_contribution_pie,
|
119 |
+
create_contribuion_stacked_plot,
|
120 |
+
create_channel_spends_sales_plot,
|
121 |
+
format_numbers,
|
122 |
+
channel_name_formating,
|
123 |
+
)
|
124 |
+
|
125 |
+
import plotly.graph_objects as go
|
126 |
+
import streamlit_authenticator as stauth
|
127 |
+
import yaml
|
128 |
+
from yaml import SafeLoader
|
129 |
+
import time
|
130 |
+
|
131 |
+
def get_random_effects(media_data, panel_col, mdf):
|
132 |
+
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
133 |
+
for i, market in enumerate(media_data[panel_col].unique()):
|
134 |
+
print(i, end="\r")
|
135 |
+
intercept = mdf.random_effects[market].values[0]
|
136 |
+
random_eff_df.loc[i, "random_effect"] = intercept
|
137 |
+
random_eff_df.loc[i, panel_col] = market
|
138 |
+
|
139 |
+
return random_eff_df
|
140 |
+
|
141 |
+
def process_train_and_test(train, test, features, panel_col, target_col):
|
142 |
+
X1 = train[features]
|
143 |
+
|
144 |
+
ss = MinMaxScaler()
|
145 |
+
X1 = pd.DataFrame(ss.fit_transform(X1), columns=X1.columns)
|
146 |
+
|
147 |
+
X1[panel_col] = train[panel_col]
|
148 |
+
X1[target_col] = train[target_col]
|
149 |
+
|
150 |
+
if test is not None:
|
151 |
+
X2 = test[features]
|
152 |
+
X2 = pd.DataFrame(ss.transform(X2), columns=X2.columns)
|
153 |
+
X2[panel_col] = test[panel_col]
|
154 |
+
X2[target_col] = test[target_col]
|
155 |
+
return X1, X2
|
156 |
+
return X1
|
157 |
+
|
158 |
+
def mdf_predict(X_df, mdf, random_eff_df):
|
159 |
+
X = X_df.copy()
|
160 |
+
X = pd.merge(
|
161 |
+
X,
|
162 |
+
random_eff_df[[panel_col, "random_effect"]],
|
163 |
+
on=panel_col,
|
164 |
+
how="left",
|
165 |
+
)
|
166 |
+
X["pred_fixed_effect"] = mdf.predict(X)
|
167 |
+
|
168 |
+
X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
|
169 |
+
X.to_csv("Test/merged_df_contri.csv", index=False)
|
170 |
+
X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
|
171 |
+
|
172 |
+
return X
|
173 |
+
|
174 |
+
# target='Revenue'
|
175 |
+
|
176 |
+
# is_panel=False
|
177 |
+
# is_panel = st.session_state['is_panel']
|
178 |
+
panel_col = [
|
179 |
+
col.lower()
|
180 |
+
.replace(".", "_")
|
181 |
+
.replace("@", "_")
|
182 |
+
.replace(" ", "_")
|
183 |
+
.replace("-", "")
|
184 |
+
.replace(":", "")
|
185 |
+
.replace("__", "_")
|
186 |
+
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
187 |
+
][
|
188 |
+
0
|
189 |
+
] # set the panel column
|
190 |
+
is_panel = True if len(panel_col) > 0 else False
|
191 |
+
date_col = "date"
|
192 |
+
|
193 |
+
# Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
|
194 |
+
if (
|
195 |
+
"used_response_metrics" in st.session_state
|
196 |
+
and st.session_state["used_response_metrics"] != []
|
197 |
+
):
|
198 |
+
sel_target_col = st.selectbox(
|
199 |
+
"Select the response metric",
|
200 |
+
st.session_state["used_response_metrics"],
|
201 |
+
)
|
202 |
+
target_col = (
|
203 |
+
sel_target_col.lower()
|
204 |
+
.replace(" ", "_")
|
205 |
+
.replace("-", "")
|
206 |
+
.replace(":", "")
|
207 |
+
.replace("__", "_")
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
sel_target_col = "Total Approved Accounts - Revenue"
|
211 |
+
target_col = "total_approved_accounts_revenue"
|
212 |
+
|
213 |
+
target = sel_target_col
|
214 |
+
|
215 |
+
# Sprint4 - Look through all saved tuned models, only show saved models of the sel resp metric (target_col)
|
216 |
+
# saved_models = st.session_state['saved_model_names']
|
217 |
+
# Sprint4 - get the model obj of the selected model
|
218 |
+
# st.write(sel_model_dict)
|
219 |
+
|
220 |
+
# Sprint3 - Contribution
|
221 |
+
if is_panel:
|
222 |
+
# read tuned mixedLM model
|
223 |
+
# if st.session_state["tuned_model"] is not None :
|
224 |
+
if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
|
225 |
+
with open(
|
226 |
+
os.path.join(
|
227 |
+
st.session_state["project_path"], "tuned_model.pkl"
|
228 |
+
),
|
229 |
+
"rb",
|
230 |
+
) as file:
|
231 |
+
model_dict = pickle.load(file)
|
232 |
+
saved_models = list(model_dict.keys())
|
233 |
+
# st.write(saved_models)
|
234 |
+
required_saved_models = [
|
235 |
+
m.split("__")[0]
|
236 |
+
for m in saved_models
|
237 |
+
if m.split("__")[1] == target_col
|
238 |
+
]
|
239 |
+
sel_model = st.selectbox(
|
240 |
+
"Select the model to review", required_saved_models
|
241 |
+
)
|
242 |
+
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
243 |
+
|
244 |
+
model = sel_model_dict["Model_object"]
|
245 |
+
X_train = sel_model_dict["X_train_tuned"]
|
246 |
+
X_test = sel_model_dict["X_test_tuned"]
|
247 |
+
best_feature_set = sel_model_dict["feature_set"]
|
248 |
+
|
249 |
+
else: # if non tuned model to be used # Pending
|
250 |
+
with open(
|
251 |
+
os.path.join(
|
252 |
+
st.session_state["project_path"], "best_models.pkl"
|
253 |
+
),
|
254 |
+
"rb",
|
255 |
+
) as file:
|
256 |
+
model_dict = pickle.load(file)
|
257 |
+
# st.write(model_dict)
|
258 |
+
saved_models = list(model_dict.keys())
|
259 |
+
required_saved_models = [
|
260 |
+
m.split("__")[0]
|
261 |
+
for m in saved_models
|
262 |
+
if m.split("__")[1] == target_col
|
263 |
+
]
|
264 |
+
sel_model = st.selectbox(
|
265 |
+
"Select the model to review", required_saved_models
|
266 |
+
)
|
267 |
+
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
268 |
+
# st.write(sel_model, sel_model_dict)
|
269 |
+
model = sel_model_dict["Model_object"]
|
270 |
+
X_train = sel_model_dict["X_train"]
|
271 |
+
X_test = sel_model_dict["X_test"]
|
272 |
+
best_feature_set = sel_model_dict["feature_set"]
|
273 |
+
|
274 |
+
# Calculate contributions
|
275 |
+
|
276 |
+
with open(
|
277 |
+
os.path.join(st.session_state["project_path"], "data_import.pkl"),
|
278 |
+
"rb",
|
279 |
+
) as f:
|
280 |
+
data = pickle.load(f)
|
281 |
+
|
282 |
+
# Accessing the loaded objects
|
283 |
+
st.session_state["orig_media_data"] = data["final_df"]
|
284 |
+
|
285 |
+
st.session_state["orig_media_data"].columns = [
|
286 |
+
col.lower()
|
287 |
+
.replace(".", "_")
|
288 |
+
.replace("@", "_")
|
289 |
+
.replace(" ", "_")
|
290 |
+
.replace("-", "")
|
291 |
+
.replace(":", "")
|
292 |
+
.replace("__", "_")
|
293 |
+
for col in st.session_state["orig_media_data"].columns
|
294 |
+
]
|
295 |
+
|
296 |
+
media_data = st.session_state["media_data"]
|
297 |
+
|
298 |
+
# st.session_state['orig_media_data']=st.session_state["media_data"]
|
299 |
+
|
300 |
+
# st.write(media_data)
|
301 |
+
|
302 |
+
contri_df = pd.DataFrame()
|
303 |
+
|
304 |
+
y = []
|
305 |
+
y_pred = []
|
306 |
+
|
307 |
+
random_eff_df = get_random_effects(media_data, panel_col, model)
|
308 |
+
random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
|
309 |
+
random_eff_df["panel_effect"] = (
|
310 |
+
random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
|
311 |
+
)
|
312 |
+
# random_eff_df.to_csv("Test/random_eff_df_contri.csv", index=False)
|
313 |
+
|
314 |
+
coef_df = pd.DataFrame(model.fe_params)
|
315 |
+
coef_df.reset_index(inplace=True)
|
316 |
+
coef_df.columns = ["feature", "coef"]
|
317 |
+
|
318 |
+
# coef_df.reset_index().to_csv("Test/coef_df_contri1.csv",index=False)
|
319 |
+
# print(model.fe_params)
|
320 |
+
|
321 |
+
x_train_contribution = X_train.copy()
|
322 |
+
x_test_contribution = X_test.copy()
|
323 |
+
|
324 |
+
# preprocessing not needed since X_train is already preprocessed
|
325 |
+
# X1, X2 = process_train_and_test(x_train_contribution, x_test_contribution, best_feature_set, panel_col, target_col)
|
326 |
+
# x_train_contribution[best_feature_set] = X1[best_feature_set]
|
327 |
+
# x_test_contribution[best_feature_set] = X2[best_feature_set]
|
328 |
+
|
329 |
+
x_train_contribution = mdf_predict(
|
330 |
+
x_train_contribution, model, random_eff_df
|
331 |
+
)
|
332 |
+
x_test_contribution = mdf_predict(
|
333 |
+
x_test_contribution, model, random_eff_df
|
334 |
+
)
|
335 |
+
|
336 |
+
x_train_contribution = pd.merge(
|
337 |
+
x_train_contribution,
|
338 |
+
random_eff_df[[panel_col, "panel_effect"]],
|
339 |
+
on=panel_col,
|
340 |
+
how="left",
|
341 |
+
)
|
342 |
+
x_test_contribution = pd.merge(
|
343 |
+
x_test_contribution,
|
344 |
+
random_eff_df[[panel_col, "panel_effect"]],
|
345 |
+
on=panel_col,
|
346 |
+
how="left",
|
347 |
+
)
|
348 |
+
|
349 |
+
for i in range(len(coef_df))[1:]:
|
350 |
+
coef = coef_df.loc[i, "coef"]
|
351 |
+
col = coef_df.loc[i, "feature"]
|
352 |
+
x_train_contribution[str(col) + "_contr"] = (
|
353 |
+
coef * x_train_contribution[col]
|
354 |
+
)
|
355 |
+
x_test_contribution[str(col) + "_contr"] = (
|
356 |
+
coef * x_train_contribution[col]
|
357 |
+
)
|
358 |
+
|
359 |
+
x_train_contribution["sum_contributions"] = (
|
360 |
+
x_train_contribution.filter(regex="contr").sum(axis=1)
|
361 |
+
)
|
362 |
+
x_train_contribution["sum_contributions"] = (
|
363 |
+
x_train_contribution["sum_contributions"]
|
364 |
+
+ x_train_contribution["panel_effect"]
|
365 |
+
)
|
366 |
+
|
367 |
+
x_test_contribution["sum_contributions"] = x_test_contribution.filter(
|
368 |
+
regex="contr"
|
369 |
+
).sum(axis=1)
|
370 |
+
x_test_contribution["sum_contributions"] = (
|
371 |
+
x_test_contribution["sum_contributions"]
|
372 |
+
+ x_test_contribution["panel_effect"]
|
373 |
+
)
|
374 |
+
|
375 |
+
# # test
|
376 |
+
x_train_contribution.to_csv(
|
377 |
+
"Test/x_train_contribution.csv", index=False
|
378 |
+
)
|
379 |
+
x_test_contribution.to_csv("Test/x_test_contribution.csv", index=False)
|
380 |
+
#
|
381 |
+
# st.session_state['orig_media_data'].to_csv("Test/transformed_data.csv",index=False)
|
382 |
+
# st.session_state['X_test_spends'].to_csv("Test/test_spends.csv",index=False)
|
383 |
+
# # st.write(st.session_state['orig_media_data'].columns)
|
384 |
+
|
385 |
+
# st.write(date_col,panel_col)
|
386 |
+
# st.write(x_test_contribution)
|
387 |
+
|
388 |
+
overview_test_data_prep_panel(
|
389 |
+
x_test_contribution,
|
390 |
+
st.session_state["orig_media_data"],
|
391 |
+
st.session_state["X_test_spends"],
|
392 |
+
date_col,
|
393 |
+
panel_col,
|
394 |
+
target_col,
|
395 |
+
)
|
396 |
+
|
397 |
+
else: # NON PANEL
|
398 |
+
if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
|
399 |
+
with open(
|
400 |
+
os.path.join(
|
401 |
+
st.session_state["project_path"], "tuned_model.pkl"
|
402 |
+
),
|
403 |
+
"rb",
|
404 |
+
) as file:
|
405 |
+
model_dict = pickle.load(file)
|
406 |
+
saved_models = list(model_dict.keys())
|
407 |
+
required_saved_models = [
|
408 |
+
m.split("__")[0]
|
409 |
+
for m in saved_models
|
410 |
+
if m.split("__")[1] == target_col
|
411 |
+
]
|
412 |
+
sel_model = st.selectbox(
|
413 |
+
"Select the model to review", required_saved_models
|
414 |
+
)
|
415 |
+
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
416 |
+
|
417 |
+
model = sel_model_dict["Model_object"]
|
418 |
+
X_train = sel_model_dict["X_train_tuned"]
|
419 |
+
X_test = sel_model_dict["X_test_tuned"]
|
420 |
+
best_feature_set = sel_model_dict["feature_set"]
|
421 |
+
|
422 |
+
else: # Sprint4
|
423 |
+
with open(
|
424 |
+
os.path.join(
|
425 |
+
st.session_state["project_path"], "best_models.pkl"
|
426 |
+
),
|
427 |
+
"rb",
|
428 |
+
) as file:
|
429 |
+
model_dict = pickle.load(file)
|
430 |
+
saved_models = list(model_dict.keys())
|
431 |
+
required_saved_models = [
|
432 |
+
m.split("__")[0]
|
433 |
+
for m in saved_models
|
434 |
+
if m.split("__")[1] == target_col
|
435 |
+
]
|
436 |
+
sel_model = st.selectbox(
|
437 |
+
"Select the model to review", required_saved_models
|
438 |
+
)
|
439 |
+
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
440 |
+
|
441 |
+
model = sel_model_dict["Model_object"]
|
442 |
+
X_train = sel_model_dict["X_train"]
|
443 |
+
X_test = sel_model_dict["X_test"]
|
444 |
+
best_feature_set = sel_model_dict["feature_set"]
|
445 |
+
|
446 |
+
x_train_contribution = X_train.copy()
|
447 |
+
x_test_contribution = X_test.copy()
|
448 |
+
|
449 |
+
x_train_contribution["pred"] = model.predict(
|
450 |
+
x_train_contribution[best_feature_set]
|
451 |
+
)
|
452 |
+
x_test_contribution["pred"] = model.predict(
|
453 |
+
x_test_contribution[best_feature_set]
|
454 |
+
)
|
455 |
+
|
456 |
+
for num, i in enumerate(model.params.values):
|
457 |
+
col = best_feature_set[num]
|
458 |
+
x_train_contribution[col + "_contr"] = X_train[col] * i
|
459 |
+
x_test_contribution[col + "_contr"] = X_test[col] * i
|
460 |
+
|
461 |
+
x_test_contribution.to_csv(
|
462 |
+
"Test/x_test_contribution_non_panel.csv", index=False
|
463 |
+
)
|
464 |
+
overview_test_data_prep_nonpanel(
|
465 |
+
x_test_contribution,
|
466 |
+
st.session_state["orig_media_data"].copy(),
|
467 |
+
st.session_state["X_test_spends"].copy(),
|
468 |
+
date_col,
|
469 |
+
target_col,
|
470 |
+
)
|
471 |
+
# for k, v in st.session_sta
|
472 |
+
# te.items():
|
473 |
+
|
474 |
+
# if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
|
475 |
+
# st.session_state[k] = v
|
476 |
+
|
477 |
+
# authenticator = st.session_state.get('authenticator')
|
478 |
+
|
479 |
+
# if authenticator is None:
|
480 |
+
# authenticator = load_authenticator()
|
481 |
+
|
482 |
+
# name, authentication_status, username = authenticator.login('Login', 'main')
|
483 |
+
# auth_status = st.session_state['authentication_status']
|
484 |
+
|
485 |
+
# if auth_status:
|
486 |
+
# authenticator.logout('Logout', 'main')
|
487 |
+
|
488 |
+
# is_state_initiaized = st.session_state.get('initialized',False)
|
489 |
+
# if not is_state_initiaized:
|
490 |
+
|
491 |
+
initialize_data(target_col)
|
492 |
+
scenario = st.session_state["scenario"]
|
493 |
+
raw_df = st.session_state["raw_df"]
|
494 |
+
st.header("Overview of previous spends")
|
495 |
+
|
496 |
+
# st.write(scenario.actual_total_spends)
|
497 |
+
# st.write(scenario.actual_total_sales)
|
498 |
+
columns = st.columns((1, 1, 3))
|
499 |
+
|
500 |
+
with columns[0]:
|
501 |
+
st.metric(
|
502 |
+
label="Spends",
|
503 |
+
value=format_numbers(float(scenario.actual_total_spends)),
|
504 |
+
)
|
505 |
+
###print(f"##################### {scenario.actual_total_sales} ##################")
|
506 |
+
with columns[1]:
|
507 |
+
st.metric(
|
508 |
+
label=target,
|
509 |
+
value=format_numbers(
|
510 |
+
float(scenario.actual_total_sales), include_indicator=False
|
511 |
+
),
|
512 |
+
)
|
513 |
+
|
514 |
+
actual_summary_df = create_channel_summary(scenario)
|
515 |
+
actual_summary_df["Channel"] = actual_summary_df["Channel"].apply(
|
516 |
+
channel_name_formating
|
517 |
+
)
|
518 |
+
|
519 |
+
columns = st.columns((2, 1))
|
520 |
+
with columns[0]:
|
521 |
+
with st.expander("Channel wise overview"):
|
522 |
+
st.markdown(
|
523 |
+
actual_summary_df.style.set_table_styles(
|
524 |
+
[
|
525 |
+
{
|
526 |
+
"selector": "th",
|
527 |
+
"props": [("background-color", "#11B6BD")],
|
528 |
+
},
|
529 |
+
{
|
530 |
+
"selector": "tr:nth-child(even)",
|
531 |
+
"props": [("background-color", "#11B6BD")],
|
532 |
+
},
|
533 |
+
]
|
534 |
+
).to_html(),
|
535 |
+
unsafe_allow_html=True,
|
536 |
+
)
|
537 |
+
|
538 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
539 |
+
##############################
|
540 |
+
|
541 |
+
st.plotly_chart(
|
542 |
+
create_contribution_pie(scenario), use_container_width=True
|
543 |
+
)
|
544 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
545 |
+
|
546 |
+
################################3
|
547 |
+
st.plotly_chart(
|
548 |
+
create_contribuion_stacked_plot(scenario), use_container_width=True
|
549 |
+
)
|
550 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
551 |
+
#######################################
|
552 |
+
|
553 |
+
selected_channel_name = st.selectbox(
|
554 |
+
"Channel",
|
555 |
+
st.session_state["channels_list"] + ["non media"],
|
556 |
+
format_func=channel_name_formating,
|
557 |
+
)
|
558 |
+
selected_channel = scenario.channels.get(selected_channel_name, None)
|
559 |
+
|
560 |
+
st.plotly_chart(
|
561 |
+
create_channel_spends_sales_plot(selected_channel),
|
562 |
+
use_container_width=True,
|
563 |
+
)
|
564 |
+
|
565 |
+
st.markdown("<hr>", unsafe_allow_html=True)
|
566 |
+
|
567 |
+
if st.checkbox("Save this session", key="save"):
|
568 |
+
project_dct_path = os.path.join(
|
569 |
+
st.session_state["project_path"], "project_dct.pkl"
|
570 |
+
)
|
571 |
+
with open(project_dct_path, "wb") as f:
|
572 |
+
pickle.dump(st.session_state["project_dct"], f)
|
573 |
+
update_db("7_Current_Media_Performance.py")
|
pages/8_Response_Curves.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import plotly.express as px
|
3 |
+
import numpy as np
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from utilities import (
|
6 |
+
channel_name_formating,
|
7 |
+
load_authenticator,
|
8 |
+
initialize_data,
|
9 |
+
fetch_actual_data,
|
10 |
+
)
|
11 |
+
from sklearn.metrics import r2_score
|
12 |
+
from collections import OrderedDict
|
13 |
+
from classes import class_from_dict, class_to_dict
|
14 |
+
import pickle
|
15 |
+
import json
|
16 |
+
import sqlite3
|
17 |
+
from utilities import update_db
|
18 |
+
|
19 |
+
for k, v in st.session_state.items():
|
20 |
+
if k not in ["logout", "login", "config"] and not k.startswith(
|
21 |
+
"FormSubmitter"
|
22 |
+
):
|
23 |
+
st.session_state[k] = v
|
24 |
+
|
25 |
+
|
26 |
+
def s_curve(x, K, b, a, x0):
|
27 |
+
return K / (1 + b * np.exp(-a * (x - x0)))
|
28 |
+
|
29 |
+
|
30 |
+
def save_scenario(scenario_name):
|
31 |
+
"""
|
32 |
+
Save the current scenario with the mentioned name in the session state
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
scenario_name
|
37 |
+
Name of the scenario to be saved
|
38 |
+
"""
|
39 |
+
if "saved_scenarios" not in st.session_state:
|
40 |
+
st.session_state = OrderedDict()
|
41 |
+
|
42 |
+
# st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
|
43 |
+
st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
|
44 |
+
st.session_state["scenario"]
|
45 |
+
)
|
46 |
+
st.session_state["scenario_input"] = ""
|
47 |
+
print(type(st.session_state["saved_scenarios"]))
|
48 |
+
with open("../saved_scenarios.pkl", "wb") as f:
|
49 |
+
pickle.dump(st.session_state["saved_scenarios"], f)
|
50 |
+
|
51 |
+
|
52 |
+
def reset_curve_parameters(
|
53 |
+
metrics=None, panel=None, selected_channel_name=None
|
54 |
+
):
|
55 |
+
del st.session_state["K"]
|
56 |
+
del st.session_state["b"]
|
57 |
+
del st.session_state["a"]
|
58 |
+
del st.session_state["x0"]
|
59 |
+
|
60 |
+
if (
|
61 |
+
metrics is not None
|
62 |
+
and panel is not None
|
63 |
+
and selected_channel_name is not None
|
64 |
+
):
|
65 |
+
if f"{metrics}#@{panel}#@{selected_channel_name}" in list(
|
66 |
+
st.session_state["update_rcs"].keys()
|
67 |
+
):
|
68 |
+
del st.session_state["update_rcs"][
|
69 |
+
f"{metrics}#@{panel}#@{selected_channel_name}"
|
70 |
+
]
|
71 |
+
|
72 |
+
|
73 |
+
def update_response_curve(
|
74 |
+
K_updated,
|
75 |
+
b_updated,
|
76 |
+
a_updated,
|
77 |
+
x0_updated,
|
78 |
+
metrics=None,
|
79 |
+
panel=None,
|
80 |
+
selected_channel_name=None,
|
81 |
+
):
|
82 |
+
print(
|
83 |
+
"[DEBUG] update_response_curves: ",
|
84 |
+
st.session_state["project_dct"]["scenario_planner"].keys(),
|
85 |
+
)
|
86 |
+
st.session_state["project_dct"]["scenario_planner"][unique_key].channels[
|
87 |
+
selected_channel_name
|
88 |
+
].response_curve_params = {
|
89 |
+
"K": st.session_state["K"],
|
90 |
+
"b": st.session_state["b"],
|
91 |
+
"a": st.session_state["a"],
|
92 |
+
"x0": st.session_state["x0"],
|
93 |
+
}
|
94 |
+
|
95 |
+
# if (
|
96 |
+
# metrics is not None
|
97 |
+
# and panel is not None
|
98 |
+
# and selected_channel_name is not None
|
99 |
+
# ):
|
100 |
+
# st.session_state["update_rcs"][
|
101 |
+
# f"{metrics}#@{panel}#@{selected_channel_name}"
|
102 |
+
# ] = {
|
103 |
+
# "K": K_updated,
|
104 |
+
# "b": b_updated,
|
105 |
+
# "a": a_updated,
|
106 |
+
# "x0": x0_updated,
|
107 |
+
# }
|
108 |
+
|
109 |
+
# st.session_state["scenario"].channels[
|
110 |
+
# selected_channel_name
|
111 |
+
# ].response_curve_params = {
|
112 |
+
# "K": K_updated,
|
113 |
+
# "b": b_updated,
|
114 |
+
# "a": a_updated,
|
115 |
+
# "x0": x0_updated,
|
116 |
+
# }
|
117 |
+
|
118 |
+
|
119 |
+
# authenticator = st.session_state.get('authenticator')
|
120 |
+
# if authenticator is None:
|
121 |
+
# authenticator = load_authenticator()
|
122 |
+
|
123 |
+
# name, authentication_status, username = authenticator.login('Login', 'main')
|
124 |
+
# auth_status = st.session_state.get('authentication_status')
|
125 |
+
|
126 |
+
# if auth_status == True:
|
127 |
+
# is_state_initiaized = st.session_state.get('initialized',False)
|
128 |
+
# if not is_state_initiaized:
|
129 |
+
# print("Scenario page state reloaded")
|
130 |
+
|
131 |
+
import pandas as pd
|
132 |
+
|
133 |
+
|
134 |
+
@st.cache_resource(show_spinner=False)
|
135 |
+
def panel_fetch(file_selected):
|
136 |
+
raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
|
137 |
+
|
138 |
+
if "Panel" in raw_data_mmm_df.columns:
|
139 |
+
panel = list(set(raw_data_mmm_df["Panel"]))
|
140 |
+
else:
|
141 |
+
raw_data_mmm_df = None
|
142 |
+
panel = None
|
143 |
+
|
144 |
+
return panel
|
145 |
+
|
146 |
+
|
147 |
+
import glob
|
148 |
+
import os
|
149 |
+
|
150 |
+
|
151 |
+
def get_excel_names(directory):
|
152 |
+
# Create a list to hold the final parts of the filenames
|
153 |
+
last_portions = []
|
154 |
+
|
155 |
+
# Patterns to match Excel files (.xlsx and .xls) that contain @#
|
156 |
+
patterns = [
|
157 |
+
os.path.join(directory, "*@#*.xlsx"),
|
158 |
+
os.path.join(directory, "*@#*.xls"),
|
159 |
+
]
|
160 |
+
|
161 |
+
# Process each pattern
|
162 |
+
for pattern in patterns:
|
163 |
+
files = glob.glob(pattern)
|
164 |
+
|
165 |
+
# Extracting the last portion after @# for each file
|
166 |
+
for file in files:
|
167 |
+
base_name = os.path.basename(file)
|
168 |
+
last_portion = base_name.split("@#")[-1]
|
169 |
+
last_portion = last_portion.replace(".xlsx", "").replace(
|
170 |
+
".xls", ""
|
171 |
+
) # Removing extensions
|
172 |
+
last_portions.append(last_portion)
|
173 |
+
|
174 |
+
return last_portions
|
175 |
+
|
176 |
+
|
177 |
+
def name_formating(channel_name):
|
178 |
+
# Replace underscores with spaces
|
179 |
+
name_mod = channel_name.replace("_", " ")
|
180 |
+
|
181 |
+
# Capitalize the first letter of each word
|
182 |
+
name_mod = name_mod.title()
|
183 |
+
|
184 |
+
return name_mod
|
185 |
+
|
186 |
+
|
187 |
+
def fetch_panel_data():
|
188 |
+
print("DEBUG etch_panel_data: running... ")
|
189 |
+
file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx"
|
190 |
+
panel_selected = st.session_state["panel_selected_selectbox"]
|
191 |
+
print(panel_selected)
|
192 |
+
if panel_selected == "Aggregated":
|
193 |
+
(
|
194 |
+
st.session_state["actual_input_df"],
|
195 |
+
st.session_state["actual_contribution_df"],
|
196 |
+
) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
|
197 |
+
else:
|
198 |
+
(
|
199 |
+
st.session_state["actual_input_df"],
|
200 |
+
st.session_state["actual_contribution_df"],
|
201 |
+
) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
|
202 |
+
|
203 |
+
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
|
204 |
+
print("unique_key")
|
205 |
+
if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
|
206 |
+
if panel_selected == "Aggregated":
|
207 |
+
initialize_data(
|
208 |
+
panel=panel_selected,
|
209 |
+
target_file=file_selected,
|
210 |
+
updated_rcs={},
|
211 |
+
metrics=metrics_selected,
|
212 |
+
)
|
213 |
+
panel = None
|
214 |
+
else:
|
215 |
+
initialize_data(
|
216 |
+
panel=panel_selected,
|
217 |
+
target_file=file_selected,
|
218 |
+
updated_rcs={},
|
219 |
+
metrics=metrics_selected,
|
220 |
+
)
|
221 |
+
st.session_state["project_dct"]["scenario_planner"][unique_key] = (
|
222 |
+
st.session_state["scenario"]
|
223 |
+
)
|
224 |
+
# print(
|
225 |
+
# "DEBUG etch_panel_data: ",
|
226 |
+
# st.session_state["project_dct"]["scenario_planner"][
|
227 |
+
# unique_key
|
228 |
+
# ].keys(),
|
229 |
+
# )
|
230 |
+
|
231 |
+
else:
|
232 |
+
st.session_state["scenario"] = st.session_state["project_dct"][
|
233 |
+
"scenario_planner"
|
234 |
+
][unique_key]
|
235 |
+
st.session_state["rcs"] = {}
|
236 |
+
st.session_state["powers"] = {}
|
237 |
+
|
238 |
+
for channel_name, _channel in st.session_state["project_dct"][
|
239 |
+
"scenario_planner"
|
240 |
+
][unique_key].channels.items():
|
241 |
+
st.session_state["rcs"][
|
242 |
+
channel_name
|
243 |
+
] = _channel.response_curve_params
|
244 |
+
st.session_state["powers"][channel_name] = _channel.power
|
245 |
+
|
246 |
+
if "K" in st.session_state:
|
247 |
+
del st.session_state["K"]
|
248 |
+
|
249 |
+
if "b" in st.session_state:
|
250 |
+
del st.session_state["b"]
|
251 |
+
|
252 |
+
if "a" in st.session_state:
|
253 |
+
del st.session_state["a"]
|
254 |
+
|
255 |
+
if "x0" in st.session_state:
|
256 |
+
del st.session_state["x0"]
|
257 |
+
|
258 |
+
|
259 |
+
if "project_dct" not in st.session_state:
|
260 |
+
st.error("Please load a project from home")
|
261 |
+
st.stop()
|
262 |
+
|
263 |
+
database_file = r"DB\User.db"
|
264 |
+
|
265 |
+
conn = sqlite3.connect(
|
266 |
+
database_file, check_same_thread=False
|
267 |
+
) # connection with sql db
|
268 |
+
c = conn.cursor()
|
269 |
+
|
270 |
+
st.subheader("Build Response Curves")
|
271 |
+
|
272 |
+
|
273 |
+
if "update_rcs" not in st.session_state:
|
274 |
+
st.session_state["update_rcs"] = {}
|
275 |
+
|
276 |
+
st.session_state["first_time"] = True
|
277 |
+
|
278 |
+
col1, col2, col3 = st.columns([1, 1, 1])
|
279 |
+
|
280 |
+
directory = "metrics_level_data"
|
281 |
+
metrics_list = get_excel_names(directory)
|
282 |
+
|
283 |
+
|
284 |
+
metrics_selected = col1.selectbox(
|
285 |
+
"Response Metrics",
|
286 |
+
metrics_list,
|
287 |
+
on_change=fetch_panel_data,
|
288 |
+
format_func=name_formating,
|
289 |
+
key="response_metrics_selectbox",
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
file_selected = (
|
294 |
+
f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
|
295 |
+
)
|
296 |
+
|
297 |
+
panel_list = panel_fetch(file_selected)
|
298 |
+
final_panel_list = ["Aggregated"] + panel_list
|
299 |
+
|
300 |
+
panel_selected = col3.selectbox(
|
301 |
+
"Panel",
|
302 |
+
final_panel_list,
|
303 |
+
on_change=fetch_panel_data,
|
304 |
+
key="panel_selected_selectbox",
|
305 |
+
)
|
306 |
+
|
307 |
+
|
308 |
+
is_state_initiaized = st.session_state.get("initialized_rcs", False)
|
309 |
+
print(is_state_initiaized)
|
310 |
+
if not is_state_initiaized:
|
311 |
+
print("DEBUG.....", "Here")
|
312 |
+
fetch_panel_data()
|
313 |
+
# if panel_selected == "Aggregated":
|
314 |
+
# initialize_data(panel=panel_selected, target_file=file_selected)
|
315 |
+
# panel = None
|
316 |
+
# else:
|
317 |
+
# initialize_data(panel=panel_selected, target_file=file_selected)
|
318 |
+
|
319 |
+
st.session_state["initialized_rcs"] = True
|
320 |
+
|
321 |
+
# channels_list = st.session_state["channels_list"]
|
322 |
+
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
|
323 |
+
chanel_list_final = list(
|
324 |
+
st.session_state["project_dct"]["scenario_planner"][
|
325 |
+
unique_key
|
326 |
+
].channels.keys()
|
327 |
+
) + ["Others"]
|
328 |
+
|
329 |
+
|
330 |
+
selected_channel_name = col2.selectbox(
|
331 |
+
"Channel",
|
332 |
+
chanel_list_final,
|
333 |
+
format_func=channel_name_formating,
|
334 |
+
on_change=reset_curve_parameters,
|
335 |
+
key="selected_channel_name_selectbox",
|
336 |
+
)
|
337 |
+
|
338 |
+
|
339 |
+
rcs = st.session_state["rcs"]
|
340 |
+
|
341 |
+
if "K" not in st.session_state:
|
342 |
+
st.session_state["K"] = rcs[selected_channel_name]["K"]
|
343 |
+
|
344 |
+
if "b" not in st.session_state:
|
345 |
+
st.session_state["b"] = rcs[selected_channel_name]["b"]
|
346 |
+
|
347 |
+
|
348 |
+
if "a" not in st.session_state:
|
349 |
+
st.session_state["a"] = rcs[selected_channel_name]["a"]
|
350 |
+
|
351 |
+
if "x0" not in st.session_state:
|
352 |
+
st.session_state["x0"] = rcs[selected_channel_name]["x0"]
|
353 |
+
|
354 |
+
|
355 |
+
x = st.session_state["actual_input_df"][selected_channel_name].values
|
356 |
+
y = st.session_state["actual_contribution_df"][selected_channel_name].values
|
357 |
+
|
358 |
+
|
359 |
+
power = np.ceil(np.log(x.max()) / np.log(10)) - 3
|
360 |
+
|
361 |
+
print(f"DEBUG BUILD RCS: {selected_channel_name}")
|
362 |
+
print(f"DEBUG BUILD RCS: K : {st.session_state['K']}")
|
363 |
+
print(f"DEBUG BUILD RCS: b : {st.session_state['b']}")
|
364 |
+
print(f"DEBUG BUILD RCS: a : {st.session_state['a']}")
|
365 |
+
print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}")
|
366 |
+
|
367 |
+
# fig = px.scatter(x, s_curve(x/10**power,
|
368 |
+
# st.session_state['K'],
|
369 |
+
# st.session_state['b'],
|
370 |
+
# st.session_state['a'],
|
371 |
+
# st.session_state['x0']))
|
372 |
+
|
373 |
+
x_plot = np.linspace(0, 5 * max(x), 50)
|
374 |
+
|
375 |
+
fig = px.scatter(x=x, y=y)
|
376 |
+
fig.add_trace(
|
377 |
+
go.Scatter(
|
378 |
+
x=x_plot,
|
379 |
+
y=s_curve(
|
380 |
+
x_plot / 10**power,
|
381 |
+
st.session_state["K"],
|
382 |
+
st.session_state["b"],
|
383 |
+
st.session_state["a"],
|
384 |
+
st.session_state["x0"],
|
385 |
+
),
|
386 |
+
line=dict(color="red"),
|
387 |
+
name="Modified",
|
388 |
+
),
|
389 |
+
)
|
390 |
+
|
391 |
+
fig.add_trace(
|
392 |
+
go.Scatter(
|
393 |
+
x=x_plot,
|
394 |
+
y=s_curve(
|
395 |
+
x_plot / 10**power,
|
396 |
+
rcs[selected_channel_name]["K"],
|
397 |
+
rcs[selected_channel_name]["b"],
|
398 |
+
rcs[selected_channel_name]["a"],
|
399 |
+
rcs[selected_channel_name]["x0"],
|
400 |
+
),
|
401 |
+
line=dict(color="rgba(0, 255, 0, 0.4)"),
|
402 |
+
name="Actual",
|
403 |
+
),
|
404 |
+
)
|
405 |
+
|
406 |
+
fig.update_layout(title_text="Response Curve", showlegend=True)
|
407 |
+
fig.update_annotations(font_size=10)
|
408 |
+
fig.update_xaxes(title="Spends")
|
409 |
+
fig.update_yaxes(title="Revenue")
|
410 |
+
|
411 |
+
st.plotly_chart(fig, use_container_width=True)
|
412 |
+
|
413 |
+
r2 = r2_score(
|
414 |
+
y,
|
415 |
+
s_curve(
|
416 |
+
x / 10**power,
|
417 |
+
st.session_state["K"],
|
418 |
+
st.session_state["b"],
|
419 |
+
st.session_state["a"],
|
420 |
+
st.session_state["x0"],
|
421 |
+
),
|
422 |
+
)
|
423 |
+
|
424 |
+
r2_actual = r2_score(
|
425 |
+
y,
|
426 |
+
s_curve(
|
427 |
+
x / 10**power,
|
428 |
+
rcs[selected_channel_name]["K"],
|
429 |
+
rcs[selected_channel_name]["b"],
|
430 |
+
rcs[selected_channel_name]["a"],
|
431 |
+
rcs[selected_channel_name]["x0"],
|
432 |
+
),
|
433 |
+
)
|
434 |
+
|
435 |
+
columns = st.columns((1, 1, 2))
|
436 |
+
with columns[0]:
|
437 |
+
st.metric("R2 Modified", round(r2, 2))
|
438 |
+
with columns[1]:
|
439 |
+
st.metric("R2 Actual", round(r2_actual, 2))
|
440 |
+
|
441 |
+
|
442 |
+
st.markdown("#### Set Parameters", unsafe_allow_html=True)
|
443 |
+
columns = st.columns(4)
|
444 |
+
|
445 |
+
if "updated_parms" not in st.session_state:
|
446 |
+
st.session_state["updated_parms"] = {
|
447 |
+
"K_updated": 0,
|
448 |
+
"b_updated": 0,
|
449 |
+
"a_updated": 0,
|
450 |
+
"x0_updated": 0,
|
451 |
+
}
|
452 |
+
|
453 |
+
with columns[0]:
|
454 |
+
st.session_state["updated_parms"]["K_updated"] = st.number_input(
|
455 |
+
"K", key="K", format="%0.5f"
|
456 |
+
)
|
457 |
+
with columns[1]:
|
458 |
+
st.session_state["updated_parms"]["b_updated"] = st.number_input(
|
459 |
+
"b", key="b", format="%0.5f"
|
460 |
+
)
|
461 |
+
with columns[2]:
|
462 |
+
st.session_state["updated_parms"]["a_updated"] = st.number_input(
|
463 |
+
"a", key="a", step=0.0001, format="%0.5f"
|
464 |
+
)
|
465 |
+
with columns[3]:
|
466 |
+
st.session_state["updated_parms"]["x0_updated"] = st.number_input(
|
467 |
+
"x0", key="x0", format="%0.5f"
|
468 |
+
)
|
469 |
+
|
470 |
+
# st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = (
|
471 |
+
# st.session_state["updated_parms"]["K_updated"]
|
472 |
+
# )
|
473 |
+
# st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = (
|
474 |
+
# st.session_state["updated_parms"]["b_updated"]
|
475 |
+
# )
|
476 |
+
# st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = (
|
477 |
+
# st.session_state["updated_parms"]["a_updated"]
|
478 |
+
# )
|
479 |
+
# st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = (
|
480 |
+
# st.session_state["updated_parms"]["x0_updated"]
|
481 |
+
# )
|
482 |
+
|
483 |
+
update_col, reset_col = st.columns([1, 1])
|
484 |
+
if update_col.button(
|
485 |
+
"Update Parameters",
|
486 |
+
on_click=update_response_curve,
|
487 |
+
args=(
|
488 |
+
st.session_state["updated_parms"]["K_updated"],
|
489 |
+
st.session_state["updated_parms"]["b_updated"],
|
490 |
+
st.session_state["updated_parms"]["a_updated"],
|
491 |
+
st.session_state["updated_parms"]["x0_updated"],
|
492 |
+
metrics_selected,
|
493 |
+
panel_selected,
|
494 |
+
selected_channel_name,
|
495 |
+
),
|
496 |
+
use_container_width=True,
|
497 |
+
):
|
498 |
+
st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[
|
499 |
+
"updated_parms"
|
500 |
+
]["K_updated"]
|
501 |
+
st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[
|
502 |
+
"updated_parms"
|
503 |
+
]["b_updated"]
|
504 |
+
st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[
|
505 |
+
"updated_parms"
|
506 |
+
]["a_updated"]
|
507 |
+
st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[
|
508 |
+
"updated_parms"
|
509 |
+
]["x0_updated"]
|
510 |
+
|
511 |
+
reset_col.button(
|
512 |
+
"Reset Parameters",
|
513 |
+
on_click=reset_curve_parameters,
|
514 |
+
args=(metrics_selected, panel_selected, selected_channel_name),
|
515 |
+
use_container_width=True,
|
516 |
+
)
|
517 |
+
|
518 |
+
st.divider()
|
519 |
+
save_col, down_col = st.columns([1, 1])
|
520 |
+
|
521 |
+
|
522 |
+
with save_col:
|
523 |
+
file_name = st.text_input(
|
524 |
+
"rcs download file name",
|
525 |
+
key="file_name_input",
|
526 |
+
placeholder="File name",
|
527 |
+
label_visibility="collapsed",
|
528 |
+
)
|
529 |
+
down_col.download_button(
|
530 |
+
label="Download response curves",
|
531 |
+
data=json.dumps(rcs),
|
532 |
+
file_name=f"{file_name}.json",
|
533 |
+
mime="application/json",
|
534 |
+
disabled=len(file_name) == 0,
|
535 |
+
use_container_width=True,
|
536 |
+
)
|
537 |
+
|
538 |
+
|
539 |
+
def s_curve_derivative(x, K, b, a, x0):
|
540 |
+
# Derivative of the S-curve function
|
541 |
+
return (
|
542 |
+
a
|
543 |
+
* b
|
544 |
+
* K
|
545 |
+
* np.exp(-a * (x - x0))
|
546 |
+
/ ((1 + b * np.exp(-a * (x - x0))) ** 2)
|
547 |
+
)
|
548 |
+
|
549 |
+
|
550 |
+
# Parameters of the S-curve
|
551 |
+
K = st.session_state["K"]
|
552 |
+
b = st.session_state["b"]
|
553 |
+
a = st.session_state["a"]
|
554 |
+
x0 = st.session_state["x0"]
|
555 |
+
|
556 |
+
# # Optimized spend value obtained from the tool
|
557 |
+
# optimized_spend = st.number_input(
|
558 |
+
# "value of x"
|
559 |
+
# ) # Replace this with your optimized spend value
|
560 |
+
|
561 |
+
# # Calculate the slope at the optimized spend value
|
562 |
+
# slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0)
|
563 |
+
|
564 |
+
# st.write("Slope ", slope_at_optimized_spend)
|
565 |
+
|
566 |
+
|
567 |
+
# Initialize a list to hold our rows
|
568 |
+
rows = []
|
569 |
+
|
570 |
+
# Iterate over the dictionary
|
571 |
+
for key, value in st.session_state["update_rcs"].items():
|
572 |
+
# Split the key into its components
|
573 |
+
metrics, panel, channel_name = key.split("#@")
|
574 |
+
# Create a new row with the components and the values
|
575 |
+
row = {
|
576 |
+
"Metrics": name_formating(metrics),
|
577 |
+
"Panel": name_formating(panel),
|
578 |
+
"Channel Name": channel_name,
|
579 |
+
"K": value["K"],
|
580 |
+
"b": value["b"],
|
581 |
+
"a": value["a"],
|
582 |
+
"x0": value["x0"],
|
583 |
+
}
|
584 |
+
# Append the row to our list
|
585 |
+
rows.append(row)
|
586 |
+
|
587 |
+
# Convert the list of rows into a DataFrame
|
588 |
+
updated_parms_df = pd.DataFrame(rows)
|
589 |
+
|
590 |
+
if len(list(st.session_state["update_rcs"].keys())) > 0:
|
591 |
+
st.markdown("#### Updated Parameters", unsafe_allow_html=True)
|
592 |
+
st.dataframe(updated_parms_df, hide_index=True)
|
593 |
+
else:
|
594 |
+
st.info("No parameters are updated")
|
595 |
+
|
596 |
+
update_db("8_Build_Response_Curves.py")
|
pages/9_Scenario_Planner.py
ADDED
@@ -0,0 +1,1715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from numerize.numerize import numerize
|
3 |
+
import numpy as np
|
4 |
+
from functools import partial
|
5 |
+
from collections import OrderedDict
|
6 |
+
from plotly.subplots import make_subplots
|
7 |
+
import plotly.graph_objects as go
|
8 |
+
from utilities import (
|
9 |
+
format_numbers,
|
10 |
+
load_local_css,
|
11 |
+
set_header,
|
12 |
+
initialize_data,
|
13 |
+
load_authenticator,
|
14 |
+
send_email,
|
15 |
+
channel_name_formating,
|
16 |
+
)
|
17 |
+
from classes import class_from_dict, class_to_dict
|
18 |
+
import pickle
|
19 |
+
import streamlit_authenticator as stauth
|
20 |
+
import yaml
|
21 |
+
from yaml import SafeLoader
|
22 |
+
import re
|
23 |
+
import pandas as pd
|
24 |
+
import plotly.express as px
|
25 |
+
import logging
|
26 |
+
from utilities import update_db
|
27 |
+
import sqlite3
|
28 |
+
|
29 |
+
|
30 |
+
st.set_page_config(layout="wide")
|
31 |
+
load_local_css("styles.css")
|
32 |
+
set_header()
|
33 |
+
|
34 |
+
for k, v in st.session_state.items():
|
35 |
+
if k not in ["logout", "login", "config"] and not k.startswith(
|
36 |
+
"FormSubmitter"
|
37 |
+
):
|
38 |
+
st.session_state[k] = v
|
39 |
+
# ======================================================== #
|
40 |
+
# ======================= Functions ====================== #
|
41 |
+
# ======================================================== #
|
42 |
+
|
43 |
+
|
44 |
+
def optimize(key, status_placeholder):
|
45 |
+
"""
|
46 |
+
Optimize the spends for the sales
|
47 |
+
"""
|
48 |
+
|
49 |
+
channel_list = [
|
50 |
+
key
|
51 |
+
for key, value in st.session_state["optimization_channels"].items()
|
52 |
+
if value
|
53 |
+
]
|
54 |
+
|
55 |
+
if len(channel_list) > 0:
|
56 |
+
scenario = st.session_state["scenario"]
|
57 |
+
if key.lower() == "media spends":
|
58 |
+
with status_placeholder:
|
59 |
+
with st.spinner("Optimizing"):
|
60 |
+
result = st.session_state["scenario"].optimize(
|
61 |
+
st.session_state["total_spends_change"], channel_list
|
62 |
+
)
|
63 |
+
# elif key.lower() == "revenue":
|
64 |
+
else:
|
65 |
+
with status_placeholder:
|
66 |
+
with st.spinner("Optimizing"):
|
67 |
+
|
68 |
+
result = st.session_state["scenario"].optimize_spends(
|
69 |
+
st.session_state["total_sales_change"], channel_list
|
70 |
+
)
|
71 |
+
for channel_name, modified_spends in result:
|
72 |
+
|
73 |
+
st.session_state[channel_name] = numerize(
|
74 |
+
modified_spends
|
75 |
+
* scenario.channels[channel_name].conversion_rate,
|
76 |
+
1,
|
77 |
+
)
|
78 |
+
prev_spends = (
|
79 |
+
st.session_state["scenario"]
|
80 |
+
.channels[channel_name]
|
81 |
+
.actual_total_spends
|
82 |
+
)
|
83 |
+
st.session_state[f"{channel_name}_change"] = round(
|
84 |
+
100 * (modified_spends - prev_spends) / prev_spends, 2
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
def save_scenario(scenario_name):
|
89 |
+
"""
|
90 |
+
Save the current scenario with the mentioned name in the session state
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
scenario_name
|
95 |
+
Name of the scenario to be saved
|
96 |
+
"""
|
97 |
+
if "saved_scenarios" not in st.session_state:
|
98 |
+
st.session_state = OrderedDict()
|
99 |
+
|
100 |
+
# st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
|
101 |
+
st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
|
102 |
+
st.session_state["scenario"]
|
103 |
+
)
|
104 |
+
st.session_state["scenario_input"] = ""
|
105 |
+
# print(type(st.session_state['saved_scenarios']))
|
106 |
+
with open("../saved_scenarios.pkl", "wb") as f:
|
107 |
+
pickle.dump(st.session_state["saved_scenarios"], f)
|
108 |
+
|
109 |
+
|
110 |
+
def update_sales_abs_slider():
|
111 |
+
actual_sales = st.session_state["scenario"].actual_total_sales
|
112 |
+
if validate_input(st.session_state["total_sales_change_abs_slider"]):
|
113 |
+
modified_sales = extract_number_for_string(
|
114 |
+
st.session_state["total_sales_change_abs_slider"]
|
115 |
+
)
|
116 |
+
st.session_state["total_sales_change"] = round(
|
117 |
+
((modified_sales / actual_sales) - 1) * 100
|
118 |
+
)
|
119 |
+
st.session_state["total_sales_change_abs"] = numerize(
|
120 |
+
modified_sales, 1
|
121 |
+
)
|
122 |
+
|
123 |
+
st.session_state["project_dct"]["scenario_planner"][
|
124 |
+
"total_sales_change"
|
125 |
+
] = st.session_state.total_sales_change
|
126 |
+
|
127 |
+
|
128 |
+
def update_sales_abs():
|
129 |
+
actual_sales = st.session_state["scenario"].actual_total_sales
|
130 |
+
if validate_input(st.session_state["total_sales_change_abs"]):
|
131 |
+
modified_sales = extract_number_for_string(
|
132 |
+
st.session_state["total_sales_change_abs"]
|
133 |
+
)
|
134 |
+
st.session_state["total_sales_change"] = round(
|
135 |
+
((modified_sales / actual_sales) - 1) * 100
|
136 |
+
)
|
137 |
+
st.session_state["total_sales_change_abs_slider"] = numerize(
|
138 |
+
modified_sales, 1
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
def update_sales():
|
143 |
+
# print("DEBUG: running update_sales")
|
144 |
+
# st.session_state["project_dct"]["scenario_planner"][
|
145 |
+
# "total_sales_change"
|
146 |
+
# ] = st.session_state.total_sales_change
|
147 |
+
# st.session_state["total_spends_change"] = st.session_state[
|
148 |
+
# "total_sales_change"
|
149 |
+
# ]
|
150 |
+
|
151 |
+
st.session_state["total_sales_change_abs"] = numerize(
|
152 |
+
(1 + st.session_state["total_sales_change"] / 100)
|
153 |
+
* st.session_state["scenario"].actual_total_sales,
|
154 |
+
1,
|
155 |
+
)
|
156 |
+
st.session_state["total_sales_change_abs_slider"] = numerize(
|
157 |
+
(1 + st.session_state["total_sales_change"] / 100)
|
158 |
+
* st.session_state["scenario"].actual_total_sales,
|
159 |
+
1,
|
160 |
+
)
|
161 |
+
# update_spends()
|
162 |
+
|
163 |
+
|
164 |
+
def update_all_spends_abs_slider():
|
165 |
+
actual_spends = st.session_state["scenario"].actual_total_spends
|
166 |
+
if validate_input(st.session_state["total_spends_change_abs_slider"]):
|
167 |
+
modified_spends = extract_number_for_string(
|
168 |
+
st.session_state["total_spends_change_abs_slider"]
|
169 |
+
)
|
170 |
+
st.session_state["total_spends_change"] = round(
|
171 |
+
((modified_spends / actual_spends) - 1) * 100
|
172 |
+
)
|
173 |
+
st.session_state["total_spends_change_abs"] = numerize(
|
174 |
+
modified_spends, 1
|
175 |
+
)
|
176 |
+
|
177 |
+
st.session_state["project_dct"]["scenario_planner"][
|
178 |
+
"total_spends_change"
|
179 |
+
] = st.session_state.total_spends_change
|
180 |
+
|
181 |
+
update_all_spends()
|
182 |
+
|
183 |
+
|
184 |
+
# def update_all_spends_abs_slider():
|
185 |
+
# actual_spends = _scenario.actual_total_spends
|
186 |
+
# if validate_input(st.session_state["total_spends_change_abs_slider"]):
|
187 |
+
# print("#" * 100)
|
188 |
+
# print(st.session_state["total_spends_change_abs_slider"])
|
189 |
+
# print("#" * 100)
|
190 |
+
|
191 |
+
# modified_spends = extract_number_for_string(
|
192 |
+
# st.session_state["total_spends_change_abs_slider"]
|
193 |
+
# )
|
194 |
+
# st.session_state["total_spends_change"] = (
|
195 |
+
# (modified_spends / actual_spends) - 1
|
196 |
+
# ) * 100
|
197 |
+
# st.session_state["total_spends_change_abs"] = st.session_state[
|
198 |
+
# "total_spends_change_abs_slider"
|
199 |
+
# ]
|
200 |
+
|
201 |
+
# update_all_spends()
|
202 |
+
|
203 |
+
|
204 |
+
def update_all_spends_abs():
|
205 |
+
print("DEBUG: ", "inside update_all_spends_abs")
|
206 |
+
# print(st.session_state["total_spends_change_abs_slider_options"])
|
207 |
+
|
208 |
+
actual_spends = st.session_state["scenario"].actual_total_spends
|
209 |
+
if validate_input(st.session_state["total_spends_change_abs"]):
|
210 |
+
modified_spends = extract_number_for_string(
|
211 |
+
st.session_state["total_spends_change_abs"]
|
212 |
+
)
|
213 |
+
st.session_state["total_spends_change"] = (
|
214 |
+
(modified_spends / actual_spends) - 1
|
215 |
+
) * 100
|
216 |
+
st.session_state["total_spends_change_abs_slider"] = numerize(
|
217 |
+
extract_number_for_string(
|
218 |
+
st.session_state["total_spends_change_abs"]
|
219 |
+
),
|
220 |
+
1,
|
221 |
+
)
|
222 |
+
|
223 |
+
st.session_state["project_dct"]["scenario_planner"][
|
224 |
+
"total_spends_change"
|
225 |
+
] = st.session_state.total_spends_change
|
226 |
+
|
227 |
+
# print(
|
228 |
+
# "DEBUG UPDATE_ALL_SPENDS_ABS: ",
|
229 |
+
# st.session_state["total_spends_change"],
|
230 |
+
# )
|
231 |
+
update_all_spends()
|
232 |
+
|
233 |
+
|
234 |
+
def update_spends():
|
235 |
+
print("update_spends")
|
236 |
+
st.session_state["total_spends_change_abs"] = numerize(
|
237 |
+
(1 + st.session_state["total_spends_change"] / 100)
|
238 |
+
* st.session_state["scenario"].actual_total_spends,
|
239 |
+
1,
|
240 |
+
)
|
241 |
+
st.session_state["total_spends_change_abs_slider"] = numerize(
|
242 |
+
(1 + st.session_state["total_spends_change"] / 100)
|
243 |
+
* st.session_state["scenario"].actual_total_spends,
|
244 |
+
1,
|
245 |
+
)
|
246 |
+
|
247 |
+
st.session_state["project_dct"]["scenario_planner"][
|
248 |
+
"total_spends_change"
|
249 |
+
] = st.session_state.total_spends_change
|
250 |
+
|
251 |
+
update_all_spends()
|
252 |
+
|
253 |
+
|
254 |
+
def update_all_spends():
|
255 |
+
"""
|
256 |
+
Updates spends for all the channels with the given overall spends change
|
257 |
+
"""
|
258 |
+
percent_change = st.session_state["total_spends_change"]
|
259 |
+
print("runs update_all")
|
260 |
+
for channel_name in list(
|
261 |
+
st.session_state["project_dct"]["scenario_planner"][
|
262 |
+
unique_key
|
263 |
+
].channels.keys()
|
264 |
+
):
|
265 |
+
st.session_state[f"{channel_name}_percent"] = percent_change
|
266 |
+
channel = st.session_state["scenario"].channels[channel_name]
|
267 |
+
current_spends = channel.actual_total_spends
|
268 |
+
modified_spends = (1 + percent_change / 100) * current_spends
|
269 |
+
st.session_state["scenario"].update(channel_name, modified_spends)
|
270 |
+
st.session_state[channel_name] = numerize(
|
271 |
+
modified_spends * channel.conversion_rate, 1
|
272 |
+
)
|
273 |
+
st.session_state[f"{channel_name}_change"] = percent_change
|
274 |
+
|
275 |
+
|
276 |
+
def extract_number_for_string(string_input):
|
277 |
+
string_input = string_input.upper()
|
278 |
+
if string_input.endswith("K"):
|
279 |
+
return float(string_input[:-1]) * 10**3
|
280 |
+
elif string_input.endswith("M"):
|
281 |
+
return float(string_input[:-1]) * 10**6
|
282 |
+
elif string_input.endswith("B"):
|
283 |
+
return float(string_input[:-1]) * 10**9
|
284 |
+
|
285 |
+
|
286 |
+
def validate_input(string_input):
|
287 |
+
pattern = r"\d+\.?\d*[K|M|B]$"
|
288 |
+
match = re.match(pattern, string_input)
|
289 |
+
if match is None:
|
290 |
+
return False
|
291 |
+
return True
|
292 |
+
|
293 |
+
|
294 |
+
def update_data_by_percent(channel_name):
|
295 |
+
prev_spends = (
|
296 |
+
st.session_state["scenario"].channels[channel_name].actual_total_spends
|
297 |
+
* st.session_state["scenario"].channels[channel_name].conversion_rate
|
298 |
+
)
|
299 |
+
modified_spends = prev_spends * (
|
300 |
+
1 + st.session_state[f"{channel_name}_percent"] / 100
|
301 |
+
)
|
302 |
+
|
303 |
+
st.session_state[channel_name] = numerize(modified_spends, 1)
|
304 |
+
|
305 |
+
st.session_state["scenario"].update(
|
306 |
+
channel_name,
|
307 |
+
modified_spends
|
308 |
+
/ st.session_state["scenario"].channels[channel_name].conversion_rate,
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
def update_data(channel_name):
|
313 |
+
"""
|
314 |
+
Updates the spends for the given channel
|
315 |
+
"""
|
316 |
+
print("tuns update_Data")
|
317 |
+
if validate_input(st.session_state[channel_name]):
|
318 |
+
modified_spends = extract_number_for_string(
|
319 |
+
st.session_state[channel_name]
|
320 |
+
)
|
321 |
+
|
322 |
+
prev_spends = (
|
323 |
+
st.session_state["scenario"]
|
324 |
+
.channels[channel_name]
|
325 |
+
.actual_total_spends
|
326 |
+
* st.session_state["scenario"]
|
327 |
+
.channels[channel_name]
|
328 |
+
.conversion_rate
|
329 |
+
)
|
330 |
+
st.session_state[f"{channel_name}_percent"] = round(
|
331 |
+
100 * (modified_spends - prev_spends) / prev_spends, 2
|
332 |
+
)
|
333 |
+
st.session_state["scenario"].update(
|
334 |
+
channel_name,
|
335 |
+
modified_spends
|
336 |
+
/ st.session_state["scenario"]
|
337 |
+
.channels[channel_name]
|
338 |
+
.conversion_rate,
|
339 |
+
)
|
340 |
+
# st.session_state['scenario'].update(channel_name, modified_spends)
|
341 |
+
# else:
|
342 |
+
# try:
|
343 |
+
# modified_spends = float(st.session_state[channel_name])
|
344 |
+
# prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate
|
345 |
+
# st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2)
|
346 |
+
# st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate)
|
347 |
+
# st.session_state[f'{channel_name}'] = numerize(modified_spends,1)
|
348 |
+
# except ValueError:
|
349 |
+
# st.write('Invalid input')
|
350 |
+
|
351 |
+
|
352 |
+
def select_channel_for_optimization(channel_name):
|
353 |
+
"""
|
354 |
+
Marks the given channel for optimization
|
355 |
+
"""
|
356 |
+
st.session_state["optimization_channels"][channel_name] = st.session_state[
|
357 |
+
f"{channel_name}_selected"
|
358 |
+
]
|
359 |
+
|
360 |
+
|
361 |
+
def select_all_channels_for_optimization():
|
362 |
+
"""
|
363 |
+
Marks all the channel for optimization
|
364 |
+
"""
|
365 |
+
# print(
|
366 |
+
# "DEBUG: select_all_channels_for_opt",
|
367 |
+
# st.session_state["optimze_all_channels"],
|
368 |
+
# )
|
369 |
+
|
370 |
+
for channel_name in st.session_state["optimization_channels"].keys():
|
371 |
+
st.session_state[f"{channel_name}_selected"] = st.session_state[
|
372 |
+
"optimze_all_channels"
|
373 |
+
]
|
374 |
+
st.session_state["optimization_channels"][channel_name] = (
|
375 |
+
st.session_state["optimze_all_channels"]
|
376 |
+
)
|
377 |
+
from pprint import pprint
|
378 |
+
|
379 |
+
|
380 |
+
def update_penalty():
|
381 |
+
"""
|
382 |
+
Updates the penalty flag for sales calculation
|
383 |
+
"""
|
384 |
+
st.session_state["scenario"].update_penalty(
|
385 |
+
st.session_state["apply_penalty"]
|
386 |
+
)
|
387 |
+
|
388 |
+
|
389 |
+
def reset_optimization():
|
390 |
+
print("DEBUG: ", "Running reset_optimization")
|
391 |
+
for channel_name in list(
|
392 |
+
st.session_state["project_dct"]["scenario_planner"][
|
393 |
+
unique_key
|
394 |
+
].channels.keys()
|
395 |
+
):
|
396 |
+
st.session_state[f"{channel_name}_selected"] = False
|
397 |
+
# st.session_state[f"{channel_name}_change"] = 0
|
398 |
+
st.session_state["optimze_all_channels"] = False
|
399 |
+
st.session_state["initialized"] = False
|
400 |
+
del st.session_state["total_sales_change_abs_slider"]
|
401 |
+
del st.session_state["total_sales_change_abs"]
|
402 |
+
del st.session_state["total_sales_change"]
|
403 |
+
|
404 |
+
|
405 |
+
def reset_scenario():
|
406 |
+
print("[DEBUG]: reset_scenario")
|
407 |
+
# def reset_scenario(panel_selected, file_selected, updated_rcs):
|
408 |
+
# #print(st.session_state['default_scenario_dict'])
|
409 |
+
# st.session_state['scenario'] = class_from_dict(st.session_state['default_scenario_dict'])
|
410 |
+
# for channel in st.session_state['scenario'].channels.values():
|
411 |
+
# st.session_state[channel.name] = float(channel.actual_total_spends * channel.conversion_rate)
|
412 |
+
for channel_name in list(
|
413 |
+
st.session_state["project_dct"]["scenario_planner"][
|
414 |
+
unique_key
|
415 |
+
].channels.keys()
|
416 |
+
):
|
417 |
+
st.session_state[f"{channel_name}_selected"] = False
|
418 |
+
# st.session_state[f"{channel_name}_change"] = 0
|
419 |
+
st.session_state["optimze_all_channels"] = False
|
420 |
+
st.session_state["initialized"] = False
|
421 |
+
|
422 |
+
del st.session_state["optimization_channels"]
|
423 |
+
panel_selected = st.session_state.get("panel_selected", 0)
|
424 |
+
file_selected = st.session_state["file_selected"]
|
425 |
+
update_rcs = st.session_state.get("update_rcs", None)
|
426 |
+
|
427 |
+
# print(f"## [DEBUG] [SCENARIO PLANNER][RESET SCENARIO]: {}")
|
428 |
+
del st.session_state["project_dct"]["scenario_planner"][
|
429 |
+
f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
|
430 |
+
]
|
431 |
+
del st.session_state["total_sales_change_abs_slider"]
|
432 |
+
del st.session_state["total_sales_change_abs"]
|
433 |
+
del st.session_state["total_sales_change"]
|
434 |
+
# if panel_selected == "Aggregated":
|
435 |
+
# initialize_data(
|
436 |
+
# panel=panel_selected,
|
437 |
+
# target_file=file_selected,
|
438 |
+
# updated_rcs=updated_rcs,
|
439 |
+
# metrics=metrics_selected,
|
440 |
+
# )
|
441 |
+
# panel = None
|
442 |
+
# else:
|
443 |
+
# initialize_data(
|
444 |
+
# panel=panel_selected,
|
445 |
+
# target_file=file_selected,
|
446 |
+
# updated_rcs=updated_rcs,
|
447 |
+
# metrics=metrics_selected,
|
448 |
+
# )
|
449 |
+
# st.session_state["total_spends_change"] = 0
|
450 |
+
# update_all_spends()
|
451 |
+
|
452 |
+
|
453 |
+
def format_number(num):
|
454 |
+
if num >= 1_000_000:
|
455 |
+
return f"{num / 1_000_000:.2f}M"
|
456 |
+
elif num >= 1_000:
|
457 |
+
return f"{num / 1_000:.0f}K"
|
458 |
+
else:
|
459 |
+
return f"{num:.2f}"
|
460 |
+
|
461 |
+
|
462 |
+
def summary_plot(data, x, y, title, text_column):
|
463 |
+
fig = px.bar(
|
464 |
+
data,
|
465 |
+
x=x,
|
466 |
+
y=y,
|
467 |
+
orientation="h",
|
468 |
+
title=title,
|
469 |
+
text=text_column,
|
470 |
+
color="Channel_name",
|
471 |
+
)
|
472 |
+
|
473 |
+
# Convert text_column to numeric values
|
474 |
+
data[text_column] = pd.to_numeric(data[text_column], errors="coerce")
|
475 |
+
|
476 |
+
# Update the format of the displayed text based on magnitude
|
477 |
+
fig.update_traces(
|
478 |
+
texttemplate="%{text:.2s}",
|
479 |
+
textposition="outside",
|
480 |
+
hovertemplate="%{x:.2s}",
|
481 |
+
)
|
482 |
+
|
483 |
+
fig.update_layout(
|
484 |
+
xaxis_title=x, yaxis_title="Channel Name", showlegend=False
|
485 |
+
)
|
486 |
+
return fig
|
487 |
+
|
488 |
+
|
489 |
+
def s_curve(x, K, b, a, x0):
|
490 |
+
return K / (1 + b * np.exp(-a * (x - x0)))
|
491 |
+
|
492 |
+
|
493 |
+
def find_segment_value(x, roi, mroi):
|
494 |
+
start_value = x[0]
|
495 |
+
end_value = x[len(x) - 1]
|
496 |
+
|
497 |
+
# Condition for green region: Both MROI and ROI > 1
|
498 |
+
green_condition = (roi > 1) & (mroi > 1)
|
499 |
+
left_indices = np.where(green_condition)[0]
|
500 |
+
left_value = x[left_indices[0]] if left_indices.size > 0 else x[0]
|
501 |
+
|
502 |
+
right_indices = np.where(green_condition)[0]
|
503 |
+
right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0]
|
504 |
+
|
505 |
+
return start_value, end_value, left_value, right_value
|
506 |
+
|
507 |
+
|
508 |
+
def calculate_rgba(
|
509 |
+
start_value, end_value, left_value, right_value, current_channel_spends
|
510 |
+
):
|
511 |
+
# Initialize alpha to None for clarity
|
512 |
+
alpha = None
|
513 |
+
|
514 |
+
# Determine the color and calculate relative_position and alpha based on the point's position
|
515 |
+
if start_value <= current_channel_spends <= left_value:
|
516 |
+
color = "yellow"
|
517 |
+
relative_position = (current_channel_spends - start_value) / (
|
518 |
+
left_value - start_value
|
519 |
+
)
|
520 |
+
alpha = 0.8 - (
|
521 |
+
0.6 * relative_position
|
522 |
+
) # Alpha decreases from start to end
|
523 |
+
|
524 |
+
elif left_value < current_channel_spends <= right_value:
|
525 |
+
color = "green"
|
526 |
+
relative_position = (current_channel_spends - left_value) / (
|
527 |
+
right_value - left_value
|
528 |
+
)
|
529 |
+
alpha = 0.8 - (
|
530 |
+
0.6 * relative_position
|
531 |
+
) # Alpha decreases from start to end
|
532 |
+
|
533 |
+
elif right_value < current_channel_spends <= end_value:
|
534 |
+
color = "red"
|
535 |
+
relative_position = (current_channel_spends - right_value) / (
|
536 |
+
end_value - right_value
|
537 |
+
)
|
538 |
+
alpha = 0.2 + (
|
539 |
+
0.6 * relative_position
|
540 |
+
) # Alpha increases from start to end
|
541 |
+
|
542 |
+
else:
|
543 |
+
# Default case, if the spends are outside the defined ranges
|
544 |
+
return "rgba(136, 136, 136, 0.5)" # Grey for values outside the range
|
545 |
+
|
546 |
+
# Ensure alpha is within the intended range in case of any calculation overshoot
|
547 |
+
alpha = max(0.2, min(alpha, 0.8))
|
548 |
+
|
549 |
+
# Define color codes for RGBA
|
550 |
+
color_codes = {
|
551 |
+
"yellow": "255, 255, 0", # RGB for yellow
|
552 |
+
"green": "0, 128, 0", # RGB for green
|
553 |
+
"red": "255, 0, 0", # RGB for red
|
554 |
+
}
|
555 |
+
|
556 |
+
rgba = f"rgba({color_codes[color]}, {alpha})"
|
557 |
+
return rgba
|
558 |
+
|
559 |
+
|
560 |
+
def debug_temp(x_test, power, K, b, a, x0):
|
561 |
+
print("*" * 100)
|
562 |
+
# Calculate the count of bins
|
563 |
+
count_lower_bin = sum(1 for x in x_test if x <= 2524)
|
564 |
+
count_center_bin = sum(1 for x in x_test if x > 2524 and x <= 3377)
|
565 |
+
count_ = sum(1 for x in x_test if x > 3377)
|
566 |
+
|
567 |
+
print(
|
568 |
+
f"""
|
569 |
+
lower : {count_lower_bin}
|
570 |
+
center : {count_center_bin}
|
571 |
+
upper : {count_}
|
572 |
+
"""
|
573 |
+
)
|
574 |
+
|
575 |
+
|
576 |
+
# @st.cache
|
577 |
+
def plot_response_curves():
|
578 |
+
cols = 4
|
579 |
+
rows = (
|
580 |
+
len(channels_list) // cols
|
581 |
+
if len(channels_list) % cols == 0
|
582 |
+
else len(channels_list) // cols + 1
|
583 |
+
)
|
584 |
+
rcs = st.session_state["rcs"]
|
585 |
+
shapes = []
|
586 |
+
fig = make_subplots(rows=rows, cols=cols, subplot_titles=channels_list)
|
587 |
+
for i in range(0, len(channels_list)):
|
588 |
+
col = channels_list[i]
|
589 |
+
x_actual = st.session_state["scenario"].channels[col].actual_spends
|
590 |
+
# x_modified = st.session_state["scenario"].channels[col].modified_spends
|
591 |
+
|
592 |
+
power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
|
593 |
+
|
594 |
+
K = rcs[col]["K"]
|
595 |
+
b = rcs[col]["b"]
|
596 |
+
a = rcs[col]["a"]
|
597 |
+
x0 = rcs[col]["x0"]
|
598 |
+
|
599 |
+
x_plot = np.linspace(0, 5 * x_actual.sum(), 50)
|
600 |
+
|
601 |
+
x, y, marginal_roi = [], [], []
|
602 |
+
for x_p in x_plot:
|
603 |
+
x.append(x_p * x_actual / x_actual.sum())
|
604 |
+
|
605 |
+
for index in range(len(x_plot)):
|
606 |
+
y.append(s_curve(x[index] / 10**power, K, b, a, x0))
|
607 |
+
|
608 |
+
for index in range(len(x_plot)):
|
609 |
+
marginal_roi.append(
|
610 |
+
a
|
611 |
+
* y[index]
|
612 |
+
* (1 - y[index] / np.maximum(K, np.finfo(float).eps))
|
613 |
+
)
|
614 |
+
|
615 |
+
x = (
|
616 |
+
np.sum(x, axis=1)
|
617 |
+
* st.session_state["scenario"].channels[col].conversion_rate
|
618 |
+
)
|
619 |
+
y = np.sum(y, axis=1)
|
620 |
+
marginal_roi = (
|
621 |
+
np.average(marginal_roi, axis=1)
|
622 |
+
/ st.session_state["scenario"].channels[col].conversion_rate
|
623 |
+
)
|
624 |
+
|
625 |
+
roi = y / np.maximum(x, np.finfo(float).eps)
|
626 |
+
|
627 |
+
fig.add_trace(
|
628 |
+
go.Scatter(
|
629 |
+
x=x,
|
630 |
+
y=y,
|
631 |
+
name=col,
|
632 |
+
customdata=np.stack((roi, marginal_roi), axis=-1),
|
633 |
+
hovertemplate="Spend:%{x:$.2s}<br>Sale:%{y:$.2s}<br>ROI:%{customdata[0]:.3f}<br>MROI:%{customdata[1]:.3f}",
|
634 |
+
line=dict(color="blue"),
|
635 |
+
),
|
636 |
+
row=1 + (i) // cols,
|
637 |
+
col=i % cols + 1,
|
638 |
+
)
|
639 |
+
|
640 |
+
x_optimal = (
|
641 |
+
st.session_state["scenario"].channels[col].modified_total_spends
|
642 |
+
* st.session_state["scenario"].channels[col].conversion_rate
|
643 |
+
)
|
644 |
+
y_optimal = (
|
645 |
+
st.session_state["scenario"].channels[col].modified_total_sales
|
646 |
+
)
|
647 |
+
|
648 |
+
# if col == "Paid_social_others":
|
649 |
+
# debug_temp(x_optimal * x_actual / x_actual.sum(), power, K, b, a, x0)
|
650 |
+
|
651 |
+
fig.add_trace(
|
652 |
+
go.Scatter(
|
653 |
+
x=[x_optimal],
|
654 |
+
y=[y_optimal],
|
655 |
+
name=col,
|
656 |
+
legendgroup=col,
|
657 |
+
showlegend=False,
|
658 |
+
marker=dict(color=["black"]),
|
659 |
+
),
|
660 |
+
row=1 + (i) // cols,
|
661 |
+
col=i % cols + 1,
|
662 |
+
)
|
663 |
+
|
664 |
+
shapes.append(
|
665 |
+
go.layout.Shape(
|
666 |
+
type="line",
|
667 |
+
x0=0,
|
668 |
+
y0=y_optimal,
|
669 |
+
x1=x_optimal,
|
670 |
+
y1=y_optimal,
|
671 |
+
line_width=1,
|
672 |
+
line_dash="dash",
|
673 |
+
line_color="black",
|
674 |
+
xref=f"x{i+1}",
|
675 |
+
yref=f"y{i+1}",
|
676 |
+
)
|
677 |
+
)
|
678 |
+
|
679 |
+
shapes.append(
|
680 |
+
go.layout.Shape(
|
681 |
+
type="line",
|
682 |
+
x0=x_optimal,
|
683 |
+
y0=0,
|
684 |
+
x1=x_optimal,
|
685 |
+
y1=y_optimal,
|
686 |
+
line_width=1,
|
687 |
+
line_dash="dash",
|
688 |
+
line_color="black",
|
689 |
+
xref=f"x{i+1}",
|
690 |
+
yref=f"y{i+1}",
|
691 |
+
)
|
692 |
+
)
|
693 |
+
|
694 |
+
start_value, end_value, left_value, right_value = find_segment_value(
|
695 |
+
x,
|
696 |
+
roi,
|
697 |
+
marginal_roi,
|
698 |
+
)
|
699 |
+
|
700 |
+
# Adding background colors
|
701 |
+
y_max = y.max() * 1.3 # 30% extra space above the max
|
702 |
+
|
703 |
+
# Yellow region
|
704 |
+
shapes.append(
|
705 |
+
go.layout.Shape(
|
706 |
+
type="rect",
|
707 |
+
x0=start_value,
|
708 |
+
y0=0,
|
709 |
+
x1=left_value,
|
710 |
+
y1=y_max,
|
711 |
+
line=dict(width=0),
|
712 |
+
fillcolor="rgba(255, 255, 0, 0.3)",
|
713 |
+
layer="below",
|
714 |
+
xref=f"x{i+1}",
|
715 |
+
yref=f"y{i+1}",
|
716 |
+
)
|
717 |
+
)
|
718 |
+
|
719 |
+
# Green region
|
720 |
+
shapes.append(
|
721 |
+
go.layout.Shape(
|
722 |
+
type="rect",
|
723 |
+
x0=left_value,
|
724 |
+
y0=0,
|
725 |
+
x1=right_value,
|
726 |
+
y1=y_max,
|
727 |
+
line=dict(width=0),
|
728 |
+
fillcolor="rgba(0, 255, 0, 0.3)",
|
729 |
+
layer="below",
|
730 |
+
xref=f"x{i+1}",
|
731 |
+
yref=f"y{i+1}",
|
732 |
+
)
|
733 |
+
)
|
734 |
+
|
735 |
+
# Red region
|
736 |
+
shapes.append(
|
737 |
+
go.layout.Shape(
|
738 |
+
type="rect",
|
739 |
+
x0=right_value,
|
740 |
+
y0=0,
|
741 |
+
x1=end_value,
|
742 |
+
y1=y_max,
|
743 |
+
line=dict(width=0),
|
744 |
+
fillcolor="rgba(255, 0, 0, 0.3)",
|
745 |
+
layer="below",
|
746 |
+
xref=f"x{i+1}",
|
747 |
+
yref=f"y{i+1}",
|
748 |
+
)
|
749 |
+
)
|
750 |
+
|
751 |
+
fig.update_layout(
|
752 |
+
# height=1000,
|
753 |
+
# width=1000,
|
754 |
+
title_text=f"Response Curves (X: Spends Vs Y: {target})",
|
755 |
+
showlegend=False,
|
756 |
+
shapes=shapes,
|
757 |
+
)
|
758 |
+
fig.update_annotations(font_size=10)
|
759 |
+
# fig.update_xaxes(title="Spends")
|
760 |
+
# fig.update_yaxes(title=target)
|
761 |
+
fig.update_yaxes(
|
762 |
+
gridcolor="rgba(136, 136, 136, 0.5)", gridwidth=0.5, griddash="dash"
|
763 |
+
)
|
764 |
+
|
765 |
+
return fig
|
766 |
+
|
767 |
+
|
768 |
+
# ======================================================== #
|
769 |
+
# ==================== HTML Components =================== #
|
770 |
+
# ======================================================== #
|
771 |
+
|
772 |
+
|
773 |
+
def generate_spending_header(heading):
|
774 |
+
return st.markdown(
|
775 |
+
f"""<h2 class="spends-header">{heading}</h2>""", unsafe_allow_html=True
|
776 |
+
)
|
777 |
+
|
778 |
+
|
779 |
+
def save_checkpoint():
|
780 |
+
project_dct_path = os.path.join(
|
781 |
+
st.session_state["project_path"], "project_dct.pkl"
|
782 |
+
)
|
783 |
+
|
784 |
+
try:
|
785 |
+
pickle.dumps(st.session_state["project_dct"])
|
786 |
+
with open(project_dct_path, "wb") as f:
|
787 |
+
pickle.dump(st.session_state["project_dct"], f)
|
788 |
+
except Exception:
|
789 |
+
# with warning_placeholder:
|
790 |
+
st.toast("Unknown Issue, please reload the page.")
|
791 |
+
|
792 |
+
|
793 |
+
def reset_checkpoint():
|
794 |
+
st.session_state["project_dct"]["scenario_planner"] = {}
|
795 |
+
save_checkpoint()
|
796 |
+
|
797 |
+
|
798 |
+
# ======================================================== #
|
799 |
+
# =================== Session variables ================== #
|
800 |
+
# ======================================================== #
|
801 |
+
|
802 |
+
with open("config.yaml") as file:
|
803 |
+
config = yaml.load(file, Loader=SafeLoader)
|
804 |
+
st.session_state["config"] = config
|
805 |
+
|
806 |
+
authenticator = stauth.Authenticate(
|
807 |
+
config["credentials"],
|
808 |
+
config["cookie"]["name"],
|
809 |
+
config["cookie"]["key"],
|
810 |
+
config["cookie"]["expiry_days"],
|
811 |
+
config["preauthorized"],
|
812 |
+
)
|
813 |
+
st.session_state["authenticator"] = authenticator
|
814 |
+
name, authentication_status, username = authenticator.login("Login", "main")
|
815 |
+
auth_status = st.session_state.get("authentication_status")
|
816 |
+
|
817 |
+
import os
|
818 |
+
import glob
|
819 |
+
|
820 |
+
|
821 |
+
def get_excel_names(directory):
|
822 |
+
# Create a list to hold the final parts of the filenames
|
823 |
+
last_portions = []
|
824 |
+
|
825 |
+
# Patterns to match Excel files (.xlsx and .xls) that contain @#
|
826 |
+
patterns = [
|
827 |
+
os.path.join(directory, "*@#*.xlsx"),
|
828 |
+
os.path.join(directory, "*@#*.xls"),
|
829 |
+
]
|
830 |
+
|
831 |
+
# Process each pattern
|
832 |
+
for pattern in patterns:
|
833 |
+
files = glob.glob(pattern)
|
834 |
+
|
835 |
+
# Extracting the last portion after @# for each file
|
836 |
+
for file in files:
|
837 |
+
base_name = os.path.basename(file)
|
838 |
+
last_portion = base_name.split("@#")[-1]
|
839 |
+
last_portion = last_portion.replace(".xlsx", "").replace(
|
840 |
+
".xls", ""
|
841 |
+
) # Removing extensions
|
842 |
+
last_portions.append(last_portion)
|
843 |
+
|
844 |
+
return last_portions
|
845 |
+
|
846 |
+
|
847 |
+
def name_formating(channel_name):
|
848 |
+
# Replace underscores with spaces
|
849 |
+
name_mod = channel_name.replace("_", " ")
|
850 |
+
|
851 |
+
# Capitalize the first letter of each word
|
852 |
+
name_mod = name_mod.title()
|
853 |
+
|
854 |
+
return name_mod
|
855 |
+
|
856 |
+
|
857 |
+
@st.cache_resource(show_spinner=False)
|
858 |
+
def panel_fetch(file_selected):
|
859 |
+
raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
|
860 |
+
|
861 |
+
if "Panel" in raw_data_mmm_df.columns:
|
862 |
+
panel = list(set(raw_data_mmm_df["Panel"]))
|
863 |
+
else:
|
864 |
+
raw_data_mmm_df = None
|
865 |
+
panel = None
|
866 |
+
|
867 |
+
return panel
|
868 |
+
|
869 |
+
|
870 |
+
if auth_status is True:
|
871 |
+
authenticator.logout("Logout", "main")
|
872 |
+
|
873 |
+
if "project_dct" not in st.session_state:
|
874 |
+
st.error("Please load a project from home")
|
875 |
+
st.stop()
|
876 |
+
|
877 |
+
database_file = r"DB\User.db"
|
878 |
+
|
879 |
+
conn = sqlite3.connect(
|
880 |
+
database_file, check_same_thread=False
|
881 |
+
) # connection with sql db
|
882 |
+
c = conn.cursor()
|
883 |
+
|
884 |
+
with st.sidebar:
|
885 |
+
st.button("Save checkpoint", on_click=save_checkpoint)
|
886 |
+
st.button("Reset Checkpoint", on_click=reset_checkpoint)
|
887 |
+
|
888 |
+
warning_placeholder = st.empty()
|
889 |
+
|
890 |
+
st.header("Scenario Planner")
|
891 |
+
|
892 |
+
st.markdown("**Simulation**")
|
893 |
+
|
894 |
+
# st.subheader("Simulation")
|
895 |
+
col1, col2 = st.columns([1, 1])
|
896 |
+
|
897 |
+
# Get metric and panel from last saved state
|
898 |
+
if "last_saved_metric" not in st.session_state:
|
899 |
+
st.session_state["last_saved_metric"] = st.session_state[
|
900 |
+
"project_dct"
|
901 |
+
]["scenario_planner"].get("metric_selected", 0)
|
902 |
+
# st.session_state["last_saved_metric"] = st.session_state[
|
903 |
+
# "project_dct"
|
904 |
+
# ]["scenario_planner"].get("metric_selected", 0)
|
905 |
+
|
906 |
+
if "last_saved_panel" not in st.session_state:
|
907 |
+
st.session_state["last_saved_panel"] = st.session_state["project_dct"][
|
908 |
+
"scenario_planner"
|
909 |
+
].get("panel_selected", 0)
|
910 |
+
# st.session_state["last_saved_panel"] = st.session_state["project_dct"][
|
911 |
+
# "scenario_planner"
|
912 |
+
# ].get("panel_selected", 0)
|
913 |
+
|
914 |
+
# Response Metrics
|
915 |
+
directory = "metrics_level_data"
|
916 |
+
metrics_list = get_excel_names(directory)
|
917 |
+
metrics_selected = col1.selectbox(
|
918 |
+
"Response Metrics",
|
919 |
+
metrics_list,
|
920 |
+
format_func=name_formating,
|
921 |
+
index=st.session_state["last_saved_metric"],
|
922 |
+
on_change=reset_optimization,
|
923 |
+
key="metric_selected",
|
924 |
+
)
|
925 |
+
|
926 |
+
# Target
|
927 |
+
target = name_formating(metrics_selected)
|
928 |
+
|
929 |
+
file_selected = f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
|
930 |
+
# print(f"[DEBUG]: {metrics_selected}")
|
931 |
+
# print(f"[DEBUG]: {file_selected}")
|
932 |
+
st.session_state["file_selected"] = file_selected
|
933 |
+
# Panel List
|
934 |
+
panel_list = panel_fetch(file_selected)
|
935 |
+
panel_list_final = ["Aggregated"] + panel_list
|
936 |
+
|
937 |
+
# Panel Selected
|
938 |
+
panel_selected = col2.selectbox(
|
939 |
+
"Panel",
|
940 |
+
panel_list_final,
|
941 |
+
on_change=reset_optimization,
|
942 |
+
key="panel_selected",
|
943 |
+
index=st.session_state["last_saved_panel"],
|
944 |
+
)
|
945 |
+
|
946 |
+
unique_key = f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
|
947 |
+
|
948 |
+
if "update_rcs" in st.session_state:
|
949 |
+
updated_rcs = st.session_state["update_rcs"]
|
950 |
+
else:
|
951 |
+
updated_rcs = None
|
952 |
+
|
953 |
+
if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
|
954 |
+
if panel_selected == "Aggregated":
|
955 |
+
initialize_data(
|
956 |
+
panel=panel_selected,
|
957 |
+
target_file=file_selected,
|
958 |
+
updated_rcs=updated_rcs,
|
959 |
+
metrics=metrics_selected,
|
960 |
+
)
|
961 |
+
panel = None
|
962 |
+
else:
|
963 |
+
initialize_data(
|
964 |
+
panel=panel_selected,
|
965 |
+
target_file=file_selected,
|
966 |
+
updated_rcs=updated_rcs,
|
967 |
+
metrics=metrics_selected,
|
968 |
+
)
|
969 |
+
st.session_state["project_dct"]["scenario_planner"][unique_key] = (
|
970 |
+
st.session_state["scenario"]
|
971 |
+
)
|
972 |
+
|
973 |
+
else:
|
974 |
+
st.session_state["scenario"] = st.session_state["project_dct"][
|
975 |
+
"scenario_planner"
|
976 |
+
][unique_key]
|
977 |
+
st.session_state["rcs"] = {}
|
978 |
+
st.session_state["powers"] = {}
|
979 |
+
if "optimization_channels" not in st.session_state:
|
980 |
+
st.session_state["optimization_channels"] = {}
|
981 |
+
|
982 |
+
for channel_name, _channel in st.session_state["project_dct"][
|
983 |
+
"scenario_planner"
|
984 |
+
][unique_key].channels.items():
|
985 |
+
st.session_state[channel_name] = numerize(
|
986 |
+
_channel.modified_total_spends, 1
|
987 |
+
)
|
988 |
+
st.session_state["rcs"][
|
989 |
+
channel_name
|
990 |
+
] = _channel.response_curve_params
|
991 |
+
st.session_state["powers"][channel_name] = _channel.power
|
992 |
+
if channel_name not in st.session_state["optimization_channels"]:
|
993 |
+
st.session_state["optimization_channels"][channel_name] = False
|
994 |
+
|
995 |
+
if "first_time" not in st.session_state:
|
996 |
+
st.session_state["first_time"] = True
|
997 |
+
st.session_state["first_run_scenario"] = True
|
998 |
+
|
999 |
+
# Check if state is initiaized
|
1000 |
+
is_state_initiaized = st.session_state.get("initialized", False)
|
1001 |
+
|
1002 |
+
# if not is_state_initiaized:
|
1003 |
+
# print("running initialize...")
|
1004 |
+
# # initialize_data()
|
1005 |
+
# if panel_selected == "Aggregated":
|
1006 |
+
# initialize_data(
|
1007 |
+
# panel=panel_selected,
|
1008 |
+
# target_file=file_selected,
|
1009 |
+
# updated_rcs=updated_rcs,
|
1010 |
+
# metrics=metrics_selected,
|
1011 |
+
# )
|
1012 |
+
# panel = None
|
1013 |
+
# else:
|
1014 |
+
# initialize_data(
|
1015 |
+
# panel=panel_selected,
|
1016 |
+
# target_file=file_selected,
|
1017 |
+
# updated_rcs=updated_rcs,
|
1018 |
+
# metrics=metrics_selected,
|
1019 |
+
# )
|
1020 |
+
# st.session_state["initialized"] = True
|
1021 |
+
# st.session_state["first_time"] = False
|
1022 |
+
|
1023 |
+
# Channels List
|
1024 |
+
channels_list = list(
|
1025 |
+
st.session_state["project_dct"]["scenario_planner"][
|
1026 |
+
unique_key
|
1027 |
+
].channels.keys()
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
# ======================================================== #
|
1031 |
+
# ========================== UI ========================== #
|
1032 |
+
# ======================================================== #
|
1033 |
+
|
1034 |
+
main_header = st.columns((2, 2))
|
1035 |
+
sub_header = st.columns((1, 1, 1, 1))
|
1036 |
+
# _scenario = st.session_state["scenario"]
|
1037 |
+
|
1038 |
+
st.session_state.total_spends_change = round(
|
1039 |
+
(
|
1040 |
+
st.session_state["scenario"].modified_total_spends
|
1041 |
+
/ st.session_state["scenario"].actual_total_spends
|
1042 |
+
- 1
|
1043 |
+
)
|
1044 |
+
* 100
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
if "total_sales_change" not in st.session_state:
|
1048 |
+
st.session_state.total_sales_change = round(
|
1049 |
+
(
|
1050 |
+
st.session_state["scenario"].modified_total_sales
|
1051 |
+
/ st.session_state["scenario"].actual_total_sales
|
1052 |
+
- 1
|
1053 |
+
)
|
1054 |
+
* 100
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
st.session_state["total_spends_change_abs"] = numerize(
|
1058 |
+
st.session_state["scenario"].modified_total_spends,
|
1059 |
+
1,
|
1060 |
+
)
|
1061 |
+
if "total_sales_change_abs" not in st.session_state:
|
1062 |
+
st.session_state["total_sales_change_abs"] = numerize(
|
1063 |
+
st.session_state["scenario"].modified_total_sales,
|
1064 |
+
1,
|
1065 |
+
)
|
1066 |
+
|
1067 |
+
# if "total_spends_change_abs_slider" not in st.session_state:
|
1068 |
+
st.session_state.total_spends_change_abs_slider = numerize(
|
1069 |
+
st.session_state["scenario"].modified_total_spends, 1
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
if "total_sales_change_abs_slider" not in st.session_state:
|
1073 |
+
st.session_state.total_sales_change_abs_slider = numerize(
|
1074 |
+
st.session_state["scenario"].actual_total_sales, 1
|
1075 |
+
)
|
1076 |
+
|
1077 |
+
st.session_state["allow_sales_update"] = True
|
1078 |
+
|
1079 |
+
st.session_state["allow_spends_update"] = True
|
1080 |
+
|
1081 |
+
# if "panel_selected" not in st.session_state:
|
1082 |
+
# st.session_state["panel_selected"] = 0
|
1083 |
+
|
1084 |
+
with main_header[0]:
|
1085 |
+
st.subheader("Actual")
|
1086 |
+
|
1087 |
+
with main_header[-1]:
|
1088 |
+
st.subheader("Simulated")
|
1089 |
+
|
1090 |
+
with sub_header[0]:
|
1091 |
+
st.metric(
|
1092 |
+
label="Spends",
|
1093 |
+
value=format_numbers(
|
1094 |
+
st.session_state["scenario"].actual_total_spends
|
1095 |
+
),
|
1096 |
+
)
|
1097 |
+
|
1098 |
+
with sub_header[1]:
|
1099 |
+
st.metric(
|
1100 |
+
label=target,
|
1101 |
+
value=format_numbers(
|
1102 |
+
float(st.session_state["scenario"].actual_total_sales),
|
1103 |
+
include_indicator=False,
|
1104 |
+
),
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
with sub_header[2]:
|
1108 |
+
st.metric(
|
1109 |
+
label="Spends",
|
1110 |
+
value=format_numbers(
|
1111 |
+
st.session_state["scenario"].modified_total_spends
|
1112 |
+
),
|
1113 |
+
delta=numerize(st.session_state["scenario"].delta_spends, 1),
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
with sub_header[3]:
|
1117 |
+
st.metric(
|
1118 |
+
label=target,
|
1119 |
+
value=format_numbers(
|
1120 |
+
float(st.session_state["scenario"].modified_total_sales),
|
1121 |
+
include_indicator=False,
|
1122 |
+
),
|
1123 |
+
delta=numerize(st.session_state["scenario"].delta_sales, 1),
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
with st.expander("Channel Spends Simulator", expanded=True):
|
1127 |
+
_columns1 = st.columns((2, 2, 1, 1))
|
1128 |
+
with _columns1[0]:
|
1129 |
+
optimization_selection = st.selectbox(
|
1130 |
+
"Optimize",
|
1131 |
+
options=["Media Spends", target],
|
1132 |
+
key="optimization_key_value",
|
1133 |
+
)
|
1134 |
+
|
1135 |
+
with _columns1[1]:
|
1136 |
+
st.markdown("#")
|
1137 |
+
# if st.checkbox(
|
1138 |
+
# label="Optimize all Channels",
|
1139 |
+
# key="optimze_all_channels",
|
1140 |
+
# value=False,
|
1141 |
+
# # on_change=select_all_channels_for_optimization,
|
1142 |
+
# ):
|
1143 |
+
# select_all_channels_for_optimization()
|
1144 |
+
|
1145 |
+
st.checkbox(
|
1146 |
+
label="Optimize all Channels",
|
1147 |
+
key="optimze_all_channels",
|
1148 |
+
on_change=select_all_channels_for_optimization,
|
1149 |
+
)
|
1150 |
+
|
1151 |
+
with _columns1[2]:
|
1152 |
+
st.markdown("#")
|
1153 |
+
# st.button(
|
1154 |
+
# "Optimize",
|
1155 |
+
# on_click=optimize,
|
1156 |
+
# args=(st.session_state["optimization_key_value"]),
|
1157 |
+
# use_container_width=True,
|
1158 |
+
# )
|
1159 |
+
|
1160 |
+
optimize_placeholder = st.empty()
|
1161 |
+
|
1162 |
+
with _columns1[3]:
|
1163 |
+
st.markdown("#")
|
1164 |
+
st.button(
|
1165 |
+
"Reset",
|
1166 |
+
on_click=reset_scenario,
|
1167 |
+
# args=(panel_selected, file_selected, updated_rcs),
|
1168 |
+
use_container_width=True,
|
1169 |
+
)
|
1170 |
+
|
1171 |
+
_columns2 = st.columns((2, 2, 2))
|
1172 |
+
if st.session_state["optimization_key_value"] == "Media Spends":
|
1173 |
+
|
1174 |
+
# update_spends()
|
1175 |
+
|
1176 |
+
with _columns2[0]:
|
1177 |
+
spend_input = st.text_input(
|
1178 |
+
"Absolute",
|
1179 |
+
key="total_spends_change_abs",
|
1180 |
+
# label_visibility="collapsed",
|
1181 |
+
on_change=update_all_spends_abs,
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
with _columns2[1]:
|
1185 |
+
st.number_input(
|
1186 |
+
"Percent Change",
|
1187 |
+
key="total_spends_change",
|
1188 |
+
min_value=-50,
|
1189 |
+
max_value=50,
|
1190 |
+
step=1,
|
1191 |
+
on_change=update_spends,
|
1192 |
+
)
|
1193 |
+
|
1194 |
+
with _columns2[2]:
|
1195 |
+
scenario = st.session_state["project_dct"]["scenario_planner"][
|
1196 |
+
unique_key
|
1197 |
+
]
|
1198 |
+
min_value = round(scenario.actual_total_spends * 0.5)
|
1199 |
+
max_value = round(scenario.actual_total_spends * 1.5)
|
1200 |
+
st.session_state["total_spends_change_abs_slider_options"] = [
|
1201 |
+
numerize(value, 1)
|
1202 |
+
for value in range(min_value, max_value + 1, int(1e4))
|
1203 |
+
]
|
1204 |
+
|
1205 |
+
st.select_slider(
|
1206 |
+
"Absolute Slider",
|
1207 |
+
options=st.session_state[
|
1208 |
+
"total_spends_change_abs_slider_options"
|
1209 |
+
],
|
1210 |
+
key="total_spends_change_abs_slider",
|
1211 |
+
on_change=update_all_spends_abs_slider,
|
1212 |
+
)
|
1213 |
+
|
1214 |
+
elif st.session_state["optimization_key_value"] == target:
|
1215 |
+
# update_sales()
|
1216 |
+
|
1217 |
+
with _columns2[0]:
|
1218 |
+
sales_input = st.text_input(
|
1219 |
+
"Absolute",
|
1220 |
+
key="total_sales_change_abs",
|
1221 |
+
on_change=update_sales_abs,
|
1222 |
+
)
|
1223 |
+
|
1224 |
+
with _columns2[1]:
|
1225 |
+
st.number_input(
|
1226 |
+
"Percent Change",
|
1227 |
+
key="total_sales_change",
|
1228 |
+
min_value=-50,
|
1229 |
+
max_value=50,
|
1230 |
+
step=1,
|
1231 |
+
on_change=update_sales,
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
with _columns2[2]:
|
1235 |
+
min_value = round(
|
1236 |
+
st.session_state["scenario"].actual_total_sales * 0.5
|
1237 |
+
)
|
1238 |
+
max_value = round(
|
1239 |
+
st.session_state["scenario"].actual_total_sales * 1.5
|
1240 |
+
)
|
1241 |
+
st.session_state["total_sales_change_abs_slider_options"] = [
|
1242 |
+
numerize(value, 1)
|
1243 |
+
for value in range(min_value, max_value + 1, int(1e5))
|
1244 |
+
]
|
1245 |
+
|
1246 |
+
st.select_slider(
|
1247 |
+
"Absolute Slider",
|
1248 |
+
options=st.session_state[
|
1249 |
+
"total_sales_change_abs_slider_options"
|
1250 |
+
],
|
1251 |
+
key="total_sales_change_abs_slider",
|
1252 |
+
on_change=update_sales_abs_slider,
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
if (
|
1256 |
+
not st.session_state["allow_sales_update"]
|
1257 |
+
and optimization_selection == target
|
1258 |
+
):
|
1259 |
+
st.warning("Invalid Input")
|
1260 |
+
|
1261 |
+
if (
|
1262 |
+
not st.session_state["allow_spends_update"]
|
1263 |
+
and optimization_selection == "Media Spends"
|
1264 |
+
):
|
1265 |
+
st.warning("Invalid Input")
|
1266 |
+
|
1267 |
+
status_placeholder = st.empty()
|
1268 |
+
|
1269 |
+
# if optimize_placeholder.button("Optimize", use_container_width=True):
|
1270 |
+
# optimize(st.session_state["optimization_key_value"], status_placeholder)
|
1271 |
+
# st.rerun()
|
1272 |
+
|
1273 |
+
optimize_placeholder.button(
|
1274 |
+
"Optimize",
|
1275 |
+
on_click=optimize,
|
1276 |
+
args=(
|
1277 |
+
st.session_state["optimization_key_value"],
|
1278 |
+
status_placeholder,
|
1279 |
+
),
|
1280 |
+
use_container_width=True,
|
1281 |
+
)
|
1282 |
+
|
1283 |
+
st.markdown(
|
1284 |
+
"""<hr class="spends-heading-seperator">""", unsafe_allow_html=True
|
1285 |
+
)
|
1286 |
+
_columns = st.columns((2.5, 2, 1.5, 1.5, 1))
|
1287 |
+
with _columns[0]:
|
1288 |
+
generate_spending_header("Channel")
|
1289 |
+
with _columns[1]:
|
1290 |
+
generate_spending_header("Spends Input")
|
1291 |
+
with _columns[2]:
|
1292 |
+
generate_spending_header("Spends")
|
1293 |
+
with _columns[3]:
|
1294 |
+
generate_spending_header(target)
|
1295 |
+
with _columns[4]:
|
1296 |
+
generate_spending_header("Optimize")
|
1297 |
+
|
1298 |
+
st.markdown(
|
1299 |
+
"""<hr class="spends-heading-seperator">""", unsafe_allow_html=True
|
1300 |
+
)
|
1301 |
+
|
1302 |
+
if "acutual_predicted" not in st.session_state:
|
1303 |
+
st.session_state["acutual_predicted"] = {
|
1304 |
+
"Channel_name": [],
|
1305 |
+
"Actual_spend": [],
|
1306 |
+
"Optimized_spend": [],
|
1307 |
+
"Delta": [],
|
1308 |
+
}
|
1309 |
+
for i, channel_name in enumerate(channels_list):
|
1310 |
+
_channel_class = st.session_state["scenario"].channels[
|
1311 |
+
channel_name
|
1312 |
+
]
|
1313 |
+
|
1314 |
+
st.session_state[f"{channel_name}_percent"] = round(
|
1315 |
+
(
|
1316 |
+
_channel_class.modified_total_spends
|
1317 |
+
/ _channel_class.actual_total_spends
|
1318 |
+
- 1
|
1319 |
+
)
|
1320 |
+
* 100
|
1321 |
+
)
|
1322 |
+
|
1323 |
+
_columns = st.columns((2.5, 1.5, 1.5, 1.5, 1))
|
1324 |
+
with _columns[0]:
|
1325 |
+
st.write(channel_name_formating(channel_name))
|
1326 |
+
bin_placeholder = st.container()
|
1327 |
+
|
1328 |
+
with _columns[1]:
|
1329 |
+
channel_bounds = _channel_class.bounds
|
1330 |
+
channel_spends = float(_channel_class.actual_total_spends)
|
1331 |
+
min_value = float(
|
1332 |
+
(1 + channel_bounds[0] / 100) * channel_spends
|
1333 |
+
)
|
1334 |
+
max_value = float(
|
1335 |
+
(1 + channel_bounds[1] / 100) * channel_spends
|
1336 |
+
)
|
1337 |
+
# print("##########", st.session_state[channel_name])
|
1338 |
+
spend_input = st.text_input(
|
1339 |
+
channel_name,
|
1340 |
+
key=channel_name,
|
1341 |
+
label_visibility="collapsed",
|
1342 |
+
on_change=partial(update_data, channel_name),
|
1343 |
+
)
|
1344 |
+
if not validate_input(spend_input):
|
1345 |
+
st.error("Invalid input")
|
1346 |
+
|
1347 |
+
channel_name_current = f"{channel_name}_change"
|
1348 |
+
|
1349 |
+
st.number_input(
|
1350 |
+
"Percent Change",
|
1351 |
+
key=f"{channel_name}_percent",
|
1352 |
+
step=1,
|
1353 |
+
on_change=partial(update_data_by_percent, channel_name),
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
with _columns[2]:
|
1357 |
+
# spends
|
1358 |
+
current_channel_spends = float(
|
1359 |
+
_channel_class.modified_total_spends
|
1360 |
+
* _channel_class.conversion_rate
|
1361 |
+
)
|
1362 |
+
actual_channel_spends = float(
|
1363 |
+
_channel_class.actual_total_spends
|
1364 |
+
* _channel_class.conversion_rate
|
1365 |
+
)
|
1366 |
+
spends_delta = float(
|
1367 |
+
_channel_class.delta_spends
|
1368 |
+
* _channel_class.conversion_rate
|
1369 |
+
)
|
1370 |
+
st.session_state["acutual_predicted"]["Channel_name"].append(
|
1371 |
+
channel_name
|
1372 |
+
)
|
1373 |
+
st.session_state["acutual_predicted"]["Actual_spend"].append(
|
1374 |
+
actual_channel_spends
|
1375 |
+
)
|
1376 |
+
st.session_state["acutual_predicted"][
|
1377 |
+
"Optimized_spend"
|
1378 |
+
].append(current_channel_spends)
|
1379 |
+
st.session_state["acutual_predicted"]["Delta"].append(
|
1380 |
+
spends_delta
|
1381 |
+
)
|
1382 |
+
## REMOVE
|
1383 |
+
st.metric(
|
1384 |
+
"Spends",
|
1385 |
+
format_numbers(current_channel_spends),
|
1386 |
+
delta=numerize(spends_delta, 1),
|
1387 |
+
label_visibility="collapsed",
|
1388 |
+
)
|
1389 |
+
|
1390 |
+
with _columns[3]:
|
1391 |
+
# sales
|
1392 |
+
current_channel_sales = float(
|
1393 |
+
_channel_class.modified_total_sales
|
1394 |
+
)
|
1395 |
+
actual_channel_sales = float(_channel_class.actual_total_sales)
|
1396 |
+
sales_delta = float(_channel_class.delta_sales)
|
1397 |
+
st.metric(
|
1398 |
+
target,
|
1399 |
+
format_numbers(
|
1400 |
+
current_channel_sales, include_indicator=False
|
1401 |
+
),
|
1402 |
+
delta=numerize(sales_delta, 1),
|
1403 |
+
label_visibility="collapsed",
|
1404 |
+
)
|
1405 |
+
|
1406 |
+
with _columns[4]:
|
1407 |
+
|
1408 |
+
# if st.checkbox(
|
1409 |
+
# label="select for optimization",
|
1410 |
+
# key=f"{channel_name}_selected",
|
1411 |
+
# value=False,
|
1412 |
+
# # on_change=partial(select_channel_for_optimization, channel_name),
|
1413 |
+
# label_visibility="collapsed",
|
1414 |
+
# ):
|
1415 |
+
# select_channel_for_optimization(channel_name)
|
1416 |
+
|
1417 |
+
st.checkbox(
|
1418 |
+
label="select for optimization",
|
1419 |
+
key=f"{channel_name}_selected",
|
1420 |
+
value=False,
|
1421 |
+
on_change=partial(
|
1422 |
+
select_channel_for_optimization, channel_name
|
1423 |
+
),
|
1424 |
+
label_visibility="collapsed",
|
1425 |
+
)
|
1426 |
+
|
1427 |
+
st.markdown(
|
1428 |
+
"""<hr class="spends-child-seperator">""",
|
1429 |
+
unsafe_allow_html=True,
|
1430 |
+
)
|
1431 |
+
|
1432 |
+
# Bins
|
1433 |
+
col = channels_list[i]
|
1434 |
+
x_actual = st.session_state["scenario"].channels[col].actual_spends
|
1435 |
+
x_modified = (
|
1436 |
+
st.session_state["scenario"].channels[col].modified_spends
|
1437 |
+
)
|
1438 |
+
|
1439 |
+
x_total = x_modified.sum()
|
1440 |
+
power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
|
1441 |
+
|
1442 |
+
updated_rcs_key = (
|
1443 |
+
f"{metrics_selected}#@{panel_selected}#@{channel_name}"
|
1444 |
+
)
|
1445 |
+
|
1446 |
+
if updated_rcs and updated_rcs_key in list(updated_rcs.keys()):
|
1447 |
+
K = updated_rcs[updated_rcs_key]["K"]
|
1448 |
+
b = updated_rcs[updated_rcs_key]["b"]
|
1449 |
+
a = updated_rcs[updated_rcs_key]["a"]
|
1450 |
+
x0 = updated_rcs[updated_rcs_key]["x0"]
|
1451 |
+
else:
|
1452 |
+
K = st.session_state["rcs"][col]["K"]
|
1453 |
+
b = st.session_state["rcs"][col]["b"]
|
1454 |
+
a = st.session_state["rcs"][col]["a"]
|
1455 |
+
x0 = st.session_state["rcs"][col]["x0"]
|
1456 |
+
|
1457 |
+
x_plot = np.linspace(0, 5 * x_actual.sum(), 200)
|
1458 |
+
|
1459 |
+
# Append current_channel_spends to the end of x_plot
|
1460 |
+
x_plot = np.append(x_plot, current_channel_spends)
|
1461 |
+
|
1462 |
+
x, y, marginal_roi = [], [], []
|
1463 |
+
for x_p in x_plot:
|
1464 |
+
x.append(x_p * x_actual / x_actual.sum())
|
1465 |
+
|
1466 |
+
for index in range(len(x_plot)):
|
1467 |
+
y.append(s_curve(x[index] / 10**power, K, b, a, x0))
|
1468 |
+
|
1469 |
+
for index in range(len(x_plot)):
|
1470 |
+
marginal_roi.append(
|
1471 |
+
a
|
1472 |
+
* y[index]
|
1473 |
+
* (1 - y[index] / np.maximum(K, np.finfo(float).eps))
|
1474 |
+
)
|
1475 |
+
|
1476 |
+
x = (
|
1477 |
+
np.sum(x, axis=1)
|
1478 |
+
* st.session_state["scenario"].channels[col].conversion_rate
|
1479 |
+
)
|
1480 |
+
y = np.sum(y, axis=1)
|
1481 |
+
marginal_roi = (
|
1482 |
+
np.average(marginal_roi, axis=1)
|
1483 |
+
/ st.session_state["scenario"].channels[col].conversion_rate
|
1484 |
+
)
|
1485 |
+
|
1486 |
+
roi = y / np.maximum(x, np.finfo(float).eps)
|
1487 |
+
|
1488 |
+
roi_current, marginal_roi_current = roi[-1], marginal_roi[-1]
|
1489 |
+
x, y, roi, marginal_roi = (
|
1490 |
+
x[:-1],
|
1491 |
+
y[:-1],
|
1492 |
+
roi[:-1],
|
1493 |
+
marginal_roi[:-1],
|
1494 |
+
) # Drop data for current spends
|
1495 |
+
|
1496 |
+
start_value, end_value, left_value, right_value = (
|
1497 |
+
find_segment_value(
|
1498 |
+
x,
|
1499 |
+
roi,
|
1500 |
+
marginal_roi,
|
1501 |
+
)
|
1502 |
+
)
|
1503 |
+
|
1504 |
+
rgba = calculate_rgba(
|
1505 |
+
start_value,
|
1506 |
+
end_value,
|
1507 |
+
left_value,
|
1508 |
+
right_value,
|
1509 |
+
current_channel_spends,
|
1510 |
+
)
|
1511 |
+
|
1512 |
+
with bin_placeholder:
|
1513 |
+
st.markdown(
|
1514 |
+
f"""
|
1515 |
+
<div style="
|
1516 |
+
border-radius: 12px;
|
1517 |
+
background-color: {rgba};
|
1518 |
+
padding: 10px;
|
1519 |
+
text-align: center;
|
1520 |
+
color: #006EC0;
|
1521 |
+
">
|
1522 |
+
<p style="margin: 0; font-size: 20px;">ROI: {round(roi_current,1)}</p>
|
1523 |
+
<p style="margin: 0; font-size: 20px;">Marginal ROI: {round(marginal_roi_current,1)}</p>
|
1524 |
+
</div>
|
1525 |
+
""",
|
1526 |
+
unsafe_allow_html=True,
|
1527 |
+
)
|
1528 |
+
|
1529 |
+
st.session_state["project_dct"]["scenario_planner"]["scenario"] = (
|
1530 |
+
st.session_state["scenario"]
|
1531 |
+
)
|
1532 |
+
|
1533 |
+
with st.expander("See Response Curves", expanded=True):
|
1534 |
+
fig = plot_response_curves()
|
1535 |
+
st.plotly_chart(fig, use_container_width=True)
|
1536 |
+
|
1537 |
+
def update_optimization_bounds(channel_name, bound_type):
|
1538 |
+
index = 0 if bound_type == "lower" else 1
|
1539 |
+
update_key = (
|
1540 |
+
f"{channel_name}_b_lower"
|
1541 |
+
if bound_type == "lower"
|
1542 |
+
else f"{channel_name}_b_upper"
|
1543 |
+
)
|
1544 |
+
st.session_state["project_dct"]["scenario_planner"][
|
1545 |
+
unique_key
|
1546 |
+
].channels[channel_name].bounds[index] = st.session_state[update_key]
|
1547 |
+
|
1548 |
+
def update_optimization_bounds_all(bound_type):
|
1549 |
+
index = 0 if bound_type == "lower" else 1
|
1550 |
+
update_key = (
|
1551 |
+
f"all_b_lower" if bound_type == "lower" else f"all_b_upper"
|
1552 |
+
)
|
1553 |
+
|
1554 |
+
for channel_name, _channel in st.session_state["project_dct"][
|
1555 |
+
"scenario_planner"
|
1556 |
+
][unique_key].channels.items():
|
1557 |
+
_channel.bounds[index] = st.session_state[update_key]
|
1558 |
+
|
1559 |
+
with st.expander("Optimization setup"):
|
1560 |
+
bounds_placeholder = st.container()
|
1561 |
+
with bounds_placeholder:
|
1562 |
+
st.subheader("Optimization Bounds")
|
1563 |
+
with st.container():
|
1564 |
+
bounds_columns = st.columns((1, 0.35, 0.35, 1))
|
1565 |
+
with bounds_columns[0]:
|
1566 |
+
st.write("##")
|
1567 |
+
st.write("Update all channels")
|
1568 |
+
|
1569 |
+
with bounds_columns[1]:
|
1570 |
+
st.number_input(
|
1571 |
+
"Lower",
|
1572 |
+
min_value=-100,
|
1573 |
+
max_value=500,
|
1574 |
+
key=f"all_b_lower",
|
1575 |
+
# label_visibility="hidden",
|
1576 |
+
on_change=update_optimization_bounds_all,
|
1577 |
+
args=("lower",),
|
1578 |
+
step=5,
|
1579 |
+
value=-10,
|
1580 |
+
)
|
1581 |
+
|
1582 |
+
with bounds_columns[2]:
|
1583 |
+
st.number_input(
|
1584 |
+
"Higher",
|
1585 |
+
value=10,
|
1586 |
+
min_value=-100,
|
1587 |
+
max_value=500,
|
1588 |
+
key=f"all_b_upper",
|
1589 |
+
# label_visibility="hidden",
|
1590 |
+
on_change=update_optimization_bounds_all,
|
1591 |
+
args=("upper",),
|
1592 |
+
step=5,
|
1593 |
+
)
|
1594 |
+
st.divider()
|
1595 |
+
|
1596 |
+
st.write("#### Channel wise bounds")
|
1597 |
+
# st.divider()
|
1598 |
+
# bounds_columns = st.columns((1, 0.35, 0.35, 1))
|
1599 |
+
|
1600 |
+
# with bounds_columns[0]:
|
1601 |
+
# st.write("Channel")
|
1602 |
+
# with bounds_columns[1]:
|
1603 |
+
# st.write("Lower")
|
1604 |
+
# with bounds_columns[2]:
|
1605 |
+
# st.write("Upper")
|
1606 |
+
# st.divider()
|
1607 |
+
|
1608 |
+
for channel_name, _channel in st.session_state["project_dct"][
|
1609 |
+
"scenario_planner"
|
1610 |
+
][unique_key].channels.items():
|
1611 |
+
st.session_state[f"{channel_name}_b_lower"] = _channel.bounds[0]
|
1612 |
+
st.session_state[f"{channel_name}_b_upper"] = _channel.bounds[1]
|
1613 |
+
with bounds_placeholder:
|
1614 |
+
with st.container():
|
1615 |
+
bounds_columns = st.columns((1, 0.35, 0.35, 1))
|
1616 |
+
with bounds_columns[0]:
|
1617 |
+
st.write("##")
|
1618 |
+
st.write(channel_name)
|
1619 |
+
with bounds_columns[1]:
|
1620 |
+
st.number_input(
|
1621 |
+
"Lower",
|
1622 |
+
min_value=-100,
|
1623 |
+
max_value=500,
|
1624 |
+
key=f"{channel_name}_b_lower",
|
1625 |
+
label_visibility="hidden",
|
1626 |
+
on_change=update_optimization_bounds,
|
1627 |
+
args=(
|
1628 |
+
channel_name,
|
1629 |
+
"lower",
|
1630 |
+
),
|
1631 |
+
)
|
1632 |
+
|
1633 |
+
with bounds_columns[2]:
|
1634 |
+
st.number_input(
|
1635 |
+
"Higher",
|
1636 |
+
min_value=-100,
|
1637 |
+
max_value=500,
|
1638 |
+
key=f"{channel_name}_b_upper",
|
1639 |
+
label_visibility="hidden",
|
1640 |
+
on_change=update_optimization_bounds,
|
1641 |
+
args=(
|
1642 |
+
channel_name,
|
1643 |
+
"upper",
|
1644 |
+
),
|
1645 |
+
)
|
1646 |
+
|
1647 |
+
st.divider()
|
1648 |
+
_columns = st.columns(2)
|
1649 |
+
with _columns[0]:
|
1650 |
+
st.subheader("Save Scenario")
|
1651 |
+
scenario_name = st.text_input(
|
1652 |
+
"Scenario name",
|
1653 |
+
key="scenario_input",
|
1654 |
+
placeholder="Scenario name",
|
1655 |
+
label_visibility="collapsed",
|
1656 |
+
)
|
1657 |
+
st.button(
|
1658 |
+
"Save",
|
1659 |
+
on_click=lambda: save_scenario(scenario_name),
|
1660 |
+
disabled=len(st.session_state["scenario_input"]) == 0,
|
1661 |
+
)
|
1662 |
+
|
1663 |
+
summary_df = pd.DataFrame(st.session_state["acutual_predicted"])
|
1664 |
+
summary_df.drop_duplicates(
|
1665 |
+
subset="Channel_name", keep="last", inplace=True
|
1666 |
+
)
|
1667 |
+
|
1668 |
+
summary_df_sorted = summary_df.sort_values(by="Delta", ascending=False)
|
1669 |
+
summary_df_sorted["Delta_percent"] = np.round(
|
1670 |
+
(
|
1671 |
+
(
|
1672 |
+
summary_df_sorted["Optimized_spend"]
|
1673 |
+
/ summary_df_sorted["Actual_spend"]
|
1674 |
+
)
|
1675 |
+
- 1
|
1676 |
+
)
|
1677 |
+
* 100,
|
1678 |
+
2,
|
1679 |
+
)
|
1680 |
+
|
1681 |
+
with open("summary_df.pkl", "wb") as f:
|
1682 |
+
pickle.dump(summary_df_sorted, f)
|
1683 |
+
# st.dataframe(summary_df_sorted)
|
1684 |
+
# ___columns=st.columns(3)
|
1685 |
+
# with ___columns[2]:
|
1686 |
+
# fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent')
|
1687 |
+
# st.plotly_chart(fig,use_container_width=True)
|
1688 |
+
# with ___columns[0]:
|
1689 |
+
# fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend')
|
1690 |
+
# st.plotly_chart(fig,use_container_width=True)
|
1691 |
+
# with ___columns[1]:
|
1692 |
+
# fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend')
|
1693 |
+
# st.plotly_chart(fig,use_container_width=True)
|
1694 |
+
|
1695 |
+
elif auth_status == False:
|
1696 |
+
st.error("Username/Password is incorrect")
|
1697 |
+
|
1698 |
+
if auth_status != True:
|
1699 |
+
try:
|
1700 |
+
username_forgot_pw, email_forgot_password, random_password = (
|
1701 |
+
authenticator.forgot_password("Forgot password")
|
1702 |
+
)
|
1703 |
+
if username_forgot_pw:
|
1704 |
+
st.session_state["config"]["credentials"]["usernames"][
|
1705 |
+
username_forgot_pw
|
1706 |
+
]["password"] = stauth.Hasher([random_password]).generate()[0]
|
1707 |
+
send_email(email_forgot_password, random_password)
|
1708 |
+
st.success("New password sent securely")
|
1709 |
+
# Random password to be transferred to user securely
|
1710 |
+
elif username_forgot_pw == False:
|
1711 |
+
st.error("Username not found")
|
1712 |
+
except Exception as e:
|
1713 |
+
st.error(e)
|
1714 |
+
|
1715 |
+
update_db("9_Scenario_Planner.py")
|