3v324v23 commited on
Commit
ab78aff
·
1 Parent(s): 322ce5c
streamlit_simulation/app.py CHANGED
@@ -251,8 +251,8 @@ def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=
251
 
252
  st.metric("Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–")
253
  st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–")
254
- st.caption("Simulation Progress")
255
- st.progress(progress)
256
 
257
  if len(st.session_state.true_vals) > 1:
258
  true_arr = np.array(st.session_state.true_vals)
 
251
 
252
  st.metric("Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–")
253
  st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–")
254
+ #st.caption("Simulation Progress")
255
+ #st.progress(progress)
256
 
257
  if len(st.session_state.true_vals) > 1:
258
  true_arr = np.array(st.session_state.true_vals)
streamlit_simulation/app_backup_hug.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import streamlit as st
4
+ import pickle
5
+ import pandas as pd
6
+ import time
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.dates as mdates
10
+ import warnings
11
+ import torch
12
+
13
+ from config_streamlit import (MODEL_PATH_LIGHTGBM, DATA_PATH, TRAIN_RATIO,
14
+ TEXT_COLOR, HEADER_COLOR, ACCENT_COLOR,
15
+ BUTTON_BG, BUTTON_HOVER_BG, BG_COLOR,
16
+ INPUT_BG, PROGRESS_COLOR, PLOT_COLOR
17
+ )
18
+ from lightgbm_model.scripts.config_lightgbm import FEATURES
19
+ from transformer_model.scripts.utils.informer_dataset_class import InformerDataset
20
+ from transformer_model.scripts.training.load_basis_model import load_moment_model
21
+ from transformer_model.scripts.config_transformer import CHECKPOINT_DIR, FORECAST_HORIZON, SEQ_LEN
22
+ from sklearn.preprocessing import StandardScaler
23
+
24
+ from huggingface_hub import hf_hub_download
25
+
26
+
27
+ # ============================== Layout ==============================
28
+
29
+ # Streamlit & warnings config
30
+ warnings.filterwarnings("ignore", category=FutureWarning)
31
+ st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide")
32
+
33
+ #CSS part
34
+ st.markdown(f"""
35
+ <style>
36
+ body, .block-container {{
37
+ background-color: {BG_COLOR} !important;
38
+ }}
39
+
40
+ html, body, [class*="css"] {{
41
+ color: {TEXT_COLOR} !important;
42
+ font-family: 'sans-serif';
43
+ }}
44
+
45
+ h1, h2, h3, h4, h5, h6 {{
46
+ color: {HEADER_COLOR} !important;
47
+ }}
48
+
49
+ .stButton > button {{
50
+ background-color: {BUTTON_BG};
51
+ color: {TEXT_COLOR};
52
+ border: 1px solid {ACCENT_COLOR};
53
+ }}
54
+
55
+ .stButton > button:hover {{
56
+ background-color: {BUTTON_HOVER_BG};
57
+ }}
58
+
59
+ .stSelectbox div[data-baseweb="select"],
60
+ .stDateInput input {{
61
+ background-color: {INPUT_BG} !important;
62
+ color: {TEXT_COLOR} !important;
63
+ }}
64
+
65
+ [data-testid="stMetricLabel"],
66
+ [data-testid="stMetricValue"] {{
67
+ color: {TEXT_COLOR} !important;
68
+ }}
69
+
70
+ .stMarkdown p {{
71
+ color: {TEXT_COLOR} !important;
72
+ }}
73
+
74
+ .stDataFrame tbody tr td {{
75
+ color: {TEXT_COLOR} !important;
76
+ }}
77
+
78
+ .stProgress > div > div {{
79
+ background-color: {PROGRESS_COLOR} !important;
80
+ }}
81
+
82
+ /* Alle Label-Texte für Inputs/Sliders */
83
+ label {{
84
+ color: {TEXT_COLOR} !important;
85
+ }}
86
+
87
+ /* Text in selectbox-Optionsfeldern */
88
+ .stSelectbox label, .stSelectbox div {{
89
+ color: {TEXT_COLOR} !important;
90
+ }}
91
+
92
+ /* DateInput angleichen an Selectbox */
93
+ .stDateInput input {{
94
+ background-color: #f2f6fa !important;
95
+ color: {TEXT_COLOR} !important;
96
+ border: none !important;
97
+ border-radius: 5px !important;
98
+ }}
99
+
100
+ </style>
101
+ """, unsafe_allow_html=True)
102
+
103
+ st.title("Electricity Consumption Forecast: Hourly Simulation")
104
+ st.write("Welcome to the simulation interface!")
105
+
106
+ # ============================== Session State Init ==============================
107
+ def init_session_state():
108
+ defaults = {
109
+ "is_running": False,
110
+ "start_index": 0,
111
+ "true_vals": [],
112
+ "pred_vals": [],
113
+ "true_timestamps": [],
114
+ "pred_timestamps": [],
115
+ "last_fig": None,
116
+ "valid_pos": 0
117
+ }
118
+ for key, value in defaults.items():
119
+ if key not in st.session_state:
120
+ st.session_state[key] = value
121
+
122
+ init_session_state()
123
+
124
+ # ============================== Loaders ==============================
125
+
126
+ @st.cache_data
127
+ def load_lightgbm_model():
128
+ with open(MODEL_PATH_LIGHTGBM, "rb") as f:
129
+ return pickle.load(f)
130
+
131
+ @st.cache_resource
132
+ def load_transformer_model_and_dataset():
133
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
+
135
+ # Load model
136
+ model = load_moment_model()
137
+ checkpoint_path = hf_hub_download(
138
+ repo_id="dlaj/energy-forecasting-files",
139
+ filename="transformer_model/model_final.pth",
140
+ repo_type="dataset"
141
+ )
142
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
143
+ model.to(device)
144
+ model.eval()
145
+
146
+ # Datasets
147
+ train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
148
+ test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
149
+ test_dataset.scaler = train_dataset.scaler
150
+
151
+ return model, test_dataset, device
152
+
153
+ @st.cache_data
154
+ def load_data():
155
+ csv_path = hf_hub_download(
156
+ repo_id="dlaj/energy-forecasting-files",
157
+ filename="data/processed/energy_consumption_aggregated_cleaned.csv",
158
+ repo_type="dataset"
159
+ )
160
+ df = pd.read_csv(csv_path, parse_dates=["date"])
161
+ return df
162
+
163
+
164
+ # ============================== Utility Functions ==============================
165
+
166
+ def predict_transformer_step(model, dataset, idx, device):
167
+ """Performs a single prediction step with the transformer model."""
168
+ timeseries, _, input_mask = dataset[idx]
169
+ timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device)
170
+ input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device)
171
+
172
+ with torch.no_grad():
173
+ output = model(x_enc=timeseries, input_mask=input_mask)
174
+
175
+ pred = output.forecast[:, 0, :].cpu().numpy().flatten()
176
+
177
+ # Rückskalieren
178
+ dummy = np.zeros((len(pred), dataset.n_channels))
179
+ dummy[:, 0] = pred
180
+ pred_original = dataset.scaler.inverse_transform(dummy)[:, 0]
181
+
182
+ return float(pred_original[0])
183
+
184
+
185
+ def init_simulation_layout():
186
+ col1, spacer, col2 = st.columns([3, 0.2, 1])
187
+ plot_title = col1.empty()
188
+ plot_container = col1.empty()
189
+ x_axis_label = col1.empty()
190
+ info_container = col2.empty()
191
+ return plot_title, plot_container, x_axis_label, info_container
192
+
193
+
194
+
195
+ def create_prediction_plot(pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=None, y_max=None):
196
+ """Generates the matplotlib figure for plotting prediction vs. actual."""
197
+ fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR)
198
+ ax.set_facecolor(PLOT_COLOR)
199
+
200
+ ax.plot(pred_timestamps[-window_hours:], pred_vals[-window_hours:], label="Prediction", color="#EF233C", linestyle="--")
201
+ if true_vals:
202
+ ax.plot(true_timestamps[-window_hours:], true_vals[-window_hours:], label="Actual", color="#0077B6")
203
+
204
+ ax.set_ylabel("Consumption (MW)", fontsize=8, color=TEXT_COLOR)
205
+ ax.legend(
206
+ fontsize=8,
207
+ loc="upper left",
208
+ bbox_to_anchor=(0, 0.95),
209
+ facecolor= INPUT_BG, # INPUT_BG
210
+ edgecolor= ACCENT_COLOR, # ACCENT_COLOR
211
+ labelcolor= TEXT_COLOR # TEXT_COLOR
212
+ )
213
+ ax.yaxis.grid(True, linestyle=':', linewidth=0.5, alpha=0.7)
214
+ ax.set_ylim(y_min, y_max)
215
+ ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))
216
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d"))
217
+ ax.tick_params(axis="x", labelrotation=0, labelsize=5, colors=TEXT_COLOR)
218
+ ax.tick_params(axis="y", labelsize=5, colors=TEXT_COLOR)
219
+ #fig.patch.set_facecolor('#e6ecf0') # outer area
220
+
221
+ for spine in ax.spines.values():
222
+ spine.set_visible(False)
223
+
224
+ st.session_state.last_fig = fig
225
+ return fig
226
+
227
+
228
+ def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False):
229
+ """Displays the simulation plot and metrics in the UI."""
230
+ title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction"
231
+ plot_title.markdown(
232
+ f"<div style='text-align: center; font-size: 20pt; font-weight: bold; color: {TEXT_COLOR}; margin-bottom: -0.7rem; margin-top: 0rem;'>"
233
+ f"{title}</div>",
234
+ unsafe_allow_html=True
235
+ )
236
+ plot_container.pyplot(fig)
237
+
238
+ st.markdown("<div style='margin-bottom: 0.5rem;'></div>", unsafe_allow_html=True)
239
+ x_axis_label.markdown(
240
+ f"<div style='text-align: center; font-size: 14pt; color: {TEXT_COLOR}; margin-top: -0.5rem;'>"
241
+ f"Time</div>",
242
+ unsafe_allow_html=True
243
+ )
244
+
245
+ with info_container.container():
246
+ st.markdown("<div style='margin-top: 5rem;'></div>", unsafe_allow_html=True)
247
+ st.markdown(
248
+ f"<span style='font-size: 24px; font-weight: 600; color: {HEADER_COLOR} !important;'>Time: {timestamp}</span>",
249
+ unsafe_allow_html=True
250
+ )
251
+
252
+ st.metric("Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–")
253
+ st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–")
254
+ st.caption("Simulation Progress")
255
+ st.progress(progress)
256
+
257
+ if len(st.session_state.true_vals) > 1:
258
+ true_arr = np.array(st.session_state.true_vals)
259
+ pred_arr = np.array(st.session_state.pred_vals[:-1])
260
+
261
+ min_len = min(len(true_arr), len(pred_arr)) #just start if there are 2 actual values
262
+ if min_len >= 1:
263
+ errors = np.abs(true_arr[:min_len] - pred_arr[:min_len])
264
+ mape = np.mean(errors / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])) * 100
265
+ mae = np.mean(errors)
266
+ max_error = np.max(errors)
267
+
268
+ st.divider()
269
+ st.markdown(
270
+ f"<span style='font-size: 24px; font-weight: 600; color: {HEADER_COLOR} !important;'>Interim Metrics</span>",
271
+ unsafe_allow_html=True
272
+ )
273
+ st.metric("MAPE (so far)", f"{mape:.2f} %")
274
+ st.metric("MAE (so far)", f"{mae:,.0f} MW")
275
+ st.metric("Max Error", f"{max_error:,.0f} MW")
276
+
277
+
278
+
279
+ # ============================== Data Preparation ==============================
280
+
281
+ df_full = load_data()
282
+
283
+ # Split Train/Test
284
+ train_size = int(len(df_full) * TRAIN_RATIO)
285
+ test_df_raw = df_full.iloc[train_size:].reset_index(drop=True)
286
+
287
+ # Start at first full hour (00:00)
288
+ first_full_day_index = test_df_raw[test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time()].index[0]
289
+ test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True)
290
+
291
+ # Select simulation window via date picker
292
+ min_date = test_df_full["date"].min().date()
293
+ max_date = test_df_full["date"].max().date()
294
+
295
+ # ============================== UI Controls ==============================
296
+
297
+ st.markdown("### Simulation Settings")
298
+ col1, col2 = st.columns([1, 1])
299
+
300
+ with col1:
301
+ st.markdown("**General Settings**")
302
+ model_choice = st.selectbox("Choose prediction model", ["LightGBM", "Transformer Model (moments)"])
303
+ if model_choice == "Transformer Model(moments)":
304
+ st.caption("⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)")
305
+ window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0)
306
+ window_hours = window_days * 24
307
+ speed = st.slider("Speed", 1, 10, 5)
308
+
309
+ with col2:
310
+ st.markdown(f"**Date Range** (from {min_date} to {max_date})")
311
+ start_date = st.date_input("Start Date", value=min_date, min_value=min_date, max_value=max_date)
312
+ end_date = st.date_input("End Date", value=max_date, min_value=min_date, max_value=max_date)
313
+
314
+
315
+ # ============================== Data Preparation (filtered) ==============================
316
+
317
+ # final filtered date window
318
+ test_df_filtered = test_df_full[
319
+ (test_df_full["date"].dt.date >= start_date) &
320
+ (test_df_full["date"].dt.date <= end_date)
321
+ ].reset_index(drop=True)
322
+
323
+ # For progression bar
324
+ total_steps_ui = len(test_df_filtered)
325
+
326
+ # ============================== Buttons ==============================
327
+
328
+ st.markdown("### Start Simulation")
329
+ col1, col2, col3 = st.columns([1, 1, 14])
330
+ with col1:
331
+ play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause"
332
+ if st.button(play_pause_text):
333
+ st.session_state.is_running = not st.session_state.is_running
334
+ st.rerun()
335
+ with col2:
336
+ reset_button = st.button("🔄 Reset")
337
+
338
+ # Reset logic
339
+ if reset_button:
340
+ st.session_state.start_index = 0
341
+ st.session_state.pred_vals = []
342
+ st.session_state.true_vals = []
343
+ st.session_state.pred_timestamps = []
344
+ st.session_state.true_timestamps = []
345
+ st.session_state.last_fig = None
346
+ st.session_state.is_running = False
347
+ st.session_state.valid_pos = 0
348
+ st.rerun()
349
+
350
+ # Auto-reset on critical parameter change while running
351
+ if st.session_state.is_running and (
352
+ start_date != st.session_state.get("last_start_date") or
353
+ end_date != st.session_state.get("last_end_date") or
354
+ model_choice != st.session_state.get("last_model_choice")
355
+ ):
356
+ st.session_state.start_index = 0
357
+ st.session_state.pred_vals = []
358
+ st.session_state.true_vals = []
359
+ st.session_state.pred_timestamps = []
360
+ st.session_state.true_timestamps = []
361
+ st.session_state.last_fig = None
362
+ st.session_state.valid_pos = 0
363
+ st.rerun()
364
+
365
+ # Track current selections for change detection
366
+ st.session_state.last_start_date = start_date
367
+ st.session_state.last_end_date = end_date
368
+ st.session_state.last_model_choice = model_choice
369
+
370
+
371
+ # ============================== Paused Mode ==============================
372
+
373
+ if not st.session_state.is_running and st.session_state.last_fig is not None:
374
+ st.write("Simulation paused...")
375
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
376
+
377
+ timestamp = st.session_state.pred_timestamps[-1] if st.session_state.pred_timestamps else "–"
378
+ prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None
379
+ actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None
380
+ progress = st.session_state.start_index / total_steps_ui
381
+
382
+ render_simulation_view(timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True)
383
+
384
+
385
+ # ============================== initialize values ==============================
386
+
387
+ #if lightGbm use testdata from above
388
+ if model_choice == "LightGBM":
389
+ test_df = test_df_filtered.copy()
390
+
391
+ #Shared state references for storing predictions and ground truths
392
+
393
+ true_vals = st.session_state.true_vals
394
+ pred_vals = st.session_state.pred_vals
395
+ true_timestamps = st.session_state.true_timestamps
396
+ pred_timestamps = st.session_state.pred_timestamps
397
+
398
+ # ============================== LightGBM Simulation ==============================
399
+
400
+ if model_choice == "LightGBM" and st.session_state.is_running:
401
+ model = load_lightgbm_model()
402
+ st.write("Simulation started...")
403
+ st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
404
+
405
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
406
+
407
+ for i in range(st.session_state.start_index, len(test_df)):
408
+ if not st.session_state.is_running:
409
+ break
410
+
411
+ current = test_df.iloc[i]
412
+ timestamp = current["date"]
413
+ features = current[FEATURES].values.reshape(1, -1)
414
+ prediction = model.predict(features)[0]
415
+
416
+ pred_vals.append(prediction)
417
+ pred_timestamps.append(timestamp)
418
+
419
+ if i >= 1:
420
+ prev_actual = test_df.iloc[i - 1]["consumption_MW"]
421
+ prev_time = test_df.iloc[i - 1]["date"]
422
+ true_vals.append(prev_actual)
423
+ true_timestamps.append(prev_time)
424
+
425
+ fig = create_prediction_plot(
426
+ pred_timestamps, pred_vals,
427
+ true_timestamps, true_vals,
428
+ window_hours,
429
+ y_min= test_df_filtered["consumption_MW"].min() - 2000,
430
+ y_max= test_df_filtered["consumption_MW"].max() + 2000
431
+ )
432
+
433
+ render_simulation_view(timestamp, prediction, prev_actual if i >= 1 else None, i / len(test_df), fig)
434
+
435
+ plt.close(fig) # Speicher freigeben
436
+
437
+ st.session_state.start_index = i + 1
438
+ time.sleep(1 / (speed + 1e-9))
439
+
440
+ st.success("Simulation completed!")
441
+
442
+
443
+
444
+ # ============================== Transformer Simulation ==============================
445
+
446
+ if model_choice == "Transformer Model(moments)":
447
+ if st.session_state.is_running:
448
+ st.write("Simulation started (Transformer)...")
449
+ st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
450
+
451
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
452
+
453
+ # Zugriff auf Modell, Dataset, Device
454
+ model, test_dataset, device = load_transformer_model_and_dataset()
455
+ data = test_dataset.data # bereits skaliert
456
+ scaler = test_dataset.scaler
457
+ n_channels = test_dataset.n_channels
458
+
459
+ test_start_idx = len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON)) + SEQ_LEN
460
+ base_timestamp = pd.read_csv(DATA_PATH, parse_dates=["date"])["date"].iloc[test_start_idx] #get original timestamp for later, cause not in dataset anymore
461
+
462
+ # Schritt 1: Finde Index, ab dem Stunde = 00:00 ist
463
+ offset = 0
464
+ while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp("00:00:00").time():
465
+ offset += 1
466
+
467
+ # Neuer Startindex in der Simulation
468
+ start_index = offset
469
+
470
+ # Session-State bei Bedarf initial setzen
471
+ if "start_index" not in st.session_state or st.session_state.start_index == 0:
472
+ st.session_state.start_index = start_index
473
+
474
+
475
+ # Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum
476
+ valid_indices = []
477
+ for i in range(start_index, len(test_dataset)):
478
+ timestamp = base_timestamp + pd.Timedelta(hours=i)
479
+ if start_date <= timestamp.date() <= end_date:
480
+ valid_indices.append(i)
481
+
482
+ # Fortschrittsanzeige
483
+ total_steps = len(valid_indices)
484
+
485
+ # Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!)
486
+ if "valid_pos" not in st.session_state:
487
+ st.session_state.valid_pos = 0
488
+
489
+ # Hauptschleife: Nur noch über gültige Indizes iterieren
490
+ for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos:]):
491
+
492
+ #for i in range(st.session_state.start_index, len(test_dataset)):
493
+ if not st.session_state.is_running:
494
+ break
495
+
496
+ current_pred = predict_transformer_step(model, test_dataset, i, device)
497
+ current_time = base_timestamp + pd.Timedelta(hours=i)
498
+
499
+ pred_vals.append(current_pred)
500
+ pred_timestamps.append(current_time)
501
+
502
+ if i >= 1:
503
+ prev_actual = test_dataset[i - 1][1][0, 0] # erster Forecast-Wert der letzten Zeile
504
+ # Rückskalieren
505
+ dummy_actual = np.zeros((1, n_channels))
506
+ dummy_actual[:, 0] = prev_actual
507
+ actual_val = scaler.inverse_transform(dummy_actual)[0, 0]
508
+
509
+ true_time = current_time - pd.Timedelta(hours=1)
510
+
511
+ if true_time >= pd.to_datetime(start_date):
512
+ true_vals.append(actual_val)
513
+ true_timestamps.append(true_time)
514
+
515
+ # Plot erzeugen
516
+ fig = create_prediction_plot(
517
+ pred_timestamps, pred_vals,
518
+ true_timestamps, true_vals,
519
+ window_hours,
520
+ y_min= test_df_filtered["consumption_MW"].min() - 2000,
521
+ y_max= test_df_filtered["consumption_MW"].max() + 2000
522
+ )
523
+ if len(pred_vals) >= 2 and len(true_vals) >= 1:
524
+ render_simulation_view(current_time, current_pred, actual_val if i >= 1 else None, st.session_state.valid_pos / total_steps, fig)
525
+
526
+ plt.close(fig) # Speicher freigeben
527
+
528
+ st.session_state.valid_pos += 1
529
+ time.sleep(1 / (speed + 1e-9))
530
+
531
+ st.success("Simulation completed!")
532
+
533
+
534
+ # ============================== Scroll Sync ==============================
535
+
536
+ st.markdown("""
537
+ <script>
538
+ window.addEventListener("message", (event) => {
539
+ if (event.data.type === "save_scroll") {
540
+ const pyScroll = event.data.scrollY;
541
+ window.parent.postMessage({type: "streamlit:setComponentValue", value: pyScroll}, "*");
542
+ }
543
+ });
544
+ </script>
545
+ """, unsafe_allow_html=True)
546
+