jarif commited on
Commit
690d2e2
·
verified ·
1 Parent(s): c814bf0

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +184 -40
src/streamlit_app.py CHANGED
@@ -1,40 +1,184 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import pickle
6
+ from tensorflow.keras.models import load_model
7
+
8
+ # Streamlit page configuration
9
+ st.set_page_config(
10
+ page_title="Power Consumption Predictor",
11
+ layout="centered",
12
+ initial_sidebar_state="collapsed"
13
+ )
14
+
15
+ # Custom CSS for eye-catching design
16
+ st.markdown("""
17
+ <style>
18
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@400;700&display=swap');
19
+ .main {background-color: #ffffff;}
20
+ .stTitle {color: #003087; font-family: 'Roboto', sans-serif; text-align: center; margin-bottom: 10px; font-size: 32px; font-weight: 700;}
21
+ .stSubheader {color: #003087; font-family: 'Roboto', sans-serif; font-size: 22px; font-weight: 700; margin-top: 10px; margin-bottom: 10px;}
22
+ .stMarkdown {font-family: 'Roboto', sans-serif; color: #212529; font-size: 16px;}
23
+ .stDataFrame {
24
+ background-color: #ffffff;
25
+ border-radius: 12px;
26
+ padding: 15px;
27
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
28
+ }
29
+ .stButton>button {
30
+ background-color: #007bff;
31
+ color: white;
32
+ border-radius: 10px;
33
+ padding: 12px 30px;
34
+ font-size: 18px;
35
+ font-family: 'Roboto', sans-serif;
36
+ font-weight: 700;
37
+ display: block;
38
+ margin: 15px auto;
39
+ border: none;
40
+ transition: all 0.3s ease;
41
+ }
42
+ .stButton>button:hover {
43
+ background: linear-gradient(45deg, #0056b3, #007bff);
44
+ transform: scale(1.05);
45
+ box-shadow: 0 4px 8px rgba(0,0,0,0.2);
46
+ }
47
+ .stNumberInput label {
48
+ color: #007bff;
49
+ font-family: 'Roboto', sans-serif;
50
+ font-weight: 700;
51
+ font-size: 16px;
52
+ }
53
+ .stNumberInput input {
54
+ background-color: #ffffff;
55
+ color: #212529;
56
+ border: 2px solid #007bff;
57
+ border-radius: 8px;
58
+ padding: 10px;
59
+ font-family: 'Roboto', sans-serif;
60
+ font-size: 14px;
61
+ caret-color: #212529;
62
+ }
63
+ .stNumberInput input:focus {
64
+ outline: none;
65
+ border-color: #003087;
66
+ box-shadow: 0 0 8px rgba(0,123,255,0.3);
67
+ }
68
+ </style>
69
+ """, unsafe_allow_html=True)
70
+
71
+ # Load model and scalers
72
+ try:
73
+ model = load_model('my_model.keras')
74
+ scaler_X = pickle.load(open('scaler_X.pkl', 'rb'))
75
+ scaler_y = pickle.load(open('scaler_y.pkl', 'rb'))
76
+ except Exception as e:
77
+ st.error(f"Failed to load model or scalers: {str(e)}. Ensure 'my_model.keras', 'scaler_X.pkl', and 'scaler_y.pkl' are in E:\\grid\\. "
78
+ "This error may occur if the TensorFlow version used to save the model differs from your installed version. "
79
+ "Try installing TensorFlow 2.17.0 or the version used to save the model (e.g., `pip install tensorflow==2.17.0`).")
80
+ st.stop()
81
+
82
+ # Main app layout
83
+ st.title("Power Consumption Predictor")
84
+ st.markdown("""
85
+ Enter values for one timestep to predict power consumption for Zone1, Zone2, and Zone3.
86
+ Results will be displayed as a vibrant bar plot and a clear table.
87
+ """)
88
+
89
+ # Input section
90
+ st.subheader("Enter Timestep Data")
91
+ st.markdown("""
92
+ **Instructions**:
93
+ - Enter values for the 8 features below (default values are provided).
94
+ - **Hour**: 0 to 23 (e.g., 14 for 2 PM).
95
+ - **DayOfWeek**: 0 to 6 (0 = Monday, 6 = Sunday).
96
+ - **Month**: 1 to 12 (e.g., 7 for July).
97
+ - **Other features**: Use reasonable values (e.g., Temperature in °C, Humidity as a fraction).
98
+ - Click "Predict" to see results.
99
+ """)
100
+
101
+ # Vertical form for input
102
+ with st.container():
103
+ feature_names = ['Hour', 'DayOfWeek', 'Month', 'Temperature', 'Humidity', 'WindSpeed', 'GeneralDiffuseFlows', 'DiffuseFlows']
104
+ default_values = [0, 6, 1, 6.559, 73.8, 0.083, 0.051, 0.119] # From dataset
105
+ user_input = []
106
+ for i, (name, default) in enumerate(zip(feature_names, default_values)):
107
+ if name in ['Hour', 'DayOfWeek', 'Month']:
108
+ value = st.number_input(
109
+ f"{name}",
110
+ min_value=0,
111
+ max_value=23 if name == 'Hour' else 6 if name == 'DayOfWeek' else 12,
112
+ value=int(default),
113
+ step=1,
114
+ key=f"input_{i}"
115
+ )
116
+ user_input.append(value)
117
+ else:
118
+ value = st.number_input(
119
+ f"{name}",
120
+ value=float(default),
121
+ step=0.01,
122
+ format="%.6f",
123
+ key=f"input_{i}"
124
+ )
125
+ user_input.append(value)
126
+
127
+ # Predict button
128
+ if st.button("Predict", key="predict_button"):
129
+ try:
130
+ # Replicate input for 24 timesteps
131
+ custom_raw_data = np.array([user_input] * 24).reshape(1, 24, 8)
132
+
133
+ # Selective scaling
134
+ features_to_scale = ['Temperature', 'Humidity', 'WindSpeed', 'GeneralDiffuseFlows', 'DiffuseFlows']
135
+ scale_indices = [3, 4, 5, 6, 7]
136
+ custom_scaled = custom_raw_data.copy()
137
+ custom_2d_to_scale = custom_raw_data[:, :, scale_indices].reshape(-1, len(scale_indices))
138
+ custom_scaled_2d = scaler_X.transform(custom_2d_to_scale)
139
+ custom_scaled[:, :, scale_indices] = custom_scaled_2d.reshape(1, 24, len(scale_indices))
140
+
141
+ # Predict
142
+ y_pred_scaled = model.predict(custom_scaled)
143
+ if isinstance(y_pred_scaled, list):
144
+ y_pred_combined = np.concatenate(y_pred_scaled, axis=1)
145
+ else:
146
+ y_pred_combined = y_pred_scaled
147
+ y_pred_original = scaler_y.inverse_transform(y_pred_combined)
148
+
149
+ # Store predictions
150
+ labels = ['PowerConsumption_Zone1', 'PowerConsumption_Zone2', 'PowerConsumption_Zone3']
151
+ st.session_state.pred_df = pd.DataFrame(y_pred_original, columns=labels, index=['User Input'])
152
+ st.session_state.predictions = y_pred_original
153
+
154
+ except Exception as e:
155
+ st.error(f"Error processing input: {str(e)}")
156
+
157
+ # Display predictions if available
158
+ if 'predictions' in st.session_state and st.session_state.predictions is not None:
159
+ st.markdown("### Predicted Power Consumption")
160
+ fig = px.bar(
161
+ st.session_state.pred_df.reset_index().melt(id_vars='index', value_vars=labels, var_name='Zone', value_name='Power Consumption'),
162
+ x='index', y='Power Consumption', color='Zone', barmode='group',
163
+ title='Predicted Power Consumption by Zone',
164
+ labels={'index': 'Sample', 'Power Consumption': 'Power Consumption (Original Scale)'},
165
+ color_discrete_sequence=['#007bff', '#28a745', '#dc3545']
166
+ )
167
+ fig.update_layout(
168
+ plot_bgcolor='white',
169
+ paper_bgcolor='white',
170
+ font=dict(family='Roboto', size=12, color='#212529'),
171
+ title_font=dict(size=18, family='Roboto', color='#003087'),
172
+ xaxis_title="Sample",
173
+ yaxis_title="Power Consumption (Original Scale)",
174
+ legend_title="Zones",
175
+ margin=dict(l=40, r=40, t=60, b=40)
176
+ )
177
+ st.plotly_chart(fig, use_container_width=True)
178
+
179
+ st.markdown("### Prediction Table")
180
+ st.dataframe(st.session_state.pred_df.style.format("{:.4f}").set_caption("Predicted Power Consumption (Original Scale)"))
181
+
182
+ # Footer
183
+ st.markdown("---")
184
+ st.markdown("**Made by Sadik Al Jarif**")