|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import Callback |
|
|
|
from .visualize import plot_model_training_curves |
|
|
|
|
|
class TrainingEndCallback(Callback): |
|
def on_train_end(self, trainer, pl_module): |
|
|
|
print("Training, validation, and testing completed!") |
|
|
|
logged_metrics = pl_module.log_store |
|
|
|
plot_model_training_curves( |
|
train_accs=logged_metrics["train_acc_epoch"], |
|
test_accs=logged_metrics["val_acc_epoch"], |
|
train_losses=logged_metrics["train_loss_epoch"], |
|
test_losses=logged_metrics["val_loss_epoch"], |
|
) |
|
|
|
|
|
class PrintLearningMetricsCallback(Callback): |
|
def on_train_epoch_end( |
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule |
|
) -> None: |
|
super().on_train_epoch_end(trainer, pl_module) |
|
print( |
|
f"\nEpoch: {trainer.current_epoch}, Train Loss: {trainer.logged_metrics['train_loss_epoch']}, Train Accuracy: {trainer.logged_metrics['train_acc_epoch']}" |
|
) |
|
pl_module.log_store.get("train_loss_epoch").append( |
|
trainer.logged_metrics["train_loss_epoch"].cpu().detach().item() |
|
) |
|
pl_module.log_store.get("train_acc_epoch").append( |
|
trainer.logged_metrics["train_acc_epoch"].cpu().detach().item() |
|
) |
|
|
|
def on_validation_epoch_end( |
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule |
|
) -> None: |
|
super().on_validation_epoch_end(trainer, pl_module) |
|
print( |
|
f"\nEpoch: {trainer.current_epoch}, Val Loss: {trainer.logged_metrics['val_loss_epoch']}, Val Accuracy: {trainer.logged_metrics['val_acc_epoch']}" |
|
) |
|
pl_module.log_store.get("val_loss_epoch").append( |
|
trainer.logged_metrics["val_loss_epoch"].cpu().detach().item() |
|
) |
|
pl_module.log_store.get("val_acc_epoch").append( |
|
trainer.logged_metrics["val_acc_epoch"].cpu().detach().item() |
|
) |
|
|
|
|
|
def on_test_epoch_end( |
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule |
|
) -> None: |
|
super().on_test_epoch_end(trainer, pl_module) |
|
print( |
|
f"\nEpoch: {trainer.current_epoch}, Test Loss: {trainer.logged_metrics['test_loss_epoch']}, Test Accuracy: {trainer.logged_metrics['test_acc_epoch']}" |
|
) |
|
pl_module.log_store.get("test_loss_epoch").append( |
|
trainer.logged_metrics["test_loss_epoch"].cpu().detach().item() |
|
) |
|
pl_module.log_store.get("test_acc_epoch").append( |
|
trainer.logged_metrics["test_acc_epoch"].cpu().detach().item() |
|
) |