import os
import time
import warnings
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
import torch
from config_streamlit import DATA_PATH, PLOT_COLOR, TRAIN_RATIO
from huggingface_hub import hf_hub_download
from lightgbm_model.scripts.config_lightgbm import FEATURES
from lightgbm_model.scripts.model_loader_wrapper import load_lightgbm_model
from streamlit_simulation.utils_streamlit import load_data as load_data_raw
from transformer_model.scripts.config_transformer import (FORECAST_HORIZON,
SEQ_LEN)
from transformer_model.scripts.utils.informer_dataset_class import \
InformerDataset
from transformer_model.scripts.utils.model_loader_wrapper import \
load_model_and_dataset
# ============================== Layout ==============================
# Streamlit & warnings config
warnings.filterwarnings("ignore", category=FutureWarning)
st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide")
# CSS part
st.markdown(
f"""
""",
unsafe_allow_html=True,
)
st.title("Electricity Consumption Forecast: Hourly Simulation")
st.write("Welcome to the simulation interface!")
st.info(
"**Simulation Overview:**\n\n"
"This dashboard provides an hourly electricity consumption forecast using two different models: "
"**LightGBM** and a **Transformer (moment-based)**. Both models generate a fresh prediction at every time step "
"(i.e., every simulated hour).\n\n"
"Note: Since this app runs on a limited CPU on Hugging Face Spaces, the Transformer model may respond slower "
"compared to local execution. On a standard local CPU, performance is significantly better."
)
# ============================== Session State Init ===============================
def init_session_state():
defaults = {
"is_running": False,
"start_index": 0,
"true_vals": [],
"pred_vals": [],
"true_timestamps": [],
"pred_timestamps": [],
"last_fig": None,
"valid_pos": 0,
"first_plot_shown": False,
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
init_session_state()
# ============================== Loaders Cache ==============================
HF_REPO = "dlaj/energy-forecasting-files"
HF_FILENAME = "data/processed/energy_consumption_aggregated_cleaned.csv"
# if local data, use them, if not, download from huggingface
if os.path.exists(DATA_PATH):
CSV_PATH = DATA_PATH
else:
CSV_PATH = hf_hub_download(
repo_id=HF_REPO,
filename=HF_FILENAME,
repo_type="dataset",
cache_dir="hf_cache", # Optional
)
@st.cache_data
def load_cached_lightgbm_model():
return load_lightgbm_model()
@st.cache_resource
def load_transformer_model_and_dataset():
return load_model_and_dataset()
@st.cache_data
def load_data():
return load_data_raw()
# ============================== Utility Functions ==============================
def predict_transformer_step(model, dataset, idx, device):
"""Performs a single prediction step with the transformer model."""
timeseries, _, input_mask = dataset[idx]
timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device)
input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device)
with torch.no_grad():
output = model(x_enc=timeseries, input_mask=input_mask)
pred = output.forecast[:, 0, :].cpu().numpy().flatten()
# Rückskalieren
dummy = np.zeros((len(pred), dataset.n_channels))
dummy[:, 0] = pred
pred_original = dataset.scaler.inverse_transform(dummy)[:, 0]
return float(pred_original[0])
def init_simulation_layout():
"""Creates layout containers for plot and info sections."""
col1, spacer, col2 = st.columns([3, 0.2, 1])
plot_title = col1.empty()
plot_container = col1.empty()
x_axis_label = col1.empty()
info_container = col2.empty()
return plot_title, plot_container, x_axis_label, info_container
def create_prediction_plot(
pred_timestamps,
pred_vals,
true_timestamps,
true_vals,
window_hours,
y_min=None,
y_max=None,
):
"""Generates the matplotlib figure for plotting prediction vs. actual."""
fig, ax = plt.subplots(
figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR
)
ax.set_facecolor(PLOT_COLOR)
ax.plot(
pred_timestamps[-window_hours:],
pred_vals[-window_hours:],
label="Prediction",
color="#EF233C",
linestyle="--",
)
if true_vals:
ax.plot(
true_timestamps[-window_hours:],
true_vals[-window_hours:],
label="Actual",
color="#0077B6",
)
ax.set_ylabel("Consumption (MW)", fontsize=8)
ax.legend(
fontsize=8,
loc="upper left",
bbox_to_anchor=(0, 0.95),
# facecolor= INPUT_BG, # INPUT_BG
# edgecolor= ACCENT_COLOR, # ACCENT_COLOR
# labelcolor= TEXT_COLOR # TEXT_COLOR
)
ax.yaxis.grid(True, linestyle=":", linewidth=0.5, alpha=0.7)
ax.set_ylim(y_min, y_max)
ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d"))
ax.tick_params(axis="x", labelrotation=0, labelsize=5)
ax.tick_params(axis="y", labelsize=5)
# fig.patch.set_facecolor('#e6ecf0') # outer area
for spine in ax.spines.values():
spine.set_visible(False)
st.session_state.last_fig = fig
return fig
def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False):
"""Displays the simulation plot and metrics in the UI."""
title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction"
plot_title.markdown(
f"
"
f"{title}
",
unsafe_allow_html=True,
)
plot_container.pyplot(fig)
# st.markdown("", unsafe_allow_html=True)
# x_axis_label.markdown(f""f"Time
",unsafe_allow_html=True)
with info_container.container():
st.markdown(
f"Time: {timestamp}",
unsafe_allow_html=True,
)
st.metric(
"Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–"
)
st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–")
st.caption("Simulation Progress")
st.progress(progress)
if len(st.session_state.true_vals) > 1:
true_arr = np.array(st.session_state.true_vals)
pred_arr = np.array(st.session_state.pred_vals[:-1])
min_len = min(len(true_arr), len(pred_arr))
if min_len >= 1:
errors = np.abs(true_arr[:min_len] - pred_arr[:min_len])
mape = (
np.mean(
errors
/ np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])
)
* 100
)
mae = np.mean(errors)
max_error = np.max(errors)
st.divider()
st.markdown(
"Interim Metrics",
unsafe_allow_html=True,
)
st.metric("MAPE (so far)", f"{mape:.2f} %")
st.metric("MAE (so far)", f"{mae:,.0f} MW")
st.metric("Max Error", f"{max_error:,.0f} MW")
# ============================== Data Preparation ==============================
df_full = load_data()
# Split Train/Test
train_size = int(len(df_full) * TRAIN_RATIO)
test_df_raw = df_full.iloc[train_size:].reset_index(drop=True)
# Start at first full hour (00:00)
first_full_day_index = test_df_raw[
test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time()
].index[0]
test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True)
# Select simulation window via date picker
min_date = test_df_full["date"].min().date()
max_date = test_df_full["date"].max().date()
# ============================== UI Controls ==============================
with st.sidebar:
st.header("⚙️ Simulation Settings")
st.subheader("General Settings")
model_choice = st.selectbox(
"Choose prediction model", ["LightGBM", "Transformer Model (moments)"]
)
if model_choice == "Transformer Model (moments)":
st.caption(
"⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)"
)
window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0)
window_hours = window_days * 24
speed = st.slider("Speed", 1, 10, 5)
st.subheader("Date Range")
start_date = st.date_input(
"Start Date", value=min_date, min_value=min_date, max_value=max_date
)
end_date = st.date_input(
"End Date", value=max_date, min_value=min_date, max_value=max_date
)
# ============================== Data Preparation (filtered) ==============================
# final filtered date window
test_df_filtered = test_df_full[
(test_df_full["date"].dt.date >= start_date)
& (test_df_full["date"].dt.date <= end_date)
].reset_index(drop=True)
# For progression bar
total_steps_ui = len(test_df_filtered)
# ============================== Buttons ==============================
st.markdown("### Start Simulation")
col1, col2, col3 = st.columns([1, 1, 4])
with col1:
play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause"
if st.button(play_pause_text, use_container_width=True):
st.session_state.is_running = not st.session_state.is_running
st.rerun()
with col2:
reset_button = st.button("🔄 Reset", use_container_width=True)
# Reset logic
if reset_button:
st.session_state.start_index = 0
st.session_state.pred_vals = []
st.session_state.true_vals = []
st.session_state.pred_timestamps = []
st.session_state.true_timestamps = []
st.session_state.last_fig = None
st.session_state.is_running = False
st.session_state.valid_pos = 0
st.session_state.first_plot_shown = False
st.rerun()
# Auto-reset on critical parameter change while running
if st.session_state.is_running and (
start_date != st.session_state.get("last_start_date")
or end_date != st.session_state.get("last_end_date")
or model_choice != st.session_state.get("last_model_choice")
):
st.session_state.start_index = 0
st.session_state.pred_vals = []
st.session_state.true_vals = []
st.session_state.pred_timestamps = []
st.session_state.true_timestamps = []
st.session_state.last_fig = None
st.session_state.valid_pos = 0
st.session_state.first_plot_shown = False
st.rerun()
# Track current selections for change detection
st.session_state.last_start_date = start_date
st.session_state.last_end_date = end_date
st.session_state.last_model_choice = model_choice
# ============================== Paused Mode ==============================
if not st.session_state.is_running and st.session_state.last_fig is not None:
st.write("Simulation paused...")
plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
timestamp = (
st.session_state.pred_timestamps[-1]
if st.session_state.pred_timestamps
else "–"
)
prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None
actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None
progress = st.session_state.start_index / total_steps_ui
render_simulation_view(
timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True
)
# ============================== initialize values ==============================
# if lightGbm use testdata from above
if model_choice == "LightGBM":
test_df = test_df_filtered.copy()
# Shared state references for storing predictions and ground truths
true_vals = st.session_state.true_vals
pred_vals = st.session_state.pred_vals
true_timestamps = st.session_state.true_timestamps
pred_timestamps = st.session_state.pred_timestamps
# ============================== LightGBM Simulation ==============================
if model_choice == "LightGBM" and st.session_state.is_running:
model = load_cached_lightgbm_model()
st.write("Simulation started...")
st.markdown('', unsafe_allow_html=True)
plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
for i in range(st.session_state.start_index, len(test_df)):
if not st.session_state.is_running:
break
current = test_df.iloc[i]
timestamp = current["date"]
features = current[FEATURES].values.reshape(1, -1)
prediction = model.predict(features)[0]
pred_vals.append(prediction)
pred_timestamps.append(timestamp)
if i >= 1:
prev_actual = test_df.iloc[i - 1]["consumption_MW"]
prev_time = test_df.iloc[i - 1]["date"]
true_vals.append(prev_actual)
true_timestamps.append(prev_time)
fig = create_prediction_plot(
pred_timestamps,
pred_vals,
true_timestamps,
true_vals,
window_hours,
y_min=test_df_filtered["consumption_MW"].min() - 2000,
y_max=test_df_filtered["consumption_MW"].max() + 2000,
)
render_simulation_view(
timestamp,
prediction,
prev_actual if i >= 1 else None,
i / len(test_df),
fig,
)
plt.close(fig) # Speicher freigeben
st.session_state.start_index = i + 1
time.sleep(1 / (speed + 1e-9))
st.success("Simulation completed!")
# ============================== Transformer Simulation ==============================
spinner_placeholder = st.empty()
if model_choice == "Transformer Model (moments)":
if st.session_state.is_running:
st.write("Simulation started (Transformer)...")
st.markdown('', unsafe_allow_html=True)
if not st.session_state.first_plot_shown:
spinner_placeholder.markdown("Running first prediction – please wait...")
plot_title, plot_container, x_axis_label, info_container = (
init_simulation_layout()
)
# Zugriff auf Modell, Dataset, Device
model, test_dataset, device = load_transformer_model_and_dataset()
data = test_dataset.data # bereits skaliert
scaler = test_dataset.scaler
n_channels = test_dataset.n_channels
test_start_idx = (
len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON))
+ SEQ_LEN
)
base_timestamp = pd.read_csv(CSV_PATH, parse_dates=["date"])["date"].iloc[
test_start_idx
] # get original timestamp for later, cause not in dataset anymore
# Schritt 1: Finde Index, ab dem Stunde = 00:00 ist
offset = 0
while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp(
"00:00:00"
).time():
offset += 1
# Neuer Startindex in der Simulation
start_index = offset
# Session-State bei Bedarf initial setzen
if "start_index" not in st.session_state or st.session_state.start_index == 0:
st.session_state.start_index = start_index
# Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum
valid_indices = []
for i in range(start_index, len(test_dataset)):
timestamp = base_timestamp + pd.Timedelta(hours=i)
if start_date <= timestamp.date() <= end_date:
valid_indices.append(i)
# Fortschrittsanzeige
total_steps = len(valid_indices)
# Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!)
if "valid_pos" not in st.session_state:
st.session_state.valid_pos = 0
# Hauptschleife: Nur noch über gültige Indizes iterieren
for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos :]):
# for i in range(st.session_state.start_index, len(test_dataset)):
if not st.session_state.is_running:
break
current_pred = predict_transformer_step(model, test_dataset, i, device)
current_time = base_timestamp + pd.Timedelta(hours=i)
pred_vals.append(current_pred)
pred_timestamps.append(current_time)
if i >= 1:
prev_actual = test_dataset[i - 1][1][
0, 0
] # erster Forecast-Wert der letzten Zeile
# Rückskalieren
dummy_actual = np.zeros((1, n_channels))
dummy_actual[:, 0] = prev_actual
actual_val = scaler.inverse_transform(dummy_actual)[0, 0]
true_time = current_time - pd.Timedelta(hours=1)
if true_time >= pd.to_datetime(start_date):
true_vals.append(actual_val)
true_timestamps.append(true_time)
# Plot erzeugen
fig = create_prediction_plot(
pred_timestamps,
pred_vals,
true_timestamps,
true_vals,
window_hours,
y_min=test_df_filtered["consumption_MW"].min() - 2000,
y_max=test_df_filtered["consumption_MW"].max() + 2000,
)
if len(pred_vals) >= 2 and len(true_vals) >= 1:
render_simulation_view(
current_time,
current_pred,
actual_val if i >= 1 else None,
st.session_state.valid_pos / total_steps,
fig,
)
if not st.session_state.first_plot_shown:
spinner_placeholder.empty()
st.session_state.first_plot_shown = True
plt.close(fig) # Speicher freigeben
st.session_state.valid_pos += 1
time.sleep(1 / (speed + 1e-9))
st.success("Simulation completed!")
# ============================== Scroll Sync ==============================
st.markdown(
"""
""",
unsafe_allow_html=True,
)