erav2s13 / utilities /callbacks.py
Mojo
add utilities file
229755d
raw
history blame
2.65 kB
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from .visualize import plot_model_training_curves
class TrainingEndCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Perform actions at the end of the entire training process
print("Training, validation, and testing completed!")
logged_metrics = pl_module.log_store
plot_model_training_curves(
train_accs=logged_metrics["train_acc_epoch"],
test_accs=logged_metrics["val_acc_epoch"],
train_losses=logged_metrics["train_loss_epoch"],
test_losses=logged_metrics["val_loss_epoch"],
)
class PrintLearningMetricsCallback(Callback):
def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
super().on_train_epoch_end(trainer, pl_module)
print(
f"\nEpoch: {trainer.current_epoch}, Train Loss: {trainer.logged_metrics['train_loss_epoch']}, Train Accuracy: {trainer.logged_metrics['train_acc_epoch']}"
)
pl_module.log_store.get("train_loss_epoch").append(
trainer.logged_metrics["train_loss_epoch"].cpu().detach().item()
)
pl_module.log_store.get("train_acc_epoch").append(
trainer.logged_metrics["train_acc_epoch"].cpu().detach().item()
)
def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
super().on_validation_epoch_end(trainer, pl_module)
print(
f"\nEpoch: {trainer.current_epoch}, Val Loss: {trainer.logged_metrics['val_loss_epoch']}, Val Accuracy: {trainer.logged_metrics['val_acc_epoch']}"
)
pl_module.log_store.get("val_loss_epoch").append(
trainer.logged_metrics["val_loss_epoch"].cpu().detach().item()
)
pl_module.log_store.get("val_acc_epoch").append(
trainer.logged_metrics["val_acc_epoch"].cpu().detach().item()
)
def on_test_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
super().on_test_epoch_end(trainer, pl_module)
print(
f"\nEpoch: {trainer.current_epoch}, Test Loss: {trainer.logged_metrics['test_loss_epoch']}, Test Accuracy: {trainer.logged_metrics['test_acc_epoch']}"
)
pl_module.log_store.get("test_loss_epoch").append(
trainer.logged_metrics["test_loss_epoch"].cpu().detach().item()
)
pl_module.log_store.get("test_acc_epoch").append(
trainer.logged_metrics["test_acc_epoch"].cpu().detach().item()
)