File size: 2,648 Bytes
229755d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

from .visualize import plot_model_training_curves


class TrainingEndCallback(Callback):
    def on_train_end(self, trainer, pl_module):
        # Perform actions at the end of the entire training process
        print("Training, validation, and testing completed!")

        logged_metrics = pl_module.log_store

        plot_model_training_curves(
            train_accs=logged_metrics["train_acc_epoch"],
            test_accs=logged_metrics["val_acc_epoch"],
            train_losses=logged_metrics["train_loss_epoch"],
            test_losses=logged_metrics["val_loss_epoch"],
        )


class PrintLearningMetricsCallback(Callback):
    def on_train_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        super().on_train_epoch_end(trainer, pl_module)
        print(
            f"\nEpoch: {trainer.current_epoch}, Train Loss: {trainer.logged_metrics['train_loss_epoch']}, Train Accuracy: {trainer.logged_metrics['train_acc_epoch']}"
        )
        pl_module.log_store.get("train_loss_epoch").append(
            trainer.logged_metrics["train_loss_epoch"].cpu().detach().item()
        )
        pl_module.log_store.get("train_acc_epoch").append(
            trainer.logged_metrics["train_acc_epoch"].cpu().detach().item()
        )

    def on_validation_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        super().on_validation_epoch_end(trainer, pl_module)
        print(
            f"\nEpoch: {trainer.current_epoch}, Val Loss: {trainer.logged_metrics['val_loss_epoch']}, Val Accuracy: {trainer.logged_metrics['val_acc_epoch']}"
        )
        pl_module.log_store.get("val_loss_epoch").append(
            trainer.logged_metrics["val_loss_epoch"].cpu().detach().item()
        )
        pl_module.log_store.get("val_acc_epoch").append(
            trainer.logged_metrics["val_acc_epoch"].cpu().detach().item()
        )


    def on_test_epoch_end(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule
    ) -> None:
        super().on_test_epoch_end(trainer, pl_module)
        print(
            f"\nEpoch: {trainer.current_epoch}, Test Loss: {trainer.logged_metrics['test_loss_epoch']}, Test Accuracy: {trainer.logged_metrics['test_acc_epoch']}"
        )
        pl_module.log_store.get("test_loss_epoch").append(
            trainer.logged_metrics["test_loss_epoch"].cpu().detach().item()
        )
        pl_module.log_store.get("test_acc_epoch").append(
            trainer.logged_metrics["test_acc_epoch"].cpu().detach().item()
        )