import os import time import warnings import matplotlib.dates as mdates import matplotlib.pyplot as plt import numpy as np import pandas as pd import streamlit as st import torch from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO from huggingface_hub import hf_hub_download from lightgbm_model.scripts.config_lightgbm import FEATURES from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model from streamlit_simulation.utils_streamlit import load_data as load_data_raw from transformer_model.scripts.config_transformer import (FORECAST_HORIZON, SEQ_LEN) from transformer_model.scripts.utils.informer_dataset_class import \ InformerDataset from transformer_model.scripts.utils.model_loader_wrapper import \ load_model_and_dataset # ============================== Layout ============================== # Streamlit & warnings config warnings.filterwarnings("ignore", category=FutureWarning) st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide") # CSS part st.markdown( f""" """, unsafe_allow_html=True, ) st.title("Electricity Consumption Forecast: Hourly Simulation") st.write("Welcome to the simulation interface!") st.info( "**Simulation Overview:**\n\n" "This dashboard provides an hourly electricity consumption forecast using two different models: " "**LightGBM** and a **Transformer (moment-based)**. Both models generate a fresh prediction at every time step " "(i.e., every simulated hour).\n\n" "Note: Since this app runs on a limited CPU on Hugging Face Spaces, the Transformer model may respond slower " "compared to local execution. On a standard local CPU, performance is significantly better." ) # ============================== Session State Init =============================== def init_session_state(): defaults = { "is_running": False, "start_index": 0, "true_vals": [], "pred_vals": [], "true_timestamps": [], "pred_timestamps": [], "last_fig": None, "valid_pos": 0, "first_plot_shown": False, } for key, value in defaults.items(): if key not in st.session_state: st.session_state[key] = value init_session_state() # ============================== Loaders Cache ============================== HF_REPO = "dlaj/energy-forecasting-files" HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv" # if local data, use them, if not, download from huggingface if os.path.exists(DATA_PATH): CSV_PATH = DATA_PATH else: CSV_PATH = hf_hub_download( repo_id=HF_REPO, filename=HF_FILENAME, repo_type="dataset", cache_dir="hf_cache", # Optional ) @st.cache_data def load_cached_lightgbm_model(): return load_lightgbm_model() @st.cache_resource def load_transformer_model_and_dataset(): return load_model_and_dataset() @st.cache_data def load_data(): return load_data_raw() # ============================== Utility Functions ============================== def predict_transformer_step(model, dataset, idx, device): """Performs a single prediction step with the transformer model.""" timeseries, _, input_mask = dataset[idx] timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device) input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device) with torch.no_grad(): output = model(x_enc=timeseries, input_mask=input_mask) pred = output.forecast[:, 0, :].cpu().numpy().flatten() # Rückskalieren dummy = np.zeros((len(pred), dataset.n_channels)) dummy[:, 0] = pred pred_original = dataset.scaler.inverse_transform(dummy)[:, 0] return float(pred_original[0]) def init_simulation_layout(): """Creates layout containers for plot and info sections.""" col1, spacer, col2 = st.columns([3, 0.2, 1]) plot_title = col1.empty() plot_container = col1.empty() x_axis_label = col1.empty() info_container = col2.empty() return plot_title, plot_container, x_axis_label, info_container def create_prediction_plot( pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=None, y_max=None, ): """Generates the matplotlib figure for plotting prediction vs. actual.""" fig, ax = plt.subplots( figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR ) ax.set_facecolor(PLOT_COLOR) ax.plot( pred_timestamps[-window_hours:], pred_vals[-window_hours:], label="Prediction", color="#EF233C", linestyle="--", ) if true_vals: ax.plot( true_timestamps[-window_hours:], true_vals[-window_hours:], label="Actual", color="#0077B6", ) ax.set_ylabel("Consumption (MW)", fontsize=8) ax.legend( fontsize=8, loc="upper left", bbox_to_anchor=(0, 0.95), # facecolor= INPUT_BG, # INPUT_BG # edgecolor= ACCENT_COLOR, # ACCENT_COLOR # labelcolor= TEXT_COLOR # TEXT_COLOR ) ax.yaxis.grid(True, linestyle=":", linewidth=0.5, alpha=0.7) ax.set_ylim(y_min, y_max) ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d")) ax.tick_params(axis="x", labelrotation=0, labelsize=5) ax.tick_params(axis="y", labelsize=5) # fig.patch.set_facecolor('#e6ecf0') # outer area for spine in ax.spines.values(): spine.set_visible(False) st.session_state.last_fig = fig return fig def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False): """Displays the simulation plot and metrics in the UI.""" title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction" plot_title.markdown( f"
" f"{title}
", unsafe_allow_html=True, ) plot_container.pyplot(fig) # st.markdown("
", unsafe_allow_html=True) # x_axis_label.markdown(f"
"f"Time
",unsafe_allow_html=True) with info_container.container(): st.markdown( f"Time: {timestamp}", unsafe_allow_html=True, ) st.metric( "Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–" ) st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–") st.caption("Simulation Progress") st.progress(progress) if len(st.session_state.true_vals) > 1: true_arr = np.array(st.session_state.true_vals) pred_arr = np.array(st.session_state.pred_vals[:-1]) min_len = min(len(true_arr), len(pred_arr)) if min_len >= 1: errors = np.abs(true_arr[:min_len] - pred_arr[:min_len]) mape = ( np.mean( errors / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len]) ) * 100 ) mae = np.mean(errors) max_error = np.max(errors) st.divider() st.markdown( "Interim Metrics", unsafe_allow_html=True, ) st.metric("MAPE (so far)", f"{mape:.2f} %") st.metric("MAE (so far)", f"{mae:,.0f} MW") st.metric("Max Error", f"{max_error:,.0f} MW") # ============================== Data Preparation ============================== df_full = load_data() # Split Train/Test train_size = int(len(df_full) * TRAIN_RATIO) test_df_raw = df_full.iloc[train_size:].reset_index(drop=True) # Start at first full hour (00:00) first_full_day_index = test_df_raw[ test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time() ].index[0] test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True) # Select simulation window via date picker min_date = test_df_full["date"].min().date() max_date = test_df_full["date"].max().date() # ============================== UI Controls ============================== with st.sidebar: st.header("⚙️ Simulation Settings") st.subheader("General Settings") model_choice = st.selectbox( "Choose prediction model", ["LightGBM", "Transformer Model (moments)"] ) if model_choice == "Transformer Model (moments)": st.caption( "⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)" ) window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0) window_hours = window_days * 24 speed = st.slider("Speed", 1, 10, 5) st.subheader("Date Range") start_date = st.date_input( "Start Date", value=min_date, min_value=min_date, max_value=max_date ) end_date = st.date_input( "End Date", value=max_date, min_value=min_date, max_value=max_date ) # ============================== Data Preparation (filtered) ============================== # final filtered date window test_df_filtered = test_df_full[ (test_df_full["date"].dt.date >= start_date) & (test_df_full["date"].dt.date <= end_date) ].reset_index(drop=True) # For progression bar total_steps_ui = len(test_df_filtered) # ============================== Buttons ============================== st.markdown("### Start Simulation") col1, col2, col3 = st.columns([1, 1, 4]) with col1: play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause" if st.button(play_pause_text, use_container_width=True): st.session_state.is_running = not st.session_state.is_running st.rerun() with col2: reset_button = st.button("🔄 Reset", use_container_width=True) # Reset logic if reset_button: st.session_state.start_index = 0 st.session_state.pred_vals = [] st.session_state.true_vals = [] st.session_state.pred_timestamps = [] st.session_state.true_timestamps = [] st.session_state.last_fig = None st.session_state.is_running = False st.session_state.valid_pos = 0 st.session_state.first_plot_shown = False st.rerun() # Auto-reset on critical parameter change while running if st.session_state.is_running and ( start_date != st.session_state.get("last_start_date") or end_date != st.session_state.get("last_end_date") or model_choice != st.session_state.get("last_model_choice") ): st.session_state.start_index = 0 st.session_state.pred_vals = [] st.session_state.true_vals = [] st.session_state.pred_timestamps = [] st.session_state.true_timestamps = [] st.session_state.last_fig = None st.session_state.valid_pos = 0 st.session_state.first_plot_shown = False st.rerun() # Track current selections for change detection st.session_state.last_start_date = start_date st.session_state.last_end_date = end_date st.session_state.last_model_choice = model_choice # ============================== Paused Mode ============================== if not st.session_state.is_running and st.session_state.last_fig is not None: st.write("Simulation paused...") plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() timestamp = ( st.session_state.pred_timestamps[-1] if st.session_state.pred_timestamps else "–" ) prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None progress = st.session_state.start_index / total_steps_ui render_simulation_view( timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True ) # ============================== initialize values ============================== # if lightGbm use testdata from above if model_choice == "LightGBM": test_df = test_df_filtered.copy() # Shared state references for storing predictions and ground truths true_vals = st.session_state.true_vals pred_vals = st.session_state.pred_vals true_timestamps = st.session_state.true_timestamps pred_timestamps = st.session_state.pred_timestamps # ============================== LightGBM Simulation ============================== if model_choice == "LightGBM" and st.session_state.is_running: model = load_cached_lightgbm_model() st.write("Simulation started...") st.markdown('
', unsafe_allow_html=True) plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() for i in range(st.session_state.start_index, len(test_df)): if not st.session_state.is_running: break current = test_df.iloc[i] timestamp = current["date"] features = current[FEATURES].values.reshape(1, -1) prediction = model.predict(features)[0] pred_vals.append(prediction) pred_timestamps.append(timestamp) if i >= 1: prev_actual = test_df.iloc[i - 1]["consumption_MW"] prev_time = test_df.iloc[i - 1]["date"] true_vals.append(prev_actual) true_timestamps.append(prev_time) fig = create_prediction_plot( pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=test_df_filtered["consumption_MW"].min() - 2000, y_max=test_df_filtered["consumption_MW"].max() + 2000, ) render_simulation_view( timestamp, prediction, prev_actual if i >= 1 else None, i / len(test_df), fig, ) plt.close(fig) # Speicher freigeben st.session_state.start_index = i + 1 time.sleep(1 / (speed + 1e-9)) st.success("Simulation completed!") # ============================== Transformer Simulation ============================== spinner_placeholder = st.empty() if model_choice == "Transformer Model (moments)": if st.session_state.is_running: st.write("Simulation started (Transformer)...") st.markdown('
', unsafe_allow_html=True) if not st.session_state.first_plot_shown: spinner_placeholder.markdown("Running first prediction – please wait...") plot_title, plot_container, x_axis_label, info_container = ( init_simulation_layout() ) # Zugriff auf Modell, Dataset, Device model, test_dataset, device = load_transformer_model_and_dataset() data = test_dataset.data # bereits skaliert scaler = test_dataset.scaler n_channels = test_dataset.n_channels test_start_idx = ( len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON)) + SEQ_LEN ) base_timestamp = pd.read_csv(CSV_PATH, parse_dates=["date"])["date"].iloc[ test_start_idx ] # get original timestamp for later, cause not in dataset anymore # Schritt 1: Finde Index, ab dem Stunde = 00:00 ist offset = 0 while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp( "00:00:00" ).time(): offset += 1 # Neuer Startindex in der Simulation start_index = offset # Session-State bei Bedarf initial setzen if "start_index" not in st.session_state or st.session_state.start_index == 0: st.session_state.start_index = start_index # Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum valid_indices = [] for i in range(start_index, len(test_dataset)): timestamp = base_timestamp + pd.Timedelta(hours=i) if start_date <= timestamp.date() <= end_date: valid_indices.append(i) # Fortschrittsanzeige total_steps = len(valid_indices) # Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!) if "valid_pos" not in st.session_state: st.session_state.valid_pos = 0 # Hauptschleife: Nur noch über gültige Indizes iterieren for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos :]): # for i in range(st.session_state.start_index, len(test_dataset)): if not st.session_state.is_running: break current_pred = predict_transformer_step(model, test_dataset, i, device) current_time = base_timestamp + pd.Timedelta(hours=i) pred_vals.append(current_pred) pred_timestamps.append(current_time) if i >= 1: prev_actual = test_dataset[i - 1][1][ 0, 0 ] # erster Forecast-Wert der letzten Zeile # Rückskalieren dummy_actual = np.zeros((1, n_channels)) dummy_actual[:, 0] = prev_actual actual_val = scaler.inverse_transform(dummy_actual)[0, 0] true_time = current_time - pd.Timedelta(hours=1) if true_time >= pd.to_datetime(start_date): true_vals.append(actual_val) true_timestamps.append(true_time) # Plot erzeugen fig = create_prediction_plot( pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=test_df_filtered["consumption_MW"].min() - 2000, y_max=test_df_filtered["consumption_MW"].max() + 2000, ) if len(pred_vals) >= 2 and len(true_vals) >= 1: render_simulation_view( current_time, current_pred, actual_val if i >= 1 else None, st.session_state.valid_pos / total_steps, fig, ) if not st.session_state.first_plot_shown: spinner_placeholder.empty() st.session_state.first_plot_shown = True plt.close(fig) # Speicher freigeben st.session_state.valid_pos += 1 time.sleep(1 / (speed + 1e-9)) st.success("Simulation completed!") # ============================== Scroll Sync ============================== st.markdown( """ """, unsafe_allow_html=True, )