Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import plotly.express as px | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from utilities_with_panel import channel_name_formating, load_authenticator, initialize_data | |
| from sklearn.metrics import r2_score | |
| from collections import OrderedDict | |
| from classes import class_from_dict,class_to_dict | |
| import pickle | |
| import json | |
| import pandas as pd | |
| from utilities import ( | |
| load_local_css, | |
| set_header, | |
| channel_name_formating, | |
| ) | |
| for k, v in st.session_state.items(): | |
| if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'): | |
| st.session_state[k] = v | |
| def s_curve(x,K,b,a,x0): | |
| return K / (1 + b*np.exp(-a*(x-x0))) | |
| def save_scenario(scenario_name): | |
| """ | |
| Save the current scenario with the mentioned name in the session state | |
| Parameters | |
| ---------- | |
| scenario_name | |
| Name of the scenario to be saved | |
| """ | |
| if 'saved_scenarios' not in st.session_state: | |
| st.session_state = OrderedDict() | |
| #st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save() | |
| st.session_state['saved_scenarios'][scenario_name] = class_to_dict(st.session_state['scenario']) | |
| st.session_state['scenario_input'] = "" | |
| print(type(st.session_state['saved_scenarios'])) | |
| with open('../saved_scenarios.pkl', 'wb') as f: | |
| pickle.dump(st.session_state['saved_scenarios'],f) | |
| def reset_curve_parameters(): | |
| del st.session_state['K'] | |
| del st.session_state['b'] | |
| del st.session_state['a'] | |
| del st.session_state['x0'] | |
| def update_response_curve(): | |
| # st.session_state['rcs'][selected_channel_name]['K'] = st.session_state['K'] | |
| # st.session_state['rcs'][selected_channel_name]['b'] = st.session_state['b'] | |
| # st.session_state['rcs'][selected_channel_name]['a'] = st.session_state['a'] | |
| # st.session_state['rcs'][selected_channel_name]['x0'] = st.session_state['x0'] | |
| # rcs = st.session_state['rcs'] | |
| _channel_class = st.session_state['scenario'].channels[selected_channel_name] | |
| _channel_class.update_response_curves({ | |
| 'K' : st.session_state['K'], | |
| 'b' : st.session_state['b'], | |
| 'a' : st.session_state['a'], | |
| 'x0' : st.session_state['x0']}) | |
| # authenticator = st.session_state.get('authenticator') | |
| # if authenticator is None: | |
| # authenticator = load_authenticator() | |
| # name, authentication_status, username = authenticator.login('Login', 'main') | |
| # auth_status = st.session_state.get('authentication_status') | |
| # if auth_status == True: | |
| # is_state_initiaized = st.session_state.get('initialized',False) | |
| # if not is_state_initiaized: | |
| # print("Scenario page state reloaded") | |
| # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default | |
| st.set_page_config(layout='wide') | |
| load_local_css('styles.css') | |
| set_header() | |
| def panel_fetch(file_selected): | |
| raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM") | |
| if "Panel" in raw_data_mmm_df.columns: | |
| panel = list(set(raw_data_mmm_df["Panel"])) | |
| else: | |
| raw_data_mmm_df = None | |
| panel = None | |
| return panel | |
| metrics_selected='Revenue' | |
| file_selected = ( | |
| f"Overview_data_test_panel@#{metrics_selected}.xlsx" | |
| ) | |
| panel_list = panel_fetch(file_selected) | |
| if "used_response_metrics" in st.session_state and st.session_state['used_response_metrics']!=[]: | |
| sel_target_col = st.selectbox("Select the response metric", st.session_state['used_response_metrics']) | |
| target_col = sel_target_col.lower().replace(" ", "_").replace('-', '').replace(':', '').replace("__", "_") | |
| else : | |
| sel_target_col = 'Total Approved Accounts - Revenue' | |
| target_col = 'total_approved_accounts_revenue' | |
| st.subheader("Build response curves") | |
| st.session_state['selected_markets']= st.selectbox( | |
| "Select Markets", | |
| ["Total Market"] + panel_list, | |
| index=0, | |
| ) | |
| initialize_data(target_col,st.session_state['selected_markets']) | |
| channels_list = st.session_state['channels_list'] | |
| selected_channel_name = st.selectbox('Channel', st.session_state['channels_list'], format_func=channel_name_formating,on_change=reset_curve_parameters) | |
| rcs = {} | |
| for channel_name in channels_list: | |
| rcs[channel_name] = st.session_state['scenario'].channels[channel_name].response_curve_params | |
| # rcs = st.session_state['rcs'] | |
| if 'K' not in st.session_state: | |
| st.session_state['K'] = rcs[selected_channel_name]['K'] | |
| if 'b' not in st.session_state: | |
| st.session_state['b'] = rcs[selected_channel_name]['b'] | |
| if 'a' not in st.session_state: | |
| st.session_state['a'] = rcs[selected_channel_name]['a'] | |
| if 'x0' not in st.session_state: | |
| st.session_state['x0'] = rcs[selected_channel_name]['x0'] | |
| x = st.session_state['actual_input_df'][selected_channel_name].values | |
| y = st.session_state['actual_contribution_df'][selected_channel_name].values | |
| power = (np.ceil(np.log(x.max()) / np.log(10) )- 3) | |
| # fig = px.scatter(x, s_curve(x/10**power, | |
| # st.session_state['K'], | |
| # st.session_state['b'], | |
| # st.session_state['a'], | |
| # st.session_state['x0'])) | |
| fig = px.scatter(x=x, y=y) | |
| fig.add_trace(go.Scatter(x=sorted(x), y=s_curve(sorted(x)/10**power,st.session_state['K'], | |
| st.session_state['b'], | |
| st.session_state['a'], | |
| st.session_state['x0']), | |
| line=dict(color='red'))) | |
| fig.update_layout(title_text="Response Curve",showlegend=False) | |
| fig.update_annotations(font_size=10) | |
| fig.update_xaxes(title='Spends') | |
| fig.update_yaxes(title=sel_target_col) | |
| st.plotly_chart(fig,use_container_width=True) | |
| r2 = r2_score(y, s_curve(x / 10**power, | |
| st.session_state['K'], | |
| st.session_state['b'], | |
| st.session_state['a'], | |
| st.session_state['x0'])) | |
| st.metric('R2',round(r2,2)) | |
| columns = st.columns(4) | |
| with columns[0]: | |
| st.number_input('K',key='K',format="%0.5f") | |
| with columns[1]: | |
| st.number_input('b',key='b',format="%0.5f") | |
| with columns[2]: | |
| st.number_input('a',key='a',step=0.0001,format="%0.5f") | |
| with columns[3]: | |
| st.number_input('x0',key='x0',format="%0.5f") | |
| st.button('Update parameters',on_click=update_response_curve) | |
| st.button('Reset parameters',on_click=reset_curve_parameters) | |
| scenario_name = st.text_input('Scenario name', key='scenario_input',placeholder='Scenario name',label_visibility='collapsed') | |
| st.button('Save', on_click=lambda : save_scenario(scenario_name),disabled=len(st.session_state['scenario_input']) == 0) | |
| file_name = st.text_input('rcs download file name', key='file_name_input',placeholder='file name',label_visibility='collapsed') | |
| st.download_button( | |
| label="Download response curves", | |
| data=json.dumps(rcs), | |
| file_name=f"{file_name}.json", | |
| mime="application/json", | |
| disabled= len(file_name) == 0, | |
| ) | |
| def s_curve_derivative(x, K, b, a, x0): | |
| # Derivative of the S-curve function | |
| return a * b * K * np.exp(-a * (x - x0)) / ((1 + b * np.exp(-a * (x - x0))) ** 2) | |
| # Parameters of the S-curve | |
| K = st.session_state['K'] | |
| b = st.session_state['b'] | |
| a = st.session_state['a'] | |
| x0 = st.session_state['x0'] | |
| # Optimized spend value obtained from the tool | |
| optimized_spend = st.number_input('value of x') # Replace this with your optimized spend value | |
| # Calculate the slope at the optimized spend value | |
| slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0) | |
| st.write("Slope ", slope_at_optimized_spend) |