Spaces:
Sleeping
Sleeping
File size: 2,845 Bytes
8cc5633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# plot_metrics.py
import json
import os
import matplotlib.pyplot as plt
import pandas as pd
from transformer_model.scripts.config_transformer import RESULTS_DIR
# === Plot 1: Training Metrics ===
# Load training metrics
training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
with open(training_metrics_path, "r") as f:
metrics = json.load(f)
train_losses = metrics["train_losses"]
test_mses = metrics["test_mses"]
test_maes = metrics["test_maes"]
plt.figure(figsize=(10, 6))
plt.plot(
range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue"
)
plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red")
plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green")
plt.xlabel("Epoch")
plt.ylabel("Loss / Metric")
plt.title("Training Loss vs Test Metrics")
plt.legend()
plt.grid(True)
plot_path = os.path.join(RESULTS_DIR, "training_plot.png")
plt.savefig(plot_path)
print(f"[Saved] Training metrics plot: {plot_path}")
plt.show()
# === Plot 2: Predictions vs Ground Truth (Full Range) ===
# Load comparison results
comparison_path = os.path.join(RESULTS_DIR, "test_results.csv")
df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"])
plt.figure(figsize=(15, 6))
plt.plot(
df_comparison["Timestamp"],
df_comparison["True Consumption (MW)"],
label="True",
color="darkblue",
)
plt.plot(
df_comparison["Timestamp"],
df_comparison["Predicted Consumption (MW)"],
label="Predicted",
color="red",
linestyle="--",
)
plt.title("Energy Consumption: Predictions vs Ground Truth")
plt.xlabel("Time")
plt.ylabel("Consumption (MW)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png")
plt.savefig(plot_path)
print(f"[Saved] Full range comparison plot: {plot_path}")
plt.show()
# === Plot 3: Predictions vs Ground Truth (First Month) ===
first_month_start = df_comparison["Timestamp"].min()
first_month_end = first_month_start + pd.Timedelta(days=25)
df_first_month = df_comparison[
(df_comparison["Timestamp"] >= first_month_start)
& (df_comparison["Timestamp"] <= first_month_end)
]
plt.figure(figsize=(15, 6))
plt.plot(
df_first_month["Timestamp"],
df_first_month["True Consumption (MW)"],
label="True",
color="darkblue",
)
plt.plot(
df_first_month["Timestamp"],
df_first_month["Predicted Consumption (MW)"],
label="Predicted",
color="red",
linestyle="--",
)
plt.title("Energy Consumption (First Month): Predictions vs Ground Truth")
plt.xlabel("Time")
plt.ylabel("Consumption (MW)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png")
plt.savefig(plot_path)
print(f"[Saved] 1-Month comparison plot: {plot_path}")
plt.show()
|