Spaces:
Runtime error
Runtime error
| from lib.net import NormalNet | |
| from lib.common.train_util import * | |
| import logging | |
| import torch | |
| import numpy as np | |
| from torch import nn | |
| from skimage.transform import resize | |
| import pytorch_lightning as pl | |
| torch.backends.cudnn.benchmark = True | |
| logging.getLogger("lightning").setLevel(logging.ERROR) | |
| class Normal(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super(Normal, self).__init__() | |
| self.cfg = cfg | |
| self.batch_size = self.cfg.batch_size | |
| self.lr_N = self.cfg.lr_N | |
| self.schedulers = [] | |
| self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss()) | |
| self.in_nml = [item[0] for item in cfg.net.in_nml] | |
| def get_progress_bar_dict(self): | |
| tqdm_dict = super().get_progress_bar_dict() | |
| if "v_num" in tqdm_dict: | |
| del tqdm_dict["v_num"] | |
| return tqdm_dict | |
| # Training related | |
| def configure_optimizers(self): | |
| # set optimizer | |
| weight_decay = self.cfg.weight_decay | |
| momentum = self.cfg.momentum | |
| optim_params_N_F = [ | |
| {"params": self.netG.netF.parameters(), "lr": self.lr_N}] | |
| optim_params_N_B = [ | |
| {"params": self.netG.netB.parameters(), "lr": self.lr_N}] | |
| optimizer_N_F = torch.optim.Adam( | |
| optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay | |
| ) | |
| optimizer_N_B = torch.optim.Adam( | |
| optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay | |
| ) | |
| scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma | |
| ) | |
| scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma | |
| ) | |
| self.schedulers = [scheduler_N_F, scheduler_N_B] | |
| optims = [optimizer_N_F, optimizer_N_B] | |
| return optims, self.schedulers | |
| def render_func(self, render_tensor): | |
| height = render_tensor["image"].shape[2] | |
| result_list = [] | |
| for name in render_tensor.keys(): | |
| result_list.append( | |
| resize( | |
| ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose( | |
| 1, 2, 0 | |
| ), | |
| (height, height), | |
| anti_aliasing=True, | |
| ) | |
| ) | |
| result_array = np.concatenate(result_list, axis=1) | |
| return result_array | |
| def training_step(self, batch, batch_idx, optimizer_idx): | |
| export_cfg(self.logger, self.cfg) | |
| # retrieve the data | |
| in_tensor = {} | |
| for name in self.in_nml: | |
| in_tensor[name] = batch[name] | |
| FB_tensor = {"normal_F": batch["normal_F"], | |
| "normal_B": batch["normal_B"]} | |
| self.netG.train() | |
| preds_F, preds_B = self.netG(in_tensor) | |
| error_NF, error_NB = self.netG.get_norm_error( | |
| preds_F, preds_B, FB_tensor) | |
| (opt_nf, opt_nb) = self.optimizers() | |
| opt_nf.zero_grad() | |
| opt_nb.zero_grad() | |
| self.manual_backward(error_NF, opt_nf) | |
| self.manual_backward(error_NB, opt_nb) | |
| opt_nf.step() | |
| opt_nb.step() | |
| if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0: | |
| self.netG.eval() | |
| with torch.no_grad(): | |
| nmlF, nmlB = self.netG(in_tensor) | |
| in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) | |
| result_array = self.render_func(in_tensor) | |
| self.logger.experiment.add_image( | |
| tag=f"Normal-train/{self.global_step}", | |
| img_tensor=result_array.transpose(2, 0, 1), | |
| global_step=self.global_step, | |
| ) | |
| # metrics processing | |
| metrics_log = { | |
| "train_loss-NF": error_NF.item(), | |
| "train_loss-NB": error_NB.item(), | |
| } | |
| tf_log = tf_log_convert(metrics_log) | |
| bar_log = bar_log_convert(metrics_log) | |
| return { | |
| "loss": error_NF + error_NB, | |
| "loss-NF": error_NF, | |
| "loss-NB": error_NB, | |
| "log": tf_log, | |
| "progress_bar": bar_log, | |
| } | |
| def training_epoch_end(self, outputs): | |
| if [] in outputs: | |
| outputs = outputs[0] | |
| # metrics processing | |
| metrics_log = { | |
| "train_avgloss": batch_mean(outputs, "loss"), | |
| "train_avgloss-NF": batch_mean(outputs, "loss-NF"), | |
| "train_avgloss-NB": batch_mean(outputs, "loss-NB"), | |
| } | |
| tf_log = tf_log_convert(metrics_log) | |
| tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0] | |
| tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0] | |
| return {"log": tf_log} | |
| def validation_step(self, batch, batch_idx): | |
| # retrieve the data | |
| in_tensor = {} | |
| for name in self.in_nml: | |
| in_tensor[name] = batch[name] | |
| FB_tensor = {"normal_F": batch["normal_F"], | |
| "normal_B": batch["normal_B"]} | |
| self.netG.train() | |
| preds_F, preds_B = self.netG(in_tensor) | |
| error_NF, error_NB = self.netG.get_norm_error( | |
| preds_F, preds_B, FB_tensor) | |
| if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or ( | |
| batch_idx == 0 | |
| ): | |
| with torch.no_grad(): | |
| nmlF, nmlB = self.netG(in_tensor) | |
| in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) | |
| result_array = self.render_func(in_tensor) | |
| self.logger.experiment.add_image( | |
| tag=f"Normal-val/{self.global_step}", | |
| img_tensor=result_array.transpose(2, 0, 1), | |
| global_step=self.global_step, | |
| ) | |
| return { | |
| "val_loss": error_NF + error_NB, | |
| "val_loss-NF": error_NF, | |
| "val_loss-NB": error_NB, | |
| } | |
| def validation_epoch_end(self, outputs): | |
| # metrics processing | |
| metrics_log = { | |
| "val_avgloss": batch_mean(outputs, "val_loss"), | |
| "val_avgloss-NF": batch_mean(outputs, "val_loss-NF"), | |
| "val_avgloss-NB": batch_mean(outputs, "val_loss-NB"), | |
| } | |
| tf_log = tf_log_convert(metrics_log) | |
| return {"log": tf_log} | |