File size: 2,648 Bytes
229755d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from .visualize import plot_model_training_curves
class TrainingEndCallback(Callback):
def on_train_end(self, trainer, pl_module):
# Perform actions at the end of the entire training process
print("Training, validation, and testing completed!")
logged_metrics = pl_module.log_store
plot_model_training_curves(
train_accs=logged_metrics["train_acc_epoch"],
test_accs=logged_metrics["val_acc_epoch"],
train_losses=logged_metrics["train_loss_epoch"],
test_losses=logged_metrics["val_loss_epoch"],
)
class PrintLearningMetricsCallback(Callback):
def on_train_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
super().on_train_epoch_end(trainer, pl_module)
print(
f"\nEpoch: {trainer.current_epoch}, Train Loss: {trainer.logged_metrics['train_loss_epoch']}, Train Accuracy: {trainer.logged_metrics['train_acc_epoch']}"
)
pl_module.log_store.get("train_loss_epoch").append(
trainer.logged_metrics["train_loss_epoch"].cpu().detach().item()
)
pl_module.log_store.get("train_acc_epoch").append(
trainer.logged_metrics["train_acc_epoch"].cpu().detach().item()
)
def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
super().on_validation_epoch_end(trainer, pl_module)
print(
f"\nEpoch: {trainer.current_epoch}, Val Loss: {trainer.logged_metrics['val_loss_epoch']}, Val Accuracy: {trainer.logged_metrics['val_acc_epoch']}"
)
pl_module.log_store.get("val_loss_epoch").append(
trainer.logged_metrics["val_loss_epoch"].cpu().detach().item()
)
pl_module.log_store.get("val_acc_epoch").append(
trainer.logged_metrics["val_acc_epoch"].cpu().detach().item()
)
def on_test_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
) -> None:
super().on_test_epoch_end(trainer, pl_module)
print(
f"\nEpoch: {trainer.current_epoch}, Test Loss: {trainer.logged_metrics['test_loss_epoch']}, Test Accuracy: {trainer.logged_metrics['test_acc_epoch']}"
)
pl_module.log_store.get("test_loss_epoch").append(
trainer.logged_metrics["test_loss_epoch"].cpu().detach().item()
)
pl_module.log_store.get("test_acc_epoch").append(
trainer.logged_metrics["test_acc_epoch"].cpu().detach().item()
) |