Spaces:
Runtime error
Runtime error
| from lib.net import NormalNet | |
| from lib.common.train_util import convert_to_dict, export_cfg, batch_mean | |
| import torch | |
| import numpy as np | |
| import os.path as osp | |
| from skimage.transform import resize | |
| import pytorch_lightning as pl | |
| class Normal(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super(Normal, self).__init__() | |
| self.cfg = cfg | |
| self.batch_size = self.cfg.batch_size | |
| self.lr_F = self.cfg.lr_netF | |
| self.lr_B = self.cfg.lr_netB | |
| self.lr_D = self.cfg.lr_netD | |
| self.overfit = cfg.overfit | |
| self.F_losses = [item[0] for item in self.cfg.net.front_losses] | |
| self.B_losses = [item[0] for item in self.cfg.net.back_losses] | |
| self.ALL_losses = self.F_losses + self.B_losses | |
| self.automatic_optimization = False | |
| self.schedulers = [] | |
| self.netG = NormalNet(self.cfg) | |
| self.in_nml = [item[0] for item in cfg.net.in_nml] | |
| # Training related | |
| def configure_optimizers(self): | |
| optim_params_N_D = None | |
| optimizer_N_D = None | |
| scheduler_N_D = None | |
| # set optimizer | |
| optim_params_N_F = [{"params": self.netG.netF.parameters(), "lr": self.lr_F}] | |
| optim_params_N_B = [{"params": self.netG.netB.parameters(), "lr": self.lr_B}] | |
| optimizer_N_F = torch.optim.Adam(optim_params_N_F, lr=self.lr_F, betas=(0.5, 0.999)) | |
| optimizer_N_B = torch.optim.Adam(optim_params_N_B, lr=self.lr_B, betas=(0.5, 0.999)) | |
| scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_F, | |
| milestones=self.cfg.schedule, | |
| gamma=self.cfg.gamma) | |
| scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_B, | |
| milestones=self.cfg.schedule, | |
| gamma=self.cfg.gamma) | |
| if 'gan' in self.ALL_losses: | |
| optim_params_N_D = [{"params": self.netG.netD.parameters(), "lr": self.lr_D}] | |
| optimizer_N_D = torch.optim.Adam(optim_params_N_D, lr=self.lr_D, betas=(0.5, 0.999)) | |
| scheduler_N_D = torch.optim.lr_scheduler.MultiStepLR(optimizer_N_D, | |
| milestones=self.cfg.schedule, | |
| gamma=self.cfg.gamma) | |
| self.schedulers = [scheduler_N_F, scheduler_N_B, scheduler_N_D] | |
| optims = [optimizer_N_F, optimizer_N_B, optimizer_N_D] | |
| else: | |
| self.schedulers = [scheduler_N_F, scheduler_N_B] | |
| optims = [optimizer_N_F, optimizer_N_B] | |
| return optims, self.schedulers | |
| def render_func(self, render_tensor, dataset, idx): | |
| height = render_tensor["image"].shape[2] | |
| result_list = [] | |
| for name in render_tensor.keys(): | |
| result_list.append( | |
| resize( | |
| ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(1, 2, 0), | |
| (height, height), | |
| anti_aliasing=True, | |
| )) | |
| self.logger.log_image(key=f"Normal/{dataset}/{idx if not self.overfit else 1}", | |
| images=[(np.concatenate(result_list, axis=1) * 255.0).astype(np.uint8) | |
| ]) | |
| def training_step(self, batch, batch_idx): | |
| # cfg log | |
| if not self.cfg.fast_dev and self.global_step == 0 and self.cfg.devices == 1: | |
| export_cfg(self.logger, osp.join(self.cfg.results_path, self.cfg.name), self.cfg) | |
| self.logger.experiment.config.update(convert_to_dict(self.cfg)) | |
| self.netG.train() | |
| # retrieve the data | |
| in_tensor = {} | |
| for name in self.in_nml: | |
| in_tensor[name] = batch[name] | |
| FB_tensor = {"normal_F": batch["normal_F"], "normal_B": batch["normal_B"]} | |
| in_tensor.update(FB_tensor) | |
| preds_F, preds_B = self.netG(in_tensor) | |
| error_dict = self.netG.get_norm_error(preds_F, preds_B, FB_tensor) | |
| if 'gan' in self.ALL_losses: | |
| (opt_F, opt_B, opt_D) = self.optimizers() | |
| opt_F.zero_grad() | |
| self.manual_backward(error_dict["netF"]) | |
| opt_B.zero_grad() | |
| self.manual_backward(error_dict["netB"], retain_graph=True) | |
| opt_D.zero_grad() | |
| self.manual_backward(error_dict["netD"]) | |
| opt_F.step() | |
| opt_B.step() | |
| opt_D.step() | |
| else: | |
| (opt_F, opt_B) = self.optimizers() | |
| opt_F.zero_grad() | |
| self.manual_backward(error_dict["netF"]) | |
| opt_B.zero_grad() | |
| self.manual_backward(error_dict["netB"]) | |
| opt_F.step() | |
| opt_B.step() | |
| if batch_idx > 0 and batch_idx % int( | |
| self.cfg.freq_show_train) == 0 and self.cfg.devices == 1: | |
| self.netG.eval() | |
| with torch.no_grad(): | |
| nmlF, nmlB = self.netG(in_tensor) | |
| in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) | |
| self.render_func(in_tensor, "train", self.global_step) | |
| # metrics processing | |
| metrics_log = {"loss": error_dict["netF"] + error_dict["netB"]} | |
| if "gan" in self.ALL_losses: | |
| metrics_log["loss"] += error_dict["netD"] | |
| for key in error_dict.keys(): | |
| metrics_log["train/loss_" + key] = error_dict[key].item() | |
| self.log_dict(metrics_log, | |
| prog_bar=True, | |
| logger=True, | |
| on_step=True, | |
| on_epoch=False, | |
| sync_dist=True) | |
| return metrics_log | |
| def training_epoch_end(self, outputs): | |
| # metrics processing | |
| metrics_log = {} | |
| for key in outputs[0].keys(): | |
| if "/" in key: | |
| [stage, loss_name] = key.split("/") | |
| else: | |
| stage = "train" | |
| loss_name = key | |
| metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key) | |
| self.log_dict(metrics_log, | |
| prog_bar=False, | |
| logger=True, | |
| on_step=False, | |
| on_epoch=True, | |
| rank_zero_only=True) | |
| def validation_step(self, batch, batch_idx): | |
| self.netG.eval() | |
| self.netG.training = False | |
| # retrieve the data | |
| in_tensor = {} | |
| for name in self.in_nml: | |
| in_tensor[name] = batch[name] | |
| FB_tensor = {"normal_F": batch["normal_F"], "normal_B": batch["normal_B"]} | |
| in_tensor.update(FB_tensor) | |
| preds_F, preds_B = self.netG(in_tensor) | |
| error_dict = self.netG.get_norm_error(preds_F, preds_B, FB_tensor) | |
| if batch_idx % int(self.cfg.freq_show_train) == 0 and self.cfg.devices == 1: | |
| with torch.no_grad(): | |
| nmlF, nmlB = self.netG(in_tensor) | |
| in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) | |
| self.render_func(in_tensor, "val", batch_idx) | |
| # metrics processing | |
| metrics_log = {"val/loss": error_dict["netF"] + error_dict["netB"]} | |
| if "gan" in self.ALL_losses: | |
| metrics_log["val/loss"] += error_dict["netD"] | |
| for key in error_dict.keys(): | |
| metrics_log["val/" + key] = error_dict[key].item() | |
| return metrics_log | |
| def validation_epoch_end(self, outputs): | |
| # metrics processing | |
| metrics_log = {} | |
| for key in outputs[0].keys(): | |
| [stage, loss_name] = key.split("/") | |
| metrics_log[f"{stage}/avg-{loss_name}"] = batch_mean(outputs, key) | |
| self.log_dict(metrics_log, | |
| prog_bar=False, | |
| logger=True, | |
| on_step=False, | |
| on_epoch=True, | |
| rank_zero_only=True) | |