3v324v23 commited on
Commit
c689089
·
1 Parent(s): 0d7166f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .huggingface.yml +2 -0
  2. Hugging_face_app.zip → lightgbm_model/model/lightgbm_final_model.pkl +2 -2
  3. lightgbm_model/results/X_test.csv +0 -0
  4. lightgbm_model/results/X_train.csv +0 -0
  5. lightgbm_model/results/lightgbm_eval_result.pkl +3 -0
  6. lightgbm_model/results/y_test.csv +0 -0
  7. lightgbm_model/scripts/__init__.py +2 -0
  8. lightgbm_model/scripts/__pycache__/__init__.cpython-311.pyc +0 -0
  9. lightgbm_model/scripts/__pycache__/config_lightgbm.cpython-311.pyc +0 -0
  10. lightgbm_model/scripts/config_lightgbm.py +36 -0
  11. lightgbm_model/scripts/eval/__pycache__/eval_lightgbm.cpython-311.pyc +0 -0
  12. lightgbm_model/scripts/eval/eval_lightgbm.py +107 -0
  13. lightgbm_model/scripts/train/__pycache__/train_lightgbm.cpython-311.pyc +0 -0
  14. lightgbm_model/scripts/train/train_lightgbm.py +78 -0
  15. requirements.txt +31 -0
  16. setup.py +7 -0
  17. streamlit_simulation/__init__.py +2 -0
  18. streamlit_simulation/__pycache__/config_streamlit.cpython-311.pyc +0 -0
  19. streamlit_simulation/__pycache__/config_streamlit.cpython-312.pyc +0 -0
  20. streamlit_simulation/app.py +535 -0
  21. streamlit_simulation/config_streamlit.py +26 -0
  22. transformer_model/results/evaluation_metrics.json +1 -0
  23. transformer_model/results/test_results.csv +0 -0
  24. transformer_model/results/training_metrics.json +1 -0
  25. transformer_model/scripts/__init__.py +2 -0
  26. transformer_model/scripts/__pycache__/__init__.cpython-311.pyc +0 -0
  27. transformer_model/scripts/__pycache__/check_device.cpython-311.pyc +0 -0
  28. transformer_model/scripts/__pycache__/config.cpython-311.pyc +0 -0
  29. transformer_model/scripts/__pycache__/config_transformer.cpython-311.pyc +0 -0
  30. transformer_model/scripts/__pycache__/create_dataloaders.cpython-311.pyc +0 -0
  31. transformer_model/scripts/__pycache__/informer_dataset_class.cpython-311.pyc +0 -0
  32. transformer_model/scripts/__pycache__/load_basis_model.cpython-311.pyc +0 -0
  33. transformer_model/scripts/config_transformer.py +31 -0
  34. transformer_model/scripts/evaluation/__init__.py +1 -0
  35. transformer_model/scripts/evaluation/__pycache__/__init__.cpython-311.pyc +0 -0
  36. transformer_model/scripts/evaluation/__pycache__/evaluate.cpython-311.pyc +0 -0
  37. transformer_model/scripts/evaluation/__pycache__/plot_metrics.cpython-311.pyc +0 -0
  38. transformer_model/scripts/evaluation/evaluate.py +124 -0
  39. transformer_model/scripts/evaluation/plot_metrics.py +77 -0
  40. transformer_model/scripts/training/__init__.py +1 -0
  41. transformer_model/scripts/training/__pycache__/__init__.cpython-311.pyc +0 -0
  42. transformer_model/scripts/training/__pycache__/load_basis_model.cpython-311.pyc +0 -0
  43. transformer_model/scripts/training/__pycache__/train.cpython-311.pyc +0 -0
  44. transformer_model/scripts/training/load_basis_model.py +67 -0
  45. transformer_model/scripts/training/train.py +202 -0
  46. transformer_model/scripts/utils/__init__.py +1 -0
  47. transformer_model/scripts/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  48. transformer_model/scripts/utils/__pycache__/check_device.cpython-311.pyc +0 -0
  49. transformer_model/scripts/utils/__pycache__/create_dataloaders.cpython-311.pyc +0 -0
  50. transformer_model/scripts/utils/__pycache__/informer_dataset_class.cpython-311.pyc +0 -0
.huggingface.yml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sdk: streamlit
2
+ app_file: streamlit_simulation/app.py
Hugging_face_app.zip → lightgbm_model/model/lightgbm_final_model.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c15d038b5708a8b42682dcca58c3b94c9eeccd2352c2ba30beed6ed61585e84
3
- size 8060127
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52777b05bde0cc4665aac0d18993701769c84edaf0ffe9cb3b82049fd779b56d
3
+ size 1534227
lightgbm_model/results/X_test.csv ADDED
The diff for this file is too large to render. See raw diff
 
lightgbm_model/results/X_train.csv ADDED
The diff for this file is too large to render. See raw diff
 
lightgbm_model/results/lightgbm_eval_result.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c08c03ffc33d679292809af3f37f2f5d10d971c2243a962d5de5bd8b1415a7f5
3
+ size 76208
lightgbm_model/results/y_test.csv ADDED
The diff for this file is too large to render. See raw diff
 
lightgbm_model/scripts/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # __init__.py
2
+
lightgbm_model/scripts/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (196 Bytes). View file
 
lightgbm_model/scripts/__pycache__/config_lightgbm.cpython-311.pyc ADDED
Binary file (1.34 kB). View file
 
lightgbm_model/scripts/config_lightgbm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+
4
+ # === Paths ===
5
+ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
+ DATA_PATH = os.path.join(BASE_DIR, "..", "data", "processed", "energy_consumption_aggregated_cleaned.csv")
7
+ RESULTS_DIR = os.path.join(BASE_DIR, "results")
8
+ MODEL_DIR = os.path.join(BASE_DIR, "model")
9
+
10
+ # === Feature-Definition ===
11
+ FEATURES = [
12
+ "hour_sin", "hour_cos",
13
+ "weekday_sin", "weekday_cos",
14
+ "rolling_mean_6h",
15
+ "month_sin", "month_cos",
16
+ "temperature_c",
17
+ "consumption_last_week",
18
+ "consumption_yesterday",
19
+ "consumption_last_hour"
20
+ ]
21
+ TARGET = "consumption_MW"
22
+
23
+ # === Hyperparameters fpr LightGBM ===
24
+ LIGHTGBM_PARAMS = {
25
+ 'learning_rate': 0.05,
26
+ 'num_leaves': 15,
27
+ 'max_depth': 5,
28
+ 'lambda_l1': 1.0,
29
+ 'lambda_l2': 0.0,
30
+ 'min_split_gain': 0.0,
31
+ 'n_estimators': 1000,
32
+ 'objective': 'regression'}
33
+
34
+ # === Early Stopping ===
35
+ EARLY_STOPPING_ROUNDS = 50
36
+
lightgbm_model/scripts/eval/__pycache__/eval_lightgbm.cpython-311.pyc ADDED
Binary file (7.07 kB). View file
 
lightgbm_model/scripts/eval/eval_lightgbm.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # eval_model.py
2
+
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import pickle
8
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
9
+ from lightgbm_model.scripts.config_lightgbm import RESULTS_DIR, MODEL_DIR, DATA_PATH
10
+ from joblib import load
11
+
12
+ # === Ergebnisse-Ordner vorbereiten ===
13
+ os.makedirs(RESULTS_DIR, exist_ok=True)
14
+
15
+ # === Modell und eval_result laden ===
16
+ # Modell laden
17
+ with open(os.path.join(MODEL_DIR, "lightgbm_final_model.pkl"), "rb") as f:
18
+ model = pickle.load(f)
19
+
20
+ # Eval laden
21
+ with open(os.path.join(RESULTS_DIR, "lightgbm_eval_result.pkl"), "rb") as f:
22
+ eval_result = pickle.load(f)
23
+ X_train = pd.read_csv(os.path.join(RESULTS_DIR, "X_train.csv"))
24
+ X_test = pd.read_csv(os.path.join(RESULTS_DIR, "X_test.csv"))
25
+ y_test = pd.read_csv(os.path.join(RESULTS_DIR, "y_test.csv"))
26
+
27
+ # === Lernkurve ===
28
+ train_rmse = eval_result['training']['rmse']
29
+ valid_rmse = eval_result['valid_1']['rmse']
30
+
31
+ plt.figure(figsize=(10, 5))
32
+ plt.plot(train_rmse, label='Train RMSE')
33
+ plt.plot(valid_rmse, label='Valid RMSE')
34
+ plt.axvline(model.best_iteration_, color='gray', linestyle='--', label='Best Iteration')
35
+ plt.xlabel("Boosting Round")
36
+ plt.ylabel("RMSE")
37
+ plt.title("LightGBM Learning Curve")
38
+ plt.legend()
39
+ plt.tight_layout()
40
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_learning_curve.png"))
41
+ #plt.show()
42
+
43
+ # === Metriken berechnen ===
44
+ y_pred = model.predict(X_test)
45
+ mae = mean_absolute_error(y_test, y_pred)
46
+ rmse = np.sqrt(mean_squared_error(y_test, y_pred))
47
+ mape = np.mean(np.abs((y_test.values.flatten() - y_pred) / np.where(y_test.values.flatten() == 0, 1e-10, y_test.values.flatten()))) * 100
48
+
49
+ print(f"Test MAPE: {mape:.5f} %")
50
+ print(f"Test MAE: {mae:.5f}")
51
+ print(f"Test RMSE: {rmse:.5f}")
52
+
53
+ # === Feature Importance ===
54
+ feature_importance = pd.DataFrame({
55
+ "Feature": X_train.columns,
56
+ "Importance": model.feature_importances_
57
+ }).sort_values(by="Importance", ascending=False)
58
+
59
+ plt.figure(figsize=(10, 6))
60
+ plt.barh(feature_importance["Feature"], feature_importance["Importance"])
61
+ plt.xlabel("Feature Importance")
62
+ plt.title("LightGBM Feature Importance")
63
+ plt.gca().invert_yaxis()
64
+ plt.tight_layout()
65
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_feature_importance.png"))
66
+ #plt.show()
67
+
68
+ # === Vergleichsplots ===
69
+ results_df = pd.DataFrame({
70
+ "True Consumption (MW)": y_test.values.flatten(),
71
+ "Predicted Consumption (MW)": y_pred
72
+ })
73
+
74
+ # Timestamps anhängen
75
+ full_df = pd.read_csv(DATA_PATH)
76
+ test_dates = full_df.iloc[int(len(full_df)*0.8):]["date"].reset_index(drop=True)
77
+ results_df["Timestamp"] = pd.to_datetime(test_dates)
78
+
79
+ # Voller Plot
80
+ plt.figure(figsize=(15, 6))
81
+ plt.plot(results_df["Timestamp"], results_df["True Consumption (MW)"], label="True", color="darkblue")
82
+ plt.plot(results_df["Timestamp"], results_df["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
83
+ plt.title("Predicted vs True Consumption")
84
+ plt.xlabel("Timestamp")
85
+ plt.ylabel("Consumption (MW)")
86
+ plt.legend()
87
+ plt.tight_layout()
88
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_comparison_plot.png"))
89
+ #plt.show()
90
+
91
+ # Subset Plot
92
+ subset = results_df.iloc[:len(results_df) // 10]
93
+ plt.figure(figsize=(15, 6))
94
+ plt.plot(subset["Timestamp"], subset["True Consumption (MW)"], label="True", color="darkblue")
95
+ plt.plot(subset["Timestamp"], subset["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
96
+ plt.title("Predicted vs True (First decile)")
97
+ plt.xlabel("Timestamp")
98
+ plt.ylabel("Consumption (MW)")
99
+ plt.legend()
100
+ plt.tight_layout()
101
+ plt.savefig(os.path.join(RESULTS_DIR, "lightgbm_prediction_with_timestamp.png"))
102
+ #plt.show()
103
+
104
+
105
+ # === Ens message ===
106
+ print("\nEvaluation completed.")
107
+ print(f"All Plots stored in:\n→ {RESULTS_DIR}")
lightgbm_model/scripts/train/__pycache__/train_lightgbm.cpython-311.pyc ADDED
Binary file (3.44 kB). View file
 
lightgbm_model/scripts/train/train_lightgbm.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_lightgbm.py
2
+
3
+ import os
4
+ import pickle
5
+ import pandas as pd
6
+ import numpy as np
7
+ import lightgbm as lgb
8
+ from lightgbm import LGBMRegressor, early_stopping, record_evaluation
9
+
10
+ from lightgbm_model.scripts.config_lightgbm import (
11
+ DATA_PATH,
12
+ FEATURES,
13
+ TARGET,
14
+ LIGHTGBM_PARAMS,
15
+ EARLY_STOPPING_ROUNDS,
16
+ RESULTS_DIR,
17
+ MODEL_DIR
18
+ )
19
+
20
+ # === Load Data ===
21
+ df = pd.read_csv(DATA_PATH)
22
+
23
+ # Drop date (used later for plots only)
24
+ df = df.drop(columns=["date"], errors="ignore")
25
+
26
+ # === Time-based Split (70% train, 10% valid, 20% test) ===
27
+ train_size = int(len(df) * 0.7)
28
+ valid_size = int(len(df) * 0.1)
29
+ df_train = df.iloc[:train_size]
30
+ df_valid = df.iloc[train_size:train_size + valid_size]
31
+ df_test = df.iloc[train_size + valid_size:]
32
+
33
+ X_train, y_train = df_train[FEATURES], df_train[TARGET]
34
+ X_valid, y_valid = df_valid[FEATURES], df_valid[TARGET]
35
+ X_test, y_test = df_test[FEATURES], df_test[TARGET]
36
+
37
+
38
+ # === Init LightGBM model ===
39
+ eval_result = {}
40
+
41
+ model = LGBMRegressor(
42
+ **LIGHTGBM_PARAMS,
43
+ verbosity=-1
44
+ )
45
+
46
+ model.fit(
47
+ X_train,
48
+ y_train,
49
+ eval_set=[(X_train, y_train), (X_valid, y_valid)],
50
+ eval_metric="rmse",
51
+ callbacks=[
52
+ early_stopping(EARLY_STOPPING_ROUNDS),
53
+ record_evaluation(eval_result)
54
+ ]
55
+ )
56
+
57
+ # === Save model ===
58
+ os.makedirs(MODEL_DIR, exist_ok=True)
59
+ model_path = os.path.join(MODEL_DIR, "lightgbm_final_model.pkl")
60
+
61
+ with open(model_path, "wb") as f:
62
+ pickle.dump(model, f)
63
+
64
+ # === Save evaluation results ===
65
+ os.makedirs(RESULTS_DIR, exist_ok=True)
66
+ eval_result_path = os.path.join(RESULTS_DIR, "lightgbm_eval_result.pkl")
67
+
68
+ with open(eval_result_path, "wb") as f:
69
+ pickle.dump(eval_result, f)
70
+
71
+ print(f"Model saved to: {model_path}")
72
+ print(f"Eval results saved to: {eval_result_path}")
73
+
74
+ # === Save data for evaluation ===
75
+ X_train.to_csv(os.path.join(RESULTS_DIR, "X_train.csv"), index=False)
76
+ X_test.to_csv(os.path.join(RESULTS_DIR, "X_test.csv"), index=False)
77
+ y_test.to_csv(os.path.join(RESULTS_DIR, "y_test.csv"), index=False)
78
+
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================
2
+ # Requirements for Energy Prediction Project
3
+ # =============================
4
+
5
+ # Python 3.11 environment recommended since moments dont work with later versions
6
+
7
+ # Moment Foundation Model (forecasting backbone)
8
+ momentfm @ git+https://github.com/moment-timeseries-foundation-model/moment.git@37a8bde4eb3dd340bebc9b54a3b893bcba62cd4f
9
+
10
+ # === Core Python stack ===
11
+ numpy==1.25.2 # Numerical operations
12
+ pandas==2.2.2 # Data manipulation and analysis
13
+ matplotlib==3.10.0 # Plotting and visualizations
14
+
15
+ # === Machine Learning ===
16
+ scikit-learn==1.6.1 # Evaluation metrics and preprocessing utilities
17
+ torch==2.6.0 # PyTorch with CUDA 12.4 (GPU support)
18
+ torchvision==0.21.0+cu124 # Optional (can support visual tasks, not critical here)
19
+ torchaudio==2.6.0+cu124 # Optional (comes with torch install, can stay)
20
+
21
+ # === Utilities ===
22
+ tqdm==4.67.1 # Progress bars
23
+ ipywidgets>=8.0 # Enables tqdm progress bars in Jupyter/Colab
24
+ pprintpp==0.4.0 # Prettier print formatting for nested dicts (used for model output check)
25
+
26
+ # === lightgbm ===
27
+ lightgbm==4.3.0 # Boosted Trees for tabular modeling (used for baseline and feature selection)
28
+
29
+ # === Streamlit App ===
30
+ streamlit>=1.30.0
31
+ plotly>=5.0.0
setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="energy_prediction",
5
+ version="0.1",
6
+ packages=find_packages(),
7
+ )
streamlit_simulation/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # __init__.py
2
+
streamlit_simulation/__pycache__/config_streamlit.cpython-311.pyc ADDED
Binary file (1.26 kB). View file
 
streamlit_simulation/__pycache__/config_streamlit.cpython-312.pyc ADDED
Binary file (948 Bytes). View file
 
streamlit_simulation/app.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import streamlit as st
4
+ import pickle
5
+ import pandas as pd
6
+ import time
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.dates as mdates
10
+ import warnings
11
+ import torch
12
+
13
+ from config_streamlit import (MODEL_PATH_LIGHTGBM, DATA_PATH, TRAIN_RATIO,
14
+ TEXT_COLOR, HEADER_COLOR, ACCENT_COLOR,
15
+ BUTTON_BG, BUTTON_HOVER_BG, BG_COLOR,
16
+ INPUT_BG, PROGRESS_COLOR, PLOT_COLOR
17
+ )
18
+ from lightgbm_model.scripts.config_lightgbm import FEATURES
19
+ from transformer_model.scripts.utils.informer_dataset_class import InformerDataset
20
+ from transformer_model.scripts.training.load_basis_model import load_moment_model
21
+ from transformer_model.scripts.config_transformer import CHECKPOINT_DIR, FORECAST_HORIZON, SEQ_LEN
22
+ from sklearn.preprocessing import StandardScaler
23
+
24
+
25
+ # ============================== Layout ==============================
26
+
27
+ # Streamlit & warnings config
28
+ warnings.filterwarnings("ignore", category=FutureWarning)
29
+ st.set_page_config(page_title="Electricity Consumption Forecast", layout="wide")
30
+
31
+ #CSS part
32
+ st.markdown(f"""
33
+ <style>
34
+ body, .block-container {{
35
+ background-color: {BG_COLOR} !important;
36
+ }}
37
+
38
+ html, body, [class*="css"] {{
39
+ color: {TEXT_COLOR} !important;
40
+ font-family: 'sans-serif';
41
+ }}
42
+
43
+ h1, h2, h3, h4, h5, h6 {{
44
+ color: {HEADER_COLOR} !important;
45
+ }}
46
+
47
+ .stButton > button {{
48
+ background-color: {BUTTON_BG};
49
+ color: {TEXT_COLOR};
50
+ border: 1px solid {ACCENT_COLOR};
51
+ }}
52
+
53
+ .stButton > button:hover {{
54
+ background-color: {BUTTON_HOVER_BG};
55
+ }}
56
+
57
+ .stSelectbox div[data-baseweb="select"],
58
+ .stDateInput input {{
59
+ background-color: {INPUT_BG} !important;
60
+ color: {TEXT_COLOR} !important;
61
+ }}
62
+
63
+ [data-testid="stMetricLabel"],
64
+ [data-testid="stMetricValue"] {{
65
+ color: {TEXT_COLOR} !important;
66
+ }}
67
+
68
+ .stMarkdown p {{
69
+ color: {TEXT_COLOR} !important;
70
+ }}
71
+
72
+ .stDataFrame tbody tr td {{
73
+ color: {TEXT_COLOR} !important;
74
+ }}
75
+
76
+ .stProgress > div > div {{
77
+ background-color: {PROGRESS_COLOR} !important;
78
+ }}
79
+
80
+ /* Alle Label-Texte für Inputs/Sliders */
81
+ label {{
82
+ color: {TEXT_COLOR} !important;
83
+ }}
84
+
85
+ /* Text in selectbox-Optionsfeldern */
86
+ .stSelectbox label, .stSelectbox div {{
87
+ color: {TEXT_COLOR} !important;
88
+ }}
89
+
90
+ /* DateInput angleichen an Selectbox */
91
+ .stDateInput input {{
92
+ background-color: #f2f6fa !important;
93
+ color: {TEXT_COLOR} !important;
94
+ border: none !important;
95
+ border-radius: 5px !important;
96
+ }}
97
+
98
+ </style>
99
+ """, unsafe_allow_html=True)
100
+
101
+ st.title("Electricity Consumption Forecast: Hourly Simulation")
102
+ st.write("Welcome to the simulation interface!")
103
+
104
+ # ============================== Session State Init ==============================
105
+ def init_session_state():
106
+ defaults = {
107
+ "is_running": False,
108
+ "start_index": 0,
109
+ "true_vals": [],
110
+ "pred_vals": [],
111
+ "true_timestamps": [],
112
+ "pred_timestamps": [],
113
+ "last_fig": None,
114
+ "valid_pos": 0
115
+ }
116
+ for key, value in defaults.items():
117
+ if key not in st.session_state:
118
+ st.session_state[key] = value
119
+
120
+ init_session_state()
121
+
122
+ # ============================== Loaders ==============================
123
+
124
+ @st.cache_data
125
+ def load_lightgbm_model():
126
+ with open(MODEL_PATH_LIGHTGBM, "rb") as f:
127
+ return pickle.load(f)
128
+
129
+ @st.cache_resource
130
+ def load_transformer_model_and_dataset():
131
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+
133
+ # Load model
134
+ model = load_moment_model()
135
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, "model_final.pth")
136
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
137
+ model.to(device)
138
+ model.eval()
139
+
140
+ # Datasets
141
+ train_dataset = InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON, random_seed=13)
142
+ test_dataset = InformerDataset(data_split="test", forecast_horizon=FORECAST_HORIZON, random_seed=13)
143
+ test_dataset.scaler = train_dataset.scaler
144
+
145
+ return model, test_dataset, device
146
+
147
+ @st.cache_data
148
+ def load_data():
149
+ df = pd.read_csv(DATA_PATH, parse_dates=["date"])
150
+ return df
151
+
152
+
153
+ # ============================== Utility Functions ==============================
154
+
155
+ def predict_transformer_step(model, dataset, idx, device):
156
+ """Performs a single prediction step with the transformer model."""
157
+ timeseries, _, input_mask = dataset[idx]
158
+ timeseries = torch.tensor(timeseries, dtype=torch.float32).unsqueeze(0).to(device)
159
+ input_mask = torch.tensor(input_mask, dtype=torch.bool).unsqueeze(0).to(device)
160
+
161
+ with torch.no_grad():
162
+ output = model(x_enc=timeseries, input_mask=input_mask)
163
+
164
+ pred = output.forecast[:, 0, :].cpu().numpy().flatten()
165
+
166
+ # Rückskalieren
167
+ dummy = np.zeros((len(pred), dataset.n_channels))
168
+ dummy[:, 0] = pred
169
+ pred_original = dataset.scaler.inverse_transform(dummy)[:, 0]
170
+
171
+ return float(pred_original[0])
172
+
173
+
174
+ def init_simulation_layout():
175
+ """Creates layout containers for plot and info sections."""
176
+ col1, spacer, col2 = st.columns([3, 0.2, 1])
177
+ plot_title = col1.empty()
178
+ plot_container = col1.empty()
179
+ x_axis_label = col1.empty()
180
+ info_container = col2.empty()
181
+ return plot_title, plot_container, x_axis_label, info_container
182
+
183
+
184
+ def create_prediction_plot(pred_timestamps, pred_vals, true_timestamps, true_vals, window_hours, y_min=None, y_max=None):
185
+ """Generates the matplotlib figure for plotting prediction vs. actual."""
186
+ fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True, facecolor=PLOT_COLOR)
187
+ ax.set_facecolor(PLOT_COLOR)
188
+
189
+ ax.plot(pred_timestamps[-window_hours:], pred_vals[-window_hours:], label="Prediction", color="#EF233C", linestyle="--")
190
+ if true_vals:
191
+ ax.plot(true_timestamps[-window_hours:], true_vals[-window_hours:], label="Actual", color="#0077B6")
192
+
193
+ ax.set_ylabel("Consumption (MW)", fontsize=8, color=TEXT_COLOR)
194
+ ax.legend(
195
+ fontsize=8,
196
+ loc="upper left",
197
+ bbox_to_anchor=(0, 0.95),
198
+ facecolor= INPUT_BG, # INPUT_BG
199
+ edgecolor= ACCENT_COLOR, # ACCENT_COLOR
200
+ labelcolor= TEXT_COLOR # TEXT_COLOR
201
+ )
202
+ ax.yaxis.grid(True, linestyle=':', linewidth=0.5, alpha=0.7)
203
+ ax.set_ylim(y_min, y_max)
204
+ ax.xaxis.set_major_locator(mdates.DayLocator(interval=1))
205
+ ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d"))
206
+ ax.tick_params(axis="x", labelrotation=0, labelsize=5, colors=TEXT_COLOR)
207
+ ax.tick_params(axis="y", labelsize=5, colors=TEXT_COLOR)
208
+ #fig.patch.set_facecolor('#e6ecf0') # outer area
209
+
210
+ for spine in ax.spines.values():
211
+ spine.set_visible(False)
212
+
213
+ st.session_state.last_fig = fig
214
+ return fig
215
+
216
+
217
+ def render_simulation_view(timestamp, prediction, actual, progress, fig, paused=False):
218
+ """Displays the simulation plot and metrics in the UI."""
219
+ title = "Actual vs. Prediction (Paused)" if paused else "Actual vs. Prediction"
220
+ plot_title.markdown(
221
+ f"<div style='text-align: center; font-size: 20pt; font-weight: bold; color: {TEXT_COLOR}; margin-bottom: -0.7rem; margin-top: 0rem;'>"
222
+ f"{title}</div>",
223
+ unsafe_allow_html=True
224
+ )
225
+ plot_container.pyplot(fig)
226
+
227
+ st.markdown("<div style='margin-bottom: 0.5rem;'></div>", unsafe_allow_html=True)
228
+ x_axis_label.markdown(
229
+ f"<div style='text-align: center; font-size: 14pt; color: {TEXT_COLOR}; margin-top: -0.5rem;'>"
230
+ f"Time</div>",
231
+ unsafe_allow_html=True
232
+ )
233
+
234
+ with info_container.container():
235
+ st.markdown("<div style='margin-top: 5rem;'></div>", unsafe_allow_html=True)
236
+ st.markdown(
237
+ f"<span style='font-size: 24px; font-weight: 600; color: {HEADER_COLOR} !important;'>Time: {timestamp}</span>",
238
+ unsafe_allow_html=True
239
+ )
240
+
241
+ st.metric("Prediction", f"{prediction:,.0f} MW" if prediction is not None else "–")
242
+ st.metric("Actual", f"{actual:,.0f} MW" if actual is not None else "–")
243
+ st.caption("Simulation Progress")
244
+ st.progress(progress)
245
+
246
+ if len(st.session_state.true_vals) > 1:
247
+ true_arr = np.array(st.session_state.true_vals)
248
+ pred_arr = np.array(st.session_state.pred_vals[:-1])
249
+
250
+ min_len = min(len(true_arr), len(pred_arr)) #just start if there are 2 actual values
251
+ if min_len >= 1:
252
+ errors = np.abs(true_arr[:min_len] - pred_arr[:min_len])
253
+ mape = np.mean(errors / np.where(true_arr[:min_len] == 0, 1e-10, true_arr[:min_len])) * 100
254
+ mae = np.mean(errors)
255
+ max_error = np.max(errors)
256
+
257
+ st.divider()
258
+ st.markdown(
259
+ f"<span style='font-size: 24px; font-weight: 600; color: {HEADER_COLOR} !important;'>Interim Metrics</span>",
260
+ unsafe_allow_html=True
261
+ )
262
+ st.metric("MAPE (so far)", f"{mape:.2f} %")
263
+ st.metric("MAE (so far)", f"{mae:,.0f} MW")
264
+ st.metric("Max Error", f"{max_error:,.0f} MW")
265
+
266
+
267
+
268
+ # ============================== Data Preparation ==============================
269
+
270
+ df_full = load_data()
271
+
272
+ # Split Train/Test
273
+ train_size = int(len(df_full) * TRAIN_RATIO)
274
+ test_df_raw = df_full.iloc[train_size:].reset_index(drop=True)
275
+
276
+ # Start at first full hour (00:00)
277
+ first_full_day_index = test_df_raw[test_df_raw["date"].dt.time == pd.Timestamp("00:00:00").time()].index[0]
278
+ test_df_full = test_df_raw.iloc[first_full_day_index:].reset_index(drop=True)
279
+
280
+ # Select simulation window via date picker
281
+ min_date = test_df_full["date"].min().date()
282
+ max_date = test_df_full["date"].max().date()
283
+
284
+ # ============================== UI Controls ==============================
285
+
286
+ st.markdown("### Simulation Settings")
287
+ col1, col2 = st.columns([1, 1])
288
+
289
+ with col1:
290
+ st.markdown("**General Settings**")
291
+ model_choice = st.selectbox("Choose prediction model", ["LightGBM", "Transformer Model (moments)"])
292
+ if model_choice == "Transformer Model(moments)":
293
+ st.caption("⚠️ Note: Transformer model runs slower without GPU. (Use Speed = 10)")
294
+ window_days = st.selectbox("Display window (days)", options=[3, 5, 7], index=0)
295
+ window_hours = window_days * 24
296
+ speed = st.slider("Speed", 1, 10, 5)
297
+
298
+ with col2:
299
+ st.markdown(f"**Date Range** (from {min_date} to {max_date})")
300
+ start_date = st.date_input("Start Date", value=min_date, min_value=min_date, max_value=max_date)
301
+ end_date = st.date_input("End Date", value=max_date, min_value=min_date, max_value=max_date)
302
+
303
+
304
+ # ============================== Data Preparation (filtered) ==============================
305
+
306
+ # final filtered date window
307
+ test_df_filtered = test_df_full[
308
+ (test_df_full["date"].dt.date >= start_date) &
309
+ (test_df_full["date"].dt.date <= end_date)
310
+ ].reset_index(drop=True)
311
+
312
+ # For progression bar
313
+ total_steps_ui = len(test_df_filtered)
314
+
315
+ # ============================== Buttons ==============================
316
+
317
+ st.markdown("### Start Simulation")
318
+ col1, col2, col3 = st.columns([1, 1, 14])
319
+ with col1:
320
+ play_pause_text = "▶️ Start" if not st.session_state.is_running else "⏸️ Pause"
321
+ if st.button(play_pause_text):
322
+ st.session_state.is_running = not st.session_state.is_running
323
+ st.rerun()
324
+ with col2:
325
+ reset_button = st.button("🔄 Reset")
326
+
327
+ # Reset logic
328
+ if reset_button:
329
+ st.session_state.start_index = 0
330
+ st.session_state.pred_vals = []
331
+ st.session_state.true_vals = []
332
+ st.session_state.pred_timestamps = []
333
+ st.session_state.true_timestamps = []
334
+ st.session_state.last_fig = None
335
+ st.session_state.is_running = False
336
+ st.session_state.valid_pos = 0
337
+ st.rerun()
338
+
339
+ # Auto-reset on critical parameter change while running
340
+ if st.session_state.is_running and (
341
+ start_date != st.session_state.get("last_start_date") or
342
+ end_date != st.session_state.get("last_end_date") or
343
+ model_choice != st.session_state.get("last_model_choice")
344
+ ):
345
+ st.session_state.start_index = 0
346
+ st.session_state.pred_vals = []
347
+ st.session_state.true_vals = []
348
+ st.session_state.pred_timestamps = []
349
+ st.session_state.true_timestamps = []
350
+ st.session_state.last_fig = None
351
+ st.session_state.valid_pos = 0
352
+ st.rerun()
353
+
354
+ # Track current selections for change detection
355
+ st.session_state.last_start_date = start_date
356
+ st.session_state.last_end_date = end_date
357
+ st.session_state.last_model_choice = model_choice
358
+
359
+
360
+ # ============================== Paused Mode ==============================
361
+
362
+ if not st.session_state.is_running and st.session_state.last_fig is not None:
363
+ st.write("Simulation paused...")
364
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
365
+
366
+ timestamp = st.session_state.pred_timestamps[-1] if st.session_state.pred_timestamps else "–"
367
+ prediction = st.session_state.pred_vals[-1] if st.session_state.pred_vals else None
368
+ actual = st.session_state.true_vals[-1] if st.session_state.true_vals else None
369
+ progress = st.session_state.start_index / total_steps_ui
370
+
371
+ render_simulation_view(timestamp, prediction, actual, progress, st.session_state.last_fig, paused=True)
372
+
373
+
374
+ # ============================== initialize values ==============================
375
+
376
+ #if lightGbm use testdata from above
377
+ if model_choice == "LightGBM":
378
+ test_df = test_df_filtered.copy()
379
+
380
+ #Shared state references for storing predictions and ground truths
381
+
382
+ true_vals = st.session_state.true_vals
383
+ pred_vals = st.session_state.pred_vals
384
+ true_timestamps = st.session_state.true_timestamps
385
+ pred_timestamps = st.session_state.pred_timestamps
386
+
387
+ # ============================== LightGBM Simulation ==============================
388
+
389
+ if model_choice == "LightGBM" and st.session_state.is_running:
390
+ model = load_lightgbm_model()
391
+ st.write("Simulation started...")
392
+ st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
393
+
394
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
395
+
396
+ for i in range(st.session_state.start_index, len(test_df)):
397
+ if not st.session_state.is_running:
398
+ break
399
+
400
+ current = test_df.iloc[i]
401
+ timestamp = current["date"]
402
+ features = current[FEATURES].values.reshape(1, -1)
403
+ prediction = model.predict(features)[0]
404
+
405
+ pred_vals.append(prediction)
406
+ pred_timestamps.append(timestamp)
407
+
408
+ if i >= 1:
409
+ prev_actual = test_df.iloc[i - 1]["consumption_MW"]
410
+ prev_time = test_df.iloc[i - 1]["date"]
411
+ true_vals.append(prev_actual)
412
+ true_timestamps.append(prev_time)
413
+
414
+ fig = create_prediction_plot(
415
+ pred_timestamps, pred_vals,
416
+ true_timestamps, true_vals,
417
+ window_hours,
418
+ y_min= test_df_filtered["consumption_MW"].min() - 2000,
419
+ y_max= test_df_filtered["consumption_MW"].max() + 2000
420
+ )
421
+
422
+ render_simulation_view(timestamp, prediction, prev_actual if i >= 1 else None, i / len(test_df), fig)
423
+
424
+ plt.close(fig) # Speicher freigeben
425
+
426
+ st.session_state.start_index = i + 1
427
+ time.sleep(1 / (speed + 1e-9))
428
+
429
+ st.success("Simulation completed!")
430
+
431
+
432
+
433
+ # ============================== Transformer Simulation ==============================
434
+
435
+ if model_choice == "Transformer Model(moments)":
436
+ if st.session_state.is_running:
437
+ st.write("Simulation started (Transformer)...")
438
+ st.markdown('<div id="simulation"></div>', unsafe_allow_html=True)
439
+
440
+ plot_title, plot_container, x_axis_label, info_container = init_simulation_layout()
441
+
442
+ # Zugriff auf Modell, Dataset, Device
443
+ model, test_dataset, device = load_transformer_model_and_dataset()
444
+ data = test_dataset.data # bereits skaliert
445
+ scaler = test_dataset.scaler
446
+ n_channels = test_dataset.n_channels
447
+
448
+ test_start_idx = len(InformerDataset(data_split="train", forecast_horizon=FORECAST_HORIZON)) + SEQ_LEN
449
+ base_timestamp = pd.read_csv(DATA_PATH, parse_dates=["date"])["date"].iloc[test_start_idx] #get original timestamp for later, cause not in dataset anymore
450
+
451
+ # Schritt 1: Finde Index, ab dem Stunde = 00:00 ist
452
+ offset = 0
453
+ while (base_timestamp + pd.Timedelta(hours=offset)).time() != pd.Timestamp("00:00:00").time():
454
+ offset += 1
455
+
456
+ # Neuer Startindex in der Simulation
457
+ start_index = offset
458
+
459
+ # Session-State bei Bedarf initial setzen
460
+ if "start_index" not in st.session_state or st.session_state.start_index == 0:
461
+ st.session_state.start_index = start_index
462
+
463
+
464
+ # Vorbereiten: Liste der gültigen i-Werte im gewünschten Zeitraum
465
+ valid_indices = []
466
+ for i in range(start_index, len(test_dataset)):
467
+ timestamp = base_timestamp + pd.Timedelta(hours=i)
468
+ if start_date <= timestamp.date() <= end_date:
469
+ valid_indices.append(i)
470
+
471
+ # Fortschrittsanzeige
472
+ total_steps = len(valid_indices)
473
+
474
+ # Aktueller Fortschritt in der Liste (nicht: globaler Dataset-Index!)
475
+ if "valid_pos" not in st.session_state:
476
+ st.session_state.valid_pos = 0
477
+
478
+ # Hauptschleife: Nur noch über gültige Indizes iterieren
479
+ for relative_idx, i in enumerate(valid_indices[st.session_state.valid_pos:]):
480
+
481
+ #for i in range(st.session_state.start_index, len(test_dataset)):
482
+ if not st.session_state.is_running:
483
+ break
484
+
485
+ current_pred = predict_transformer_step(model, test_dataset, i, device)
486
+ current_time = base_timestamp + pd.Timedelta(hours=i)
487
+
488
+ pred_vals.append(current_pred)
489
+ pred_timestamps.append(current_time)
490
+
491
+ if i >= 1:
492
+ prev_actual = test_dataset[i - 1][1][0, 0] # erster Forecast-Wert der letzten Zeile
493
+ # Rückskalieren
494
+ dummy_actual = np.zeros((1, n_channels))
495
+ dummy_actual[:, 0] = prev_actual
496
+ actual_val = scaler.inverse_transform(dummy_actual)[0, 0]
497
+
498
+ true_time = current_time - pd.Timedelta(hours=1)
499
+
500
+ if true_time >= pd.to_datetime(start_date):
501
+ true_vals.append(actual_val)
502
+ true_timestamps.append(true_time)
503
+
504
+ # Plot erzeugen
505
+ fig = create_prediction_plot(
506
+ pred_timestamps, pred_vals,
507
+ true_timestamps, true_vals,
508
+ window_hours,
509
+ y_min= test_df_filtered["consumption_MW"].min() - 2000,
510
+ y_max= test_df_filtered["consumption_MW"].max() + 2000
511
+ )
512
+ if len(pred_vals) >= 2 and len(true_vals) >= 1:
513
+ render_simulation_view(current_time, current_pred, actual_val if i >= 1 else None, st.session_state.valid_pos / total_steps, fig)
514
+
515
+ plt.close(fig) # Speicher freigeben
516
+
517
+ st.session_state.valid_pos += 1
518
+ time.sleep(1 / (speed + 1e-9))
519
+
520
+ st.success("Simulation completed!")
521
+
522
+
523
+ # ============================== Scroll Sync ==============================
524
+
525
+ st.markdown("""
526
+ <script>
527
+ window.addEventListener("message", (event) => {
528
+ if (event.data.type === "save_scroll") {
529
+ const pyScroll = event.data.scrollY;
530
+ window.parent.postMessage({type: "streamlit:setComponentValue", value: pyScroll}, "*");
531
+ }
532
+ });
533
+ </script>
534
+ """, unsafe_allow_html=True)
535
+
streamlit_simulation/config_streamlit.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config_streamlit
2
+ import os
3
+
4
+ # Base directory → points to the project root
5
+ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
+
7
+ # Model paths
8
+ MODEL_PATH_LIGHTGBM = os.path.join(BASE_DIR, "lightgbm_model", "model", "lightgbm_final_model.pkl")
9
+ MODEL_PATH_TRANSFORMER = os.path.join(BASE_DIR, "transformer_model", "model", "checkpoints", "model_final.pth")
10
+
11
+ # Data path
12
+ DATA_PATH = os.path.join(BASE_DIR, "data", "processed", "energy_consumption_aggregated_cleaned.csv")
13
+
14
+ # Color palette for Streamlit layout
15
+ TEXT_COLOR = "#004080" # Primary text color (clean dark blue)
16
+ HEADER_COLOR = "#002855" # Accent color for headings
17
+ ACCENT_COLOR = "#9bb2cc" # For borders, highlights, etc.
18
+ BUTTON_BG = "#dee7f0" # Background color for buttons
19
+ BUTTON_HOVER_BG = "#cbd9e6" # Hover color for buttons
20
+ BG_COLOR = "#ffffff" # Page background
21
+ INPUT_BG = "#f2f6fa" # Background for select boxes, inputs
22
+ PROGRESS_COLOR = "#0077B6" # Progress bar color
23
+ PLOT_COLOR = "white" # Plot background color
24
+
25
+ # Constants
26
+ TRAIN_RATIO = 0.7 # Train/test split ratio used by both models
transformer_model/results/evaluation_metrics.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"RMSE": 3933.5735661100834, "MAPE": 2.3222167044878006, "R2": 0.97211754322052}
transformer_model/results/test_results.csv ADDED
The diff for this file is too large to render. See raw diff
 
transformer_model/results/training_metrics.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train_losses": [0.17907951894225974, 0.11743136870736444, 0.10286305829986463, 0.095653260748457, 0.09064765630698786, 0.08855325479177233, 0.08623282216515275, 0.08489166740133372, 0.08422152720884994], "test_mses": [0.07641124725341797, 0.050424233078956604, 0.03807574138045311, 0.032122015953063965, 0.026808083057403564, 0.02273257076740265, 0.02027367614209652, 0.018922727555036545, 0.017820490524172783], "test_maes": [0.1691250056028366, 0.1388522833585739, 0.12234506011009216, 0.11616843193769455, 0.10695459693670273, 0.09815964102745056, 0.09287288039922714, 0.0910905972123146, 0.0890081524848938]}
transformer_model/scripts/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # __init__.py
2
+
transformer_model/scripts/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (199 Bytes). View file
 
transformer_model/scripts/__pycache__/check_device.cpython-311.pyc ADDED
Binary file (1.94 kB). View file
 
transformer_model/scripts/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.16 kB). View file
 
transformer_model/scripts/__pycache__/config_transformer.cpython-311.pyc ADDED
Binary file (1.19 kB). View file
 
transformer_model/scripts/__pycache__/create_dataloaders.cpython-311.pyc ADDED
Binary file (1.94 kB). View file
 
transformer_model/scripts/__pycache__/informer_dataset_class.cpython-311.pyc ADDED
Binary file (5.33 kB). View file
 
transformer_model/scripts/__pycache__/load_basis_model.cpython-311.pyc ADDED
Binary file (2.84 kB). View file
 
transformer_model/scripts/config_transformer.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+
4
+ # Base Directory
5
+ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
6
+
7
+ # Data paths
8
+ DATA_PATH = os.path.join(BASE_DIR, "..", "data", "processed", "energy_consumption_aggregated_cleaned.csv")
9
+
10
+ # Other paths
11
+ CHECKPOINT_DIR = os.path.join(BASE_DIR, "model", "checkpoints")
12
+ RESULTS_DIR = os.path.join(BASE_DIR, "results")
13
+
14
+
15
+ # ========== Model Settings ==========
16
+ SEQ_LEN = 512 # Input sequence length (number of time steps the model sees)
17
+ FORECAST_HORIZON = 1 # Number of future steps the model should predict
18
+ HEAD_DROPOUT = 0.1 # Dropout in the head to prevent overfitting
19
+ WEIGHT_DECAY = 0.0 # L2 regularization (0 means off)
20
+
21
+ # ========== Training Settings ==========
22
+ MAX_EPOCHS = 9 # Optimal number of epochs based on performance curve
23
+ BATCH_SIZE = 32 # Batch size for training and evaluation
24
+ LEARNING_RATE = 1e-4 # Base learning rate
25
+ MAX_LR = 1e-4 # Max LR for OneCycleLR scheduler
26
+ GRAD_CLIP = 5.0 # Gradient clipping threshold
27
+
28
+ # ========== Freezing Strategy ==========
29
+ FREEZE_ENCODER = True
30
+ FREEZE_EMBEDDER = True
31
+ FREEZE_HEAD = False #just unfreeze the last forecasting head for finetuning
transformer_model/scripts/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__
transformer_model/scripts/evaluation/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (180 Bytes). View file
 
transformer_model/scripts/evaluation/__pycache__/evaluate.cpython-311.pyc ADDED
Binary file (7.78 kB). View file
 
transformer_model/scripts/evaluation/__pycache__/plot_metrics.cpython-311.pyc ADDED
Binary file (4.71 kB). View file
 
transformer_model/scripts/evaluation/evaluate.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluate.py
2
+
3
+ import os
4
+ import json
5
+ import torch
6
+ import logging
7
+ import numpy as np
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+ from sklearn.metrics import mean_squared_error, r2_score
12
+
13
+ from transformer_model.scripts.config_transformer import BASE_DIR, RESULTS_DIR, CHECKPOINT_DIR, DATA_PATH, FORECAST_HORIZON, SEQ_LEN
14
+ from transformer_model.scripts.training.load_basis_model import load_moment_model
15
+ from transformer_model.scripts.utils.informer_dataset_class import InformerDataset
16
+ from momentfm.utils.utils import control_randomness
17
+ from transformer_model.scripts.utils.check_device import check_device
18
+
19
+
20
+ # Setup logging
21
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
22
+
23
+ def evaluate():
24
+ control_randomness(seed=13)
25
+ # Set device
26
+ device, backend, scaler = check_device()
27
+ logging.info(f"Evaluation is running on: {backend} ({device})")
28
+
29
+ # Load final model
30
+ model = load_moment_model()
31
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, "model_final.pth")
32
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
33
+ model.to(device)
34
+ model.eval()
35
+ logging.info(f"Loaded final model from: {checkpoint_path}")
36
+
37
+ # Recreate training dataset to get the fitted scaler
38
+ train_dataset = InformerDataset(
39
+ data_split="train",
40
+ random_seed=13,
41
+ forecast_horizon=FORECAST_HORIZON
42
+ )
43
+
44
+ # Use its scaler in the test dataset
45
+ test_dataset = InformerDataset(
46
+ data_split="test",
47
+ random_seed=13,
48
+ forecast_horizon=FORECAST_HORIZON
49
+ )
50
+
51
+ test_dataset.scaler = train_dataset.scaler
52
+
53
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
54
+
55
+ trues, preds = [], []
56
+
57
+ with torch.no_grad():
58
+ for timeseries, forecast, input_mask in tqdm(test_loader, desc="Evaluating on test set"):
59
+ timeseries = timeseries.float().to(device)
60
+ forecast = forecast.float().to(device)
61
+ input_mask = input_mask.to(device) # <- wichtig!
62
+
63
+ output = model(x_enc=timeseries, input_mask=input_mask)
64
+
65
+ trues.append(forecast.cpu().numpy())
66
+ preds.append(output.forecast.cpu().numpy())
67
+
68
+
69
+ trues = np.concatenate(trues, axis=0)
70
+ preds = np.concatenate(preds, axis=0)
71
+
72
+ # Extract only first feature (consumption)
73
+ true_values = trues[:, 0, :]
74
+ pred_values = preds[:, 0, :]
75
+
76
+ # Inverse normalization
77
+ n_features = test_dataset.n_channels
78
+ true_reshaped = np.column_stack([true_values.flatten()] + [np.zeros_like(true_values.flatten())] * (n_features - 1))
79
+ pred_reshaped = np.column_stack([pred_values.flatten()] + [np.zeros_like(pred_values.flatten())] * (n_features - 1))
80
+
81
+ true_original = test_dataset.scaler.inverse_transform(true_reshaped)[:, 0]
82
+ pred_original = test_dataset.scaler.inverse_transform(pred_reshaped)[:, 0]
83
+
84
+
85
+ # Build timestamp index, since date got cutted out in informerdataset we need original dataset and use the index of the beginning of testdata to get the date
86
+ csv_path = os.path.join(DATA_PATH)
87
+ df = pd.read_csv(csv_path, parse_dates=["date"])
88
+
89
+ train_len = len(train_dataset)
90
+ test_start_idx = train_len + SEQ_LEN
91
+ start_timestamp = df["date"].iloc[test_start_idx]
92
+ logging.info(f"[DEBUG] timestamp: {start_timestamp}")
93
+
94
+ timestamps = [start_timestamp + pd.Timedelta(hours=i) for i in range(len(true_original))]
95
+
96
+ df = pd.DataFrame({
97
+ "Timestamp": timestamps,
98
+ "True Consumption (MW)": true_original,
99
+ "Predicted Consumption (MW)": pred_original
100
+ })
101
+
102
+ # Save results to CSV
103
+ os.makedirs(RESULTS_DIR, exist_ok=True)
104
+ results_path = os.path.join(RESULTS_DIR, "test_results.csv")
105
+ df.to_csv(results_path, index=False)
106
+ logging.info(f"Saved prediction results to: {results_path}")
107
+
108
+ # Evaluation metrics
109
+ mse = mean_squared_error(df["True Consumption (MW)"], df["Predicted Consumption (MW)"])
110
+ rmse = np.sqrt(mse)
111
+ mape = np.mean(np.abs((df["True Consumption (MW)"] - df["Predicted Consumption (MW)"]) / df["True Consumption (MW)"])) * 100
112
+ r2 = r2_score(df["True Consumption (MW)"], df["Predicted Consumption (MW)"])
113
+
114
+ # Save metrics to JSON
115
+ metrics = {"RMSE": float(rmse), "MAPE": float(mape), "R2": float(r2)}
116
+ metrics_path = os.path.join(RESULTS_DIR, "evaluation_metrics.json")
117
+ with open(metrics_path, "w") as f:
118
+ json.dump(metrics, f)
119
+
120
+ logging.info(f"Saved evaluation metrics to: {metrics_path}")
121
+ logging.info(f"RMSE: {rmse:.3f} | MAPE: {mape:.2f}% | R²: {r2:.3f}")
122
+
123
+ if __name__ == "__main__":
124
+ evaluate()
transformer_model/scripts/evaluation/plot_metrics.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # plot_metrics.py
2
+
3
+ import os
4
+ import json
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ from transformer_model.scripts.config_transformer import RESULTS_DIR
8
+
9
+ # === Plot 1: Training Metrics ===
10
+
11
+ # Load training metrics
12
+ training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
13
+ with open(training_metrics_path, "r") as f:
14
+ metrics = json.load(f)
15
+
16
+ train_losses = metrics["train_losses"]
17
+ test_mses = metrics["test_mses"]
18
+ test_maes = metrics["test_maes"]
19
+
20
+ plt.figure(figsize=(10, 6))
21
+ plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue")
22
+ plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red")
23
+ plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green")
24
+ plt.xlabel("Epoch")
25
+ plt.ylabel("Loss / Metric")
26
+ plt.title("Training Loss vs Test Metrics")
27
+ plt.legend()
28
+ plt.grid(True)
29
+
30
+ plot_path = os.path.join(RESULTS_DIR, "training_plot.png")
31
+ plt.savefig(plot_path)
32
+ print(f"[Saved] Training metrics plot: {plot_path}")
33
+ plt.show()
34
+
35
+
36
+ # === Plot 2: Predictions vs Ground Truth (Full Range) ===
37
+
38
+ # Load comparison results
39
+ comparison_path = os.path.join(RESULTS_DIR, "test_results.csv")
40
+ df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"])
41
+
42
+ plt.figure(figsize=(15, 6))
43
+ plt.plot(df_comparison["Timestamp"], df_comparison["True Consumption (MW)"], label="True", color="darkblue")
44
+ plt.plot(df_comparison["Timestamp"], df_comparison["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
45
+ plt.title("Energy Consumption: Predictions vs Ground Truth")
46
+ plt.xlabel("Time")
47
+ plt.ylabel("Consumption (MW)")
48
+ plt.legend()
49
+ plt.grid(True)
50
+ plt.tight_layout()
51
+
52
+ plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png")
53
+ plt.savefig(plot_path)
54
+ print(f"[Saved] Full range comparison plot: {plot_path}")
55
+ plt.show()
56
+
57
+
58
+ # === Plot 3: Predictions vs Ground Truth (First Month) ===
59
+
60
+ first_month_start = df_comparison["Timestamp"].min()
61
+ first_month_end = first_month_start + pd.Timedelta(days=25)
62
+ df_first_month = df_comparison[(df_comparison["Timestamp"] >= first_month_start) & (df_comparison["Timestamp"] <= first_month_end)]
63
+
64
+ plt.figure(figsize=(15, 6))
65
+ plt.plot(df_first_month["Timestamp"], df_first_month["True Consumption (MW)"], label="True", color="darkblue")
66
+ plt.plot(df_first_month["Timestamp"], df_first_month["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
67
+ plt.title("Energy Consumption (First Month): Predictions vs Ground Truth")
68
+ plt.xlabel("Time")
69
+ plt.ylabel("Consumption (MW)")
70
+ plt.legend()
71
+ plt.grid(True)
72
+ plt.tight_layout()
73
+
74
+ plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png")
75
+ plt.savefig(plot_path)
76
+ print(f"[Saved] 1-Month comparison plot: {plot_path}")
77
+ plt.show()
transformer_model/scripts/training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__
transformer_model/scripts/training/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (208 Bytes). View file
 
transformer_model/scripts/training/__pycache__/load_basis_model.cpython-311.pyc ADDED
Binary file (2.91 kB). View file
 
transformer_model/scripts/training/__pycache__/train.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
transformer_model/scripts/training/load_basis_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load_basis_model.py
2
+ # Load and initialize the base MOMENT model before finetuning
3
+
4
+ import torch
5
+ import logging
6
+ from momentfm import MOMENTPipeline
7
+ from transformer_model.scripts.config_transformer import (
8
+ FORECAST_HORIZON,
9
+ FREEZE_ENCODER,
10
+ FREEZE_EMBEDDER,
11
+ FREEZE_HEAD,
12
+ WEIGHT_DECAY,
13
+ HEAD_DROPOUT,
14
+ SEQ_LEN
15
+ )
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
19
+
20
+
21
+ def load_moment_model():
22
+ """
23
+ Loads and configures the MOMENT model for forecasting.
24
+ """
25
+ logging.info("Loading MOMENT model...")
26
+ model = MOMENTPipeline.from_pretrained(
27
+ "AutonLab/MOMENT-1-large",
28
+ model_kwargs={
29
+ 'task_name': 'forecasting',
30
+ 'forecast_horizon': FORECAST_HORIZON, # default = 1
31
+ 'head_dropout': HEAD_DROPOUT, # default = 0.1
32
+ 'weight_decay': WEIGHT_DECAY, # default = 0.0
33
+ 'freeze_encoder': FREEZE_ENCODER, # default = True
34
+ 'freeze_embedder': FREEZE_EMBEDDER, # default = True
35
+ 'freeze_head': FREEZE_HEAD # default = False
36
+ }
37
+ )
38
+
39
+ model.init()
40
+ logging.info("Model initialized successfully.")
41
+ return model
42
+
43
+
44
+ def print_trainable_params(model):
45
+ """
46
+ Logs all trainable (unfrozen) parameters of the model.
47
+ """
48
+ logging.info("Unfrozen parameters:")
49
+ for name, param in model.named_parameters():
50
+ if param.requires_grad:
51
+ logging.info(f" {name}")
52
+
53
+
54
+ def test_dummy_forward(model):
55
+ """
56
+ Performs a dummy forward pass to verify the model runs without error.
57
+ """
58
+ logging.info("Running dummy forward pass with random tensors to see if model is running.")
59
+ dummy_x = torch.randn(16, 1, SEQ_LEN)
60
+ output = model(x_enc=dummy_x)
61
+ logging.info("Dummy forward pass successful.")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ model = load_moment_model()
66
+ print_trainable_params(model)
67
+ test_dummy_forward(model)
transformer_model/scripts/training/train.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+
3
+ import os
4
+ import json
5
+ import time
6
+ import logging
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+ from sklearn.metrics import mean_squared_error, mean_absolute_error
11
+
12
+ from transformer_model.scripts.config_transformer import (
13
+ BASE_DIR,
14
+ MAX_EPOCHS,
15
+ BATCH_SIZE,
16
+ LEARNING_RATE,
17
+ MAX_LR,
18
+ GRAD_CLIP,
19
+ FORECAST_HORIZON,
20
+ CHECKPOINT_DIR,
21
+ RESULTS_DIR
22
+ )
23
+
24
+ from transformer_model.scripts.training.load_basis_model import load_moment_model
25
+ from transformer_model.scripts.utils.create_dataloaders import create_dataloaders
26
+ from transformer_model.scripts.utils.check_device import check_device
27
+ from momentfm.utils.utils import control_randomness
28
+
29
+
30
+ # === Setup logging ===
31
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
32
+
33
+
34
+ def train():
35
+ # Start timing
36
+ start_time = time.time()
37
+
38
+ # Setup device (CUDA / DirectML / CPU) and AMP scaler
39
+ device, backend, scaler = check_device()
40
+
41
+ # Load base model
42
+ model = load_moment_model().to(device)
43
+
44
+ # Set random seeds for reproducibility
45
+ control_randomness(seed=13)
46
+
47
+ # Setup loss function and optimizer
48
+ criterion = torch.nn.MSELoss().to(device)
49
+ optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
50
+
51
+ # Load data
52
+ train_loader, test_loader = create_dataloaders()
53
+
54
+ # Setup learning rate scheduler (OneCycle policy)
55
+ total_steps = len(train_loader) * MAX_EPOCHS
56
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
57
+ optimizer,
58
+ max_lr=MAX_LR,
59
+ total_steps=total_steps,
60
+ pct_start=0.3
61
+ )
62
+
63
+ # Ensure output folders exist
64
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
65
+ os.makedirs(RESULTS_DIR, exist_ok=True)
66
+
67
+ # Store metrics
68
+ train_losses, test_mses, test_maes = [], [], []
69
+
70
+ best_mae = float("inf")
71
+ best_epoch = None
72
+ no_improve_epochs = 0
73
+ patience = 5
74
+
75
+ for epoch in range(MAX_EPOCHS):
76
+ model.train()
77
+ epoch_losses = []
78
+
79
+ for timeseries, forecast, input_mask in tqdm(train_loader, desc=f"Epoch {epoch}"):
80
+ timeseries = timeseries.float().to(device)
81
+ input_mask = input_mask.to(device)
82
+ forecast = forecast.float().to(device)
83
+
84
+ # Zero gradients
85
+ optimizer.zero_grad(set_to_none=True)
86
+
87
+ # Forward pass (with AMP if enabled)
88
+ if scaler:
89
+ with torch.amp.autocast(device_type="cuda"):
90
+ output = model(x_enc=timeseries, input_mask=input_mask)
91
+ loss = criterion(output.forecast, forecast)
92
+ else:
93
+ output = model(x_enc=timeseries, input_mask=input_mask)
94
+ loss = criterion(output.forecast, forecast)
95
+
96
+ # Backward pass + optimization
97
+ if scaler:
98
+ scaler.scale(loss).backward()
99
+ scaler.unscale_(optimizer)
100
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
101
+ scaler.step(optimizer)
102
+ scaler.update()
103
+ else:
104
+ loss.backward()
105
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
106
+ optimizer.step()
107
+
108
+ epoch_losses.append(loss.item())
109
+
110
+ average_train_loss = np.mean(epoch_losses)
111
+ train_losses.append(average_train_loss)
112
+ logging.info(f"Epoch {epoch}: Train Loss = {average_train_loss:.4f}")
113
+
114
+ # === Evaluation ===
115
+ model.eval()
116
+ trues, preds = [], []
117
+
118
+ with torch.no_grad():
119
+ for timeseries, forecast, input_mask in test_loader:
120
+ timeseries = timeseries.float().to(device)
121
+ input_mask = input_mask.to(device)
122
+ forecast = forecast.float().to(device)
123
+
124
+ if scaler:
125
+ with torch.amp.autocast(device_type="cuda"):
126
+ output = model(x_enc=timeseries, input_mask=input_mask)
127
+ else:
128
+ output = model(x_enc=timeseries, input_mask=input_mask)
129
+
130
+ trues.append(forecast.detach().cpu().numpy())
131
+ preds.append(output.forecast.detach().cpu().numpy())
132
+
133
+ trues = np.concatenate(trues, axis=0)
134
+ preds = np.concatenate(preds, axis=0)
135
+
136
+
137
+ # Reshape for sklearn metrics
138
+ trues_2d = trues.reshape(trues.shape[0], -1)
139
+ preds_2d = preds.reshape(preds.shape[0], -1)
140
+
141
+ mse = mean_squared_error(trues_2d, preds_2d)
142
+ mae = mean_absolute_error(trues_2d, preds_2d)
143
+
144
+ test_mses.append(mse)
145
+ test_maes.append(mae)
146
+ logging.info(f"Epoch {epoch}: Test MSE = {mse:.4f}, MAE = {mae:.4f}")
147
+
148
+ # === Early Stopping Check ===
149
+ if mae < best_mae:
150
+ best_mae = mae
151
+ best_epoch = epoch
152
+ no_improve_epochs = 0
153
+
154
+ # Save best model
155
+ best_model_path = os.path.join(CHECKPOINT_DIR, "best_model.pth")
156
+ torch.save(model.state_dict(), best_model_path)
157
+ logging.info(f"New best model saved to: {best_model_path} (MAE: {best_mae:.4f})")
158
+ else:
159
+ no_improve_epochs += 1
160
+ logging.info(f"No improvement in MAE for {no_improve_epochs} epoch(s).")
161
+
162
+ if no_improve_epochs >= patience:
163
+ logging.info("Early stopping triggered.")
164
+ break
165
+
166
+ # Save checkpoint
167
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch}.pth")
168
+ torch.save(model.state_dict(), checkpoint_path)
169
+
170
+ scheduler.step()
171
+
172
+ logging.info(f"Best model was at epoch {best_epoch} with MAE: {best_mae:.4f}")
173
+
174
+ # Save final model
175
+ final_model_path = os.path.join(CHECKPOINT_DIR, "model_final.pth")
176
+ torch.save(model.state_dict(), final_model_path)
177
+ logging.info(f"Final model saved to: {final_model_path}")
178
+ logging.info(f"Final Test MSE: {test_mses[-1]:.4f}, MAE: {test_maes[-1]:.4f}")
179
+
180
+ # Save training metrics
181
+ metrics = {
182
+ "train_losses": [float(x) for x in train_losses],
183
+ "test_mses": [float(x) for x in test_mses],
184
+ "test_maes": [float(x) for x in test_maes]
185
+ }
186
+
187
+ metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
188
+ with open(metrics_path, "w") as f:
189
+ json.dump(metrics, f)
190
+ logging.info(f"Training metrics saved to: {metrics_path}")
191
+
192
+ # Done
193
+ elapsed = time.time() - start_time
194
+ logging.info(f"Training complete in {elapsed / 60:.2f} minutes.")
195
+
196
+
197
+ # === Entry Point ===
198
+ if __name__ == "__main__":
199
+ try:
200
+ train()
201
+ except Exception as e:
202
+ logging.error(f"Training failed: {e}")
transformer_model/scripts/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # __init__
transformer_model/scripts/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (205 Bytes). View file
 
transformer_model/scripts/utils/__pycache__/check_device.cpython-311.pyc ADDED
Binary file (1.95 kB). View file
 
transformer_model/scripts/utils/__pycache__/create_dataloaders.cpython-311.pyc ADDED
Binary file (1.95 kB). View file
 
transformer_model/scripts/utils/__pycache__/informer_dataset_class.cpython-311.pyc ADDED
Binary file (5.4 kB). View file