|
"""Module to define the train and test functions.""" |
|
|
|
|
|
|
|
import modules.config as config |
|
import pytorch_lightning as pl |
|
import torch |
|
from modules.utils import create_folder_if_not_exists |
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ModelSummary |
|
|
|
|
|
from pytorch_lightning.tuner.tuning import Tuner |
|
|
|
|
|
PREFERRED_START_LR = config.PREFERRED_START_LR |
|
|
|
|
|
def train_and_test_model( |
|
batch_size, |
|
num_epochs, |
|
model, |
|
datamodule, |
|
logger, |
|
debug=False, |
|
): |
|
"""Trains and tests the model by iterating through epochs using Lightning Trainer.""" |
|
|
|
print(f"\n\nBatch size: {batch_size}, Total epochs: {num_epochs}\n\n") |
|
|
|
print("Defining Lightning Callbacks") |
|
|
|
|
|
checkpoint = ModelCheckpoint( |
|
dirpath=config.CHECKPOINT_PATH, monitor="val_acc", mode="max", filename="model_best_epoch", save_last=True |
|
) |
|
|
|
lr_rate_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=False) |
|
|
|
model_summary = ModelSummary(max_depth=0) |
|
|
|
print("Defining Lightning Trainer") |
|
|
|
if debug: |
|
num_epochs = 1 |
|
fast_dev_run = True |
|
overfit_batches = 0.1 |
|
profiler = "advanced" |
|
else: |
|
fast_dev_run = False |
|
overfit_batches = 0.0 |
|
profiler = None |
|
|
|
|
|
trainer = pl.Trainer( |
|
precision=16, |
|
fast_dev_run=fast_dev_run, |
|
|
|
|
|
|
|
max_epochs=num_epochs, |
|
logger=logger, |
|
|
|
overfit_batches=overfit_batches, |
|
log_every_n_steps=10, |
|
|
|
profiler=profiler, |
|
|
|
callbacks=[checkpoint, lr_rate_monitor, model_summary], |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Finding the optimal learning rate using Lightning Tuner.") |
|
tuner = Tuner(trainer) |
|
tuner.lr_find( |
|
model=model, |
|
datamodule=datamodule, |
|
min_lr=PREFERRED_START_LR, |
|
max_lr=5, |
|
num_training=200, |
|
mode="linear", |
|
early_stop_threshold=10, |
|
attr_name="learning_rate", |
|
) |
|
|
|
trainer.fit(model, datamodule=datamodule) |
|
trainer.test(model, dataloaders=datamodule.test_dataloader()) |
|
|
|
|
|
print("Collecting epoch level model results.") |
|
results = model.results |
|
|
|
|
|
|
|
print("Collecting misclassified images.") |
|
misclassified_image_data = model.misclassified_image_data |
|
|
|
|
|
|
|
print("Saving the model.") |
|
print(f"Model saved to {config.MODEL_PATH}") |
|
create_folder_if_not_exists(config.MODEL_PATH) |
|
torch.save(model.state_dict(), config.MODEL_PATH) |
|
|
|
|
|
num_elements = 20 |
|
print(f"Saving first {num_elements} misclassified images.") |
|
subset_misclassified_image_data = {"images": [], "ground_truths": [], "predicted_vals": []} |
|
subset_misclassified_image_data["images"] = misclassified_image_data["images"][:num_elements] |
|
subset_misclassified_image_data["ground_truths"] = misclassified_image_data["ground_truths"][:num_elements] |
|
subset_misclassified_image_data["predicted_vals"] = misclassified_image_data["predicted_vals"][:num_elements] |
|
create_folder_if_not_exists(config.MISCLASSIFIED_PATH) |
|
torch.save(subset_misclassified_image_data, config.MISCLASSIFIED_PATH) |
|
|
|
return trainer, results, misclassified_image_data |
|
|