Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import logging | |
| import torch | |
| import wandb | |
| class CheckpointSaver: | |
| def __init__(self, dirpath, run_name='', decreasing=True, top_n=5): | |
| """ | |
| dirpath: Directory path where to store all model weights | |
| decreasing: If decreasing is `True`, then lower metric is better | |
| top_n: Total number of models to track based on validation metric value | |
| """ | |
| if not os.path.exists(dirpath): os.makedirs(dirpath) | |
| self.dirpath = dirpath | |
| self.top_n = top_n | |
| self.decreasing = decreasing | |
| self.top_model_paths = [] | |
| self.best_metric_val = np.Inf if decreasing else -np.Inf | |
| self.run_name = run_name | |
| def __call__(self, model, epoch, metric_val, optimizer, loss): | |
| model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_{self.run_name}_epoch{epoch}.pt') | |
| save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val | |
| if save: | |
| logging.info( | |
| f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.") | |
| self.best_metric_val = metric_val | |
| torch.save( | |
| { # Save our checkpoint loc | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': loss, | |
| }, model_path) | |
| self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val) | |
| self.top_model_paths.append({'path': model_path, 'score': metric_val}) | |
| self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing) | |
| if len(self.top_model_paths) > self.top_n: | |
| self.cleanup() | |
| def log_artifact(self, filename, model_path, metric_val): | |
| artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val}) | |
| artifact.add_file(model_path) | |
| wandb.run.log_artifact(artifact) | |
| def cleanup(self): | |
| to_remove = self.top_model_paths[self.top_n:] | |
| logging.info(f"Removing extra models.. {to_remove}") | |
| for o in to_remove: | |
| os.remove(o['path']) | |
| self.top_model_paths = self.top_model_paths[:self.top_n] | |