Spaces:
Running
Running
| # plot_metrics.py | |
| import os | |
| import json | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from transformer_model.scripts.config_transformer import RESULTS_DIR | |
| # === Plot 1: Training Metrics === | |
| # Load training metrics | |
| training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json") | |
| with open(training_metrics_path, "r") as f: | |
| metrics = json.load(f) | |
| train_losses = metrics["train_losses"] | |
| test_mses = metrics["test_mses"] | |
| test_maes = metrics["test_maes"] | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue") | |
| plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red") | |
| plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Loss / Metric") | |
| plt.title("Training Loss vs Test Metrics") | |
| plt.legend() | |
| plt.grid(True) | |
| plot_path = os.path.join(RESULTS_DIR, "training_plot.png") | |
| plt.savefig(plot_path) | |
| print(f"[Saved] Training metrics plot: {plot_path}") | |
| plt.show() | |
| # === Plot 2: Predictions vs Ground Truth (Full Range) === | |
| # Load comparison results | |
| comparison_path = os.path.join(RESULTS_DIR, "test_results.csv") | |
| df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"]) | |
| plt.figure(figsize=(15, 6)) | |
| plt.plot(df_comparison["Timestamp"], df_comparison["True Consumption (MW)"], label="True", color="darkblue") | |
| plt.plot(df_comparison["Timestamp"], df_comparison["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--") | |
| plt.title("Energy Consumption: Predictions vs Ground Truth") | |
| plt.xlabel("Time") | |
| plt.ylabel("Consumption (MW)") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.tight_layout() | |
| plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png") | |
| plt.savefig(plot_path) | |
| print(f"[Saved] Full range comparison plot: {plot_path}") | |
| plt.show() | |
| # === Plot 3: Predictions vs Ground Truth (First Month) === | |
| first_month_start = df_comparison["Timestamp"].min() | |
| first_month_end = first_month_start + pd.Timedelta(days=25) | |
| df_first_month = df_comparison[(df_comparison["Timestamp"] >= first_month_start) & (df_comparison["Timestamp"] <= first_month_end)] | |
| plt.figure(figsize=(15, 6)) | |
| plt.plot(df_first_month["Timestamp"], df_first_month["True Consumption (MW)"], label="True", color="darkblue") | |
| plt.plot(df_first_month["Timestamp"], df_first_month["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--") | |
| plt.title("Energy Consumption (First Month): Predictions vs Ground Truth") | |
| plt.xlabel("Time") | |
| plt.ylabel("Consumption (MW)") | |
| plt.legend() | |
| plt.grid(True) | |
| plt.tight_layout() | |
| plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png") | |
| plt.savefig(plot_path) | |
| print(f"[Saved] 1-Month comparison plot: {plot_path}") | |
| plt.show() | |