""" Copyright (c) Facebook, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ from argparse import ArgumentParser import torch import fastmri from fastmri import transforms from ..varnet import VarNet import wandb from .mri_module import MriModule class VarNetModule(MriModule): """ VarNet training module. This can be used to train variational networks from the paper: A. Sriram et al. End-to-end variational networks for accelerated MRI reconstruction. In International Conference on Medical Image Computing and Computer-Assisted Intervention, 2020. which was inspired by the earlier paper: K. Hammernik et al. Learning a variational network for reconstruction of accelerated MRI data. Magnetic Resonance inMedicine, 79(6):3055–3071, 2018. """ def __init__( self, num_cascades: int = 12, pools: int = 4, chans: int = 18, sens_pools: int = 4, sens_chans: int = 8, lr: float = 0.0003, lr_step_size: int = 40, lr_gamma: float = 0.1, weight_decay: float = 0.0, **kwargs, ): """ Parameters ---------- num_cascades : int Number of cascades (i.e., layers) for the variational network. pools : int Number of downsampling and upsampling layers for the cascade U-Net. chans : int Number of channels for the cascade U-Net. sens_pools : int Number of downsampling and upsampling layers for the sensitivity map U-Net. sens_chans : int Number of channels for the sensitivity map U-Net. lr : float Learning rate. lr_step_size : int Learning rate step size. lr_gamma : float Learning rate gamma decay. weight_decay : float Parameter for penalizing weights norm. num_sense_lines : int, optional Number of low-frequency lines to use for sensitivity map computation. Must be even or `None`. Default `None` will automatically compute the number from masks. Default behavior may cause some slices to use more low-frequency lines than others, when used in conjunction with e.g. the EquispacedMaskFunc defaults. To prevent this, either set `num_sense_lines`, or set `skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc. Note that setting this value may lead to undesired behavior when training on multiple accelerations simultaneously. """ super().__init__(**kwargs) self.save_hyperparameters() self.num_cascades = num_cascades self.pools = pools self.chans = chans self.sens_pools = sens_pools self.sens_chans = sens_chans self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.varnet = VarNet( num_cascades=self.num_cascades, sens_chans=self.sens_chans, sens_pools=self.sens_pools, chans=self.chans, pools=self.pools, ) self.criterion = fastmri.SSIMLoss() self.num_params = sum(p.numel() for p in self.parameters()) def forward(self, masked_kspace, mask, num_low_frequencies): return self.varnet(masked_kspace, mask, num_low_frequencies) def training_step(self, batch, batch_idx): output = self.forward( batch.masked_kspace, batch.mask, batch.num_low_frequencies ) target, output = transforms.center_crop_to_smallest(batch.target, output) loss = self.criterion( output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value ) self.log("train_loss", loss, on_step=True, on_epoch=True) self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True) return loss def validation_step(self, batch, batch_idx, dataloader_idx=0): dataloaders = self.trainer.val_dataloaders slug = list(dataloaders.keys())[dataloader_idx] # breakpoint() output = self.forward( batch.masked_kspace, batch.mask, batch.num_low_frequencies ) target, output = transforms.center_crop_to_smallest(batch.target, output) loss = self.criterion( output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value, ) return { "slug": slug, "fname": batch.fname, "slice_num": batch.slice_num, "max_value": batch.max_value, "output": output, "target": target, "val_loss": loss, } def configure_optimizers(self): optim = torch.optim.Adam( self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) scheduler = torch.optim.lr_scheduler.StepLR( optim, self.lr_step_size, self.lr_gamma ) return [optim], [scheduler] @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover """ Define parameters that only apply to this model """ parser = ArgumentParser(parents=[parent_parser], add_help=False) parser = MriModule.add_model_specific_args(parser) # network params parser.add_argument( "--num_cascades", default=12, type=int, help="Number of VarNet cascades", ) parser.add_argument( "--pools", default=4, type=int, help="Number of U-Net pooling layers in VarNet blocks", ) parser.add_argument( "--chans", default=18, type=int, help="Number of channels for U-Net in VarNet blocks", ) parser.add_argument( "--sens_pools", default=4, type=int, help=( "Number of pooling layers for sense map estimation U-Net in" " VarNet" ), ) parser.add_argument( "--sens_chans", default=8, type=float, help="Number of channels for sense map estimation U-Net in VarNet", ) # training params (opt) parser.add_argument( "--lr", default=0.0003, type=float, help="Adam learning rate" ) parser.add_argument( "--lr_step_size", default=40, type=int, help="Epoch at which to decrease step size", ) parser.add_argument( "--lr_gamma", default=0.1, type=float, help="Extent to which step size should be decreased", ) parser.add_argument( "--weight_decay", default=0.0, type=float, help="Strength of weight decay regularization", ) return parser