import sys import os import streamlit as st import pickle import pandas as pd import time import numpy as np import matplotlib.pyplot as plt import matplotlib.dates as mdates import warnings import torch from config_streamlit import (MODEL_PATH_LIGHTGBM, DATA_PATH, TRAIN_RATIO, TEXT_COLOR, HEADER_COLOR, ACCENT_COLOR, BUTTON_BG, BUTTON_HOVER_BG, BG_COLOR, INPUT_BG, PROGRESS_COLOR, PLOT_COLOR ) from lightgbm_model.scripts.config_lightgbm import FEATURES from transformer_model.scripts.utils.informer_dataset_class import InformerDataset from transformer_model.scripts.training.load_basis_model import load_moment_model from transformer_model.scripts.config_transformer import CHECKPOINT_DIR, FORECAST_HORIZON, SEQ_LEN from sklearn.preprocessing import StandardScaler from huggingface_hub import hf_hub_download # ============================== Layout ============================== # Streamlit & warnings config warnings.filterwarnings("ignore", category=FutureWarning) st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide") #CSS part st.markdown(f""" """, unsafe_allow_html=True) st.title("Electricity Consumption Forecast: Hourly Simulation") st.write("Welcome to the simulation interface!") # ============================== Session State Init ============================== def init_session_state(): defaults = { "is_running": False, "start_index": 0, "true_vals": [], "pred_vals": [], "true_timestamps": [], "pred_timestamps": [], "last_fig": None, "valid_pos": 0 } for key, value in defaults.items(): if key not in st.session_state: st.session_state[key] = value init_session_state() # ============================== Loaders ============================== @st.cache_data def load_lightgbm_model(): with open(MODEL_PATH_LIGHTGBM, "rb") as f: return pickle.load(f) @st.cache_resource def load_transformer_model_and_dataset(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = load_moment_model() checkpoint_path = hf_hub_download( repo_id="dlaj/energy-forecasting-files", filename="transformer_model/model_final.pth", repo_type="dataset" ) model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.to(device) model.eval() # Datasets train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13) test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13) test_dataset.scaler = train_dataset.scaler return model, test_dataset, device @st.cache_data def load_data(): csv_path = hf_hub_download( repo_id="dlaj/energy-forecasting-files", filename="data/processed/energy_consumption_aggregated_cleaned.csv", repo_type="dataset" ) df = pd.read_csv(csv_path, parse_dates=["date"]) return df # ============================== Utility Functions ============================== def predict_transformer_step(model, dataset, idx, device): """Performs a single prediction step with the transformer model.""" timeseries, _, input_mask = dataset[idx] timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device) input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device) with torch.no_grad(): output = model(x_enc=timeseries, input_mask=input_mask) pred = output.forecast[:, 0, :].cpu().numpy().flatten() # Rückskalieren dummy = np.zeros((len(pred), dataset.n_channels)) dummy[:, 0] = pred pred_original = dataset.scaler.inverse_transform(dummy)[:, 0] return float(pred_original[0]) def init_simulation_layout(): col1, spacer, col2 = st.columns([3, 0.2, 1]) plot_title = col1.empty() plot_container = col1.empty() x_axis_label = col1.empty() info_container = col2.empty() return plot_title, plot_container, x_axis_label, info_container def create_prediction_plot(pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=None, y_max=None): """Generates the matplotlib figure for plotting prediction vs. actual.""" fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR) ax.set_facecolor(PLOT_COLOR) ax.plot(pred_timestamps[-window_hours:], pred_vals[-window_hours:], label="Prediction", color="#EF233C", linestyle="--") if true_vals: ax.plot(true_timestamps[-window_hours:], true_vals[-window_hours:], label="Actual", color="#0077B6") ax.set_ylabel("Consumption (MW)", fontsize=8, color=TEXT_COLOR) ax.legend( fontsize=8, loc="upper left", bbox_to_anchor=(0, 0.95), facecolor= INPUT_BG, # INPUT_BG edgecolor= ACCENT_COLOR, # ACCENT_COLOR labelcolor= TEXT_COLOR # TEXT_COLOR ) ax.yaxis.grid(True, linestyle=':', linewidth=0.5, alpha=0.7) ax.set_ylim(y_min, y_max) ax.xaxis.set_major_locator(mdates.DayLocator(interval=1)) ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d")) ax.tick_params(axis="x", labelrotation=0, labelsize=5, colors=TEXT_COLOR) ax.tick_params(axis="y", labelsize=5, colors=TEXT_COLOR) #fig.patch.set_facecolor('#e6ecf0') # outer area for spine in ax.spines.values(): spine.set_visible(False) st.session_state.last_fig = fig return fig def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False): """Displays the simulation plot and metrics in the UI.""" title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction" plot_title.markdown( f"
" f"{title}
", unsafe_allow_html=True ) plot_container.pyplot(fig) st.markdown("
", unsafe_allow_html=True) x_axis_label.markdown( f"
" f"Time
", unsafe_allow_html=True ) with info_container.container(): st.markdown("
", unsafe_allow_html=True) st.markdown( f"Time: {timestamp}", unsafe_allow_html=True ) st.metric("Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–") st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–") st.caption("Simulation Progress") st.progress(progress) if len(st.session_state.true_vals) > 1: true_arr = np.array(st.session_state.true_vals) pred_arr = np.array(st.session_state.pred_vals[:-1]) min_len = min(len(true_arr), len(pred_arr)) #just start if there are 2 actual values if min_len >= 1: errors = np.abs(true_arr[:min_len] - pred_arr[:min_len]) mape = np.mean(errors / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])) * 100 mae = np.mean(errors) max_error = np.max(errors) st.divider() st.markdown( f"Interim Metrics", unsafe_allow_html=True ) st.metric("MAPE (so far)", f"{mape:.2f} %") st.metric("MAE (so far)", f"{mae:,.0f} MW") st.metric("Max Error", f"{max_error:,.0f} MW") # ============================== Data Preparation ============================== df_full = load_data() # Split Train/Test train_size = int(len(df_full) * TRAIN_RATIO) test_df_raw = df_full.iloc[train_size:].reset_index(drop=True) # Start at first full hour (00:00) first_full_day_index = test_df_raw[test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time()].index[0] test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True) # Select simulation window via date picker min_date = test_df_full["date"].min().date() max_date = test_df_full["date"].max().date() # ============================== UI Controls ============================== st.markdown("### Simulation Settings") col1, col2 = st.columns([1, 1]) with col1: st.markdown("**General Settings**") model_choice = st.selectbox("Choose prediction model", ["LightGBM", "Transformer Model (moments)"]) if model_choice == "Transformer Model(moments)": st.caption("⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)") window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0) window_hours = window_days * 24 speed = st.slider("Speed", 1, 10, 5) with col2: st.markdown(f"**Date Range** (from {min_date} to {max_date})") start_date = st.date_input("Start Date", value=min_date, min_value=min_date, max_value=max_date) end_date = st.date_input("End Date", value=max_date, min_value=min_date, max_value=max_date) # ============================== Data Preparation (filtered) ============================== # final filtered date window test_df_filtered = test_df_full[ (test_df_full["date"].dt.date >= start_date) & (test_df_full["date"].dt.date <= end_date) ].reset_index(drop=True) # For progression bar total_steps_ui = len(test_df_filtered) # ============================== Buttons ============================== st.markdown("### Start Simulation") col1, col2, col3 = st.columns([1, 1, 14]) with col1: play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause" if st.button(play_pause_text): st.session_state.is_running = not st.session_state.is_running st.rerun() with col2: reset_button = st.button("🔄 Reset") # Reset logic if reset_button: st.session_state.start_index = 0 st.session_state.pred_vals = [] st.session_state.true_vals = [] st.session_state.pred_timestamps = [] st.session_state.true_timestamps = [] st.session_state.last_fig = None st.session_state.is_running = False st.session_state.valid_pos = 0 st.rerun() # Auto-reset on critical parameter change while running if st.session_state.is_running and ( start_date != st.session_state.get("last_start_date") or end_date != st.session_state.get("last_end_date") or model_choice != st.session_state.get("last_model_choice") ): st.session_state.start_index = 0 st.session_state.pred_vals = [] st.session_state.true_vals = [] st.session_state.pred_timestamps = [] st.session_state.true_timestamps = [] st.session_state.last_fig = None st.session_state.valid_pos = 0 st.rerun() # Track current selections for change detection st.session_state.last_start_date = start_date st.session_state.last_end_date = end_date st.session_state.last_model_choice = model_choice # ============================== Paused Mode ============================== if not st.session_state.is_running and st.session_state.last_fig is not None: st.write("Simulation paused...") plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() timestamp = st.session_state.pred_timestamps[-1] if st.session_state.pred_timestamps else "–" prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None progress = st.session_state.start_index / total_steps_ui render_simulation_view(timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True) # ============================== initialize values ============================== #if lightGbm use testdata from above if model_choice == "LightGBM": test_df = test_df_filtered.copy() #Shared state references for storing predictions and ground truths true_vals = st.session_state.true_vals pred_vals = st.session_state.pred_vals true_timestamps = st.session_state.true_timestamps pred_timestamps = st.session_state.pred_timestamps # ============================== LightGBM Simulation ============================== if model_choice == "LightGBM" and st.session_state.is_running: model = load_lightgbm_model() st.write("Simulation started...") st.markdown('
', unsafe_allow_html=True) plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() for i in range(st.session_state.start_index, len(test_df)): if not st.session_state.is_running: break current = test_df.iloc[i] timestamp = current["date"] features = current[FEATURES].values.reshape(1, -1) prediction = model.predict(features)[0] pred_vals.append(prediction) pred_timestamps.append(timestamp) if i >= 1: prev_actual = test_df.iloc[i - 1]["consumption_MW"] prev_time = test_df.iloc[i - 1]["date"] true_vals.append(prev_actual) true_timestamps.append(prev_time) fig = create_prediction_plot( pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min= test_df_filtered["consumption_MW"].min() - 2000, y_max= test_df_filtered["consumption_MW"].max() + 2000 ) render_simulation_view(timestamp, prediction, prev_actual if i >= 1 else None, i / len(test_df), fig) plt.close(fig) # Speicher freigeben st.session_state.start_index = i + 1 time.sleep(1 / (speed + 1e-9)) st.success("Simulation completed!") # ============================== Transformer Simulation ============================== if model_choice == "Transformer Model(moments)": if st.session_state.is_running: st.write("Simulation started (Transformer)...") st.markdown('
', unsafe_allow_html=True) plot_title, plot_container, x_axis_label, info_container = init_simulation_layout() # Zugriff auf Modell, Dataset, Device model, test_dataset, device = load_transformer_model_and_dataset() data = test_dataset.data # bereits skaliert scaler = test_dataset.scaler n_channels = test_dataset.n_channels test_start_idx = len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON)) + SEQ_LEN base_timestamp = pd.read_csv(DATA_PATH, parse_dates=["date"])["date"].iloc[test_start_idx] #get original timestamp for later, cause not in dataset anymore # Schritt 1: Finde Index, ab dem Stunde = 00:00 ist offset = 0 while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp("00:00:00").time(): offset += 1 # Neuer Startindex in der Simulation start_index = offset # Session-State bei Bedarf initial setzen if "start_index" not in st.session_state or st.session_state.start_index == 0: st.session_state.start_index = start_index # Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum valid_indices = [] for i in range(start_index, len(test_dataset)): timestamp = base_timestamp + pd.Timedelta(hours=i) if start_date <= timestamp.date() <= end_date: valid_indices.append(i) # Fortschrittsanzeige total_steps = len(valid_indices) # Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!) if "valid_pos" not in st.session_state: st.session_state.valid_pos = 0 # Hauptschleife: Nur noch über gültige Indizes iterieren for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos:]): #for i in range(st.session_state.start_index, len(test_dataset)): if not st.session_state.is_running: break current_pred = predict_transformer_step(model, test_dataset, i, device) current_time = base_timestamp + pd.Timedelta(hours=i) pred_vals.append(current_pred) pred_timestamps.append(current_time) if i >= 1: prev_actual = test_dataset[i - 1][1][0, 0] # erster Forecast-Wert der letzten Zeile # Rückskalieren dummy_actual = np.zeros((1, n_channels)) dummy_actual[:, 0] = prev_actual actual_val = scaler.inverse_transform(dummy_actual)[0, 0] true_time = current_time - pd.Timedelta(hours=1) if true_time >= pd.to_datetime(start_date): true_vals.append(actual_val) true_timestamps.append(true_time) # Plot erzeugen fig = create_prediction_plot( pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min= test_df_filtered["consumption_MW"].min() - 2000, y_max= test_df_filtered["consumption_MW"].max() + 2000 ) if len(pred_vals) >= 2 and len(true_vals) >= 1: render_simulation_view(current_time, current_pred, actual_val if i >= 1 else None, st.session_state.valid_pos / total_steps, fig) plt.close(fig) # Speicher freigeben st.session_state.valid_pos += 1 time.sleep(1 / (speed + 1e-9)) st.success("Simulation completed!") # ============================== Scroll Sync ============================== st.markdown(""" """, unsafe_allow_html=True)