dlaj commited on
Commit
76582d6
·
verified ·
1 Parent(s): ae596a8

Update streamlit_simulation/app.py

Browse files
Files changed (1) hide show
  1. streamlit_simulation/app.py +26 -26
streamlit_simulation/app.py CHANGED
@@ -183,18 +183,18 @@ def predict_transformer_step(model, dataset, idx, device):
183
 
184
 
185
  def init_simulation_layout():
186
- plot_title = st.empty()
187
- plot_container = st.empty()
188
- x_axis_label = st.empty()
189
- info_container = st.container()
 
190
  return plot_title, plot_container, x_axis_label, info_container
191
 
192
 
193
 
194
  def create_prediction_plot(pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=None, y_max=None):
195
  """Generates the matplotlib figure for plotting prediction vs. actual."""
196
- fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=False, facecolor=PLOT_COLOR)
197
- fig.tight_layout()
198
  ax.set_facecolor(PLOT_COLOR)
199
 
200
  ax.plot(pred_timestamps[-window_hours:], pred_vals[-window_hours:], label="Prediction", color="#EF233C", linestyle="--")
@@ -233,7 +233,7 @@ def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=
233
  f"{title}</div>",
234
  unsafe_allow_html=True
235
  )
236
- plot_container.pyplot(fig, use_container_width=True)
237
 
238
  st.markdown("<div style='margin-bottom: 0.5rem;'></div>", unsafe_allow_html=True)
239
  x_axis_label.markdown(
@@ -254,25 +254,25 @@ def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=
254
  st.caption("Simulation Progress")
255
  st.progress(progress)
256
 
257
- #if len(st.session_state.true_vals) > 1:
258
- # true_arr = np.array(st.session_state.true_vals)
259
- #pred_arr = np.array(st.session_state.pred_vals[:-1])
260
-
261
- # min_len = min(len(true_arr), len(pred_arr)) #just start if there are 2 actual values
262
- # if min_len >= 1:
263
- # errors = np.abs(true_arr[:min_len] - pred_arr[:min_len])
264
- #mape = np.mean(errors / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])) * 100
265
- #mae = np.mean(errors)
266
- #max_error = np.max(errors)
267
-
268
- #st.divider()
269
- #st.markdown(
270
- # f"<span style='font-size: 24px; font-weight: 600; color: {HEADER_COLOR} !important;'>Interim Metrics</span>",
271
- # unsafe_allow_html=True
272
- # )
273
- #st.metric("MAPE (so far)", f"{mape:.2f} %")
274
- # st.metric("MAE (so far)", f"{mae:,.0f} MW")
275
- #st.metric("Max Error", f"{max_error:,.0f} MW")
276
 
277
 
278
 
 
183
 
184
 
185
  def init_simulation_layout():
186
+ col1, spacer, col2 = st.columns([3, 0.2, 1])
187
+ plot_title = col1.empty()
188
+ plot_container = col1.empty()
189
+ x_axis_label = col1.empty()
190
+ info_container = col2.empty()
191
  return plot_title, plot_container, x_axis_label, info_container
192
 
193
 
194
 
195
  def create_prediction_plot(pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=None, y_max=None):
196
  """Generates the matplotlib figure for plotting prediction vs. actual."""
197
+ fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR)
 
198
  ax.set_facecolor(PLOT_COLOR)
199
 
200
  ax.plot(pred_timestamps[-window_hours:], pred_vals[-window_hours:], label="Prediction", color="#EF233C", linestyle="--")
 
233
  f"{title}</div>",
234
  unsafe_allow_html=True
235
  )
236
+ plot_container.pyplot(fig)
237
 
238
  st.markdown("<div style='margin-bottom: 0.5rem;'></div>", unsafe_allow_html=True)
239
  x_axis_label.markdown(
 
254
  st.caption("Simulation Progress")
255
  st.progress(progress)
256
 
257
+ if len(st.session_state.true_vals) > 1:
258
+ true_arr = np.array(st.session_state.true_vals)
259
+ pred_arr = np.array(st.session_state.pred_vals[:-1])
260
+
261
+ min_len = min(len(true_arr), len(pred_arr)) #just start if there are 2 actual values
262
+ if min_len >= 1:
263
+ errors = np.abs(true_arr[:min_len] - pred_arr[:min_len])
264
+ mape = np.mean(errors / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])) * 100
265
+ mae = np.mean(errors)
266
+ max_error = np.max(errors)
267
+
268
+ st.divider()
269
+ st.markdown(
270
+ f"<span style='font-size: 24px; font-weight: 600; color: {HEADER_COLOR} !important;'>Interim Metrics</span>",
271
+ unsafe_allow_html=True
272
+ )
273
+ st.metric("MAPE (so far)", f"{mape:.2f} %")
274
+ st.metric("MAE (so far)", f"{mae:,.0f} MW")
275
+ st.metric("Max Error", f"{max_error:,.0f} MW")
276
 
277
 
278