Spaces:
Runtime error
Runtime error
| import copy | |
| import functools | |
| import json | |
| import os | |
| from pathlib import Path | |
| from pdb import set_trace as st | |
| from typing import Any | |
| import vision_aided_loss | |
| import blobfile as bf | |
| import imageio | |
| import numpy as np | |
| import torch as th | |
| import torch.distributed as dist | |
| import torchvision | |
| from PIL import Image | |
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
| from torch.optim import AdamW | |
| from torch.utils.tensorboard.writer import SummaryWriter | |
| from tqdm import tqdm | |
| from dnnlib.util import requires_grad | |
| from guided_diffusion.nn import update_ema | |
| from guided_diffusion.fp16_util import MixedPrecisionTrainer | |
| from guided_diffusion import dist_util, logger | |
| from guided_diffusion.train_util import (calc_average_loss, | |
| log_rec3d_loss_dict, | |
| find_resume_checkpoint) | |
| from guided_diffusion.continuous_diffusion_utils import get_mixed_prediction, different_p_q_objectives, kl_per_group_vada, kl_balancer | |
| from .train_util_diffusion_lsgm_noD_joint import TrainLoop3DDiffusionLSGMJointnoD | |
| from nsr.losses.builder import kl_coeff | |
| def get_blob_logdir(): | |
| # You can change this to be a separate path to save checkpoints to | |
| # a blobstore or some external drive. | |
| return logger.get_dir() | |
| class TrainLoop3DDiffusionLSGM_cvD(TrainLoop3DDiffusionLSGMJointnoD): | |
| def __init__(self, | |
| *, | |
| rec_model, | |
| denoise_model, | |
| diffusion, | |
| sde_diffusion, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| triplane_scaling_divider=1, | |
| use_amp=False, | |
| diffusion_input_size=224, | |
| init_cvD=True, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| denoise_model=denoise_model, | |
| diffusion=diffusion, | |
| sde_diffusion=sde_diffusion, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| triplane_scaling_divider=triplane_scaling_divider, | |
| use_amp=use_amp, | |
| diffusion_input_size=diffusion_input_size, | |
| **kwargs) | |
| # self.setup_cvD() | |
| # def setup_cvD(self): | |
| device = dist_util.dev() | |
| # TODO copied from nvs_canoD, could be merged | |
| # * create vision aided model | |
| # TODO, load model api | |
| # nvs D | |
| if init_cvD: | |
| self.nvs_cvD = vision_aided_loss.Discriminator( | |
| cv_type='clip', loss_type='multilevel_sigmoid_s', | |
| device=device).to(device) | |
| self.nvs_cvD.cv_ensemble.requires_grad_( | |
| False) # Freeze feature extractor | |
| self._load_and_sync_parameters(model=self.nvs_cvD, model_name='cvD') | |
| self.mp_trainer_nvs_cvD = MixedPrecisionTrainer( | |
| model=self.nvs_cvD, | |
| use_fp16=self.use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| model_name='cvD', | |
| use_amp=use_amp, | |
| # use_amp= | |
| # False, # assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." | |
| model_params=list(self.nvs_cvD.decoder.parameters())) | |
| cvD_lr = 2e-4 * (lr / 1e-5) * self.loss_class.opt.nvs_D_lr_mul | |
| # cvD_lr = 1e-5*(lr/1e-5) | |
| self.opt_cvD = AdamW(self.mp_trainer_nvs_cvD.master_params, | |
| lr=cvD_lr, | |
| betas=(0, 0.999), | |
| eps=1e-8) # dlr in biggan cfg | |
| logger.log(f'cpt_cvD lr: {cvD_lr}') | |
| if self.use_ddp: | |
| self.ddp_nvs_cvD = DDP( | |
| self.nvs_cvD, | |
| device_ids=[dist_util.dev()], | |
| output_device=dist_util.dev(), | |
| broadcast_buffers=False, | |
| bucket_cap_mb=128, | |
| find_unused_parameters=False, | |
| ) | |
| else: | |
| self.ddp_nvs_cvD = self.nvs_cvD | |
| # cano d | |
| self.cano_cvD = vision_aided_loss.Discriminator( | |
| cv_type='clip', loss_type='multilevel_sigmoid_s', | |
| device=device).to(device) | |
| self.cano_cvD.cv_ensemble.requires_grad_( | |
| False) # Freeze feature extractor | |
| # self.cano_cvD.train() | |
| self._load_and_sync_parameters(model=self.cano_cvD, | |
| model_name='cano_cvD') | |
| self.mp_trainer_cano_cvD = MixedPrecisionTrainer( | |
| model=self.cano_cvD, | |
| use_fp16=self.use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| model_name='canonical_cvD', | |
| use_amp=use_amp, | |
| model_params=list(self.cano_cvD.decoder.parameters())) | |
| cano_lr = 2e-4 * ( | |
| lr / 1e-5) # D_lr=2e-4 in cvD by default. 1e-4 still overfitting | |
| self.opt_cano_cvD = AdamW( | |
| self.mp_trainer_cano_cvD.master_params, | |
| lr=cano_lr, # same as the G | |
| betas=(0, 0.999), | |
| eps=1e-8) # dlr in biggan cfg | |
| logger.log(f'cpt_cano_cvD lr: {cano_lr}') | |
| self.ddp_cano_cvD = DDP( | |
| self.cano_cvD, | |
| device_ids=[dist_util.dev()], | |
| output_device=dist_util.dev(), | |
| broadcast_buffers=False, | |
| bucket_cap_mb=128, | |
| find_unused_parameters=False, | |
| ) | |
| # Fix decoder | |
| requires_grad(self.rec_model.decoder, False) | |
| def _post_run_step(self): | |
| if self.step % self.log_interval == 0 and dist_util.get_rank() == 0 and self.step != 0: | |
| out = logger.dumpkvs() | |
| # * log to tensorboard | |
| for k, v in out.items(): | |
| self.writer.add_scalar(f'Loss/{k}', v, | |
| self.step + self.resume_step) | |
| if self.step % self.eval_interval == 0 and self.step != 0: | |
| # if self.step % self.eval_interval == 0: | |
| if dist_util.get_rank() == 0: | |
| self.eval_ddpm_sample(self.rec_model) | |
| if self.sde_diffusion.args.train_vae: | |
| self.eval_loop(self.rec_model) | |
| if self.step % self.save_interval == 0 and self.step != 0: | |
| self.save(self.mp_trainer, self.mp_trainer.model_name) | |
| self.step += 1 | |
| if self.step > self.iterations: | |
| print('reached maximum iterations, exiting') | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save(self.mp_trainer, self.mp_trainer.model_name) | |
| exit() | |
| def run_loop(self): | |
| while (not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps): | |
| # let all processes sync up before starting with a new epoch of training | |
| # dist_util.synchronize() | |
| batch = next(self.data) | |
| self.run_step(batch, 'cano_ddpm_only') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'cano_ddpm_step') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'd_step_rec') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'nvs_ddpm_step') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'd_step_nvs') | |
| self._post_run_step() | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| # self.save(self.mp_trainer_canonical_cvD, 'cvD') | |
| def run_step(self, batch, step='g_step'): | |
| # self.forward_backward(batch) | |
| if step == 'ce_ddpm_step': | |
| self.ce_ddpm_step(batch) | |
| elif step in ['ce', 'ddpm', 'cano_ddpm_only']: | |
| self.joint_rec_ddpm(batch, step) | |
| elif step == 'cano_ddpm_step': | |
| self.joint_rec_ddpm(batch, 'cano') | |
| elif step == 'd_step_rec': | |
| self.forward_D(batch, behaviour='rec') | |
| elif step == 'nvs_ddpm_step': | |
| self.joint_rec_ddpm(batch, 'nvs') | |
| elif step == 'd_step_nvs': | |
| self.forward_D(batch, behaviour='nvs') | |
| self._anneal_lr() | |
| self.log_step() | |
| def flip_encoder_grad(self, mode=True): | |
| requires_grad(self.rec_model.encoder, mode) | |
| def forward_D(self, batch, behaviour): # update D | |
| self.flip_encoder_grad(False) | |
| self.rec_model.eval() | |
| # self.ddp_model.requires_grad_(False) | |
| # update two D | |
| if behaviour == 'nvs': | |
| self.mp_trainer_nvs_cvD.zero_grad() | |
| self.ddp_nvs_cvD.requires_grad_(True) | |
| self.ddp_nvs_cvD.train() | |
| self.ddp_cano_cvD.requires_grad_(False) | |
| self.ddp_cano_cvD.eval() | |
| else: # update rec canonical D | |
| self.mp_trainer_cano_cvD.zero_grad() | |
| self.ddp_nvs_cvD.requires_grad_(False) | |
| self.ddp_nvs_cvD.eval() | |
| self.ddp_cano_cvD.requires_grad_(True) | |
| self.ddp_cano_cvD.train() | |
| batch_size = batch['img'].shape[0] | |
| # * sample a new batch for D training | |
| for i in range(0, batch_size, self.microbatch): | |
| micro = { | |
| k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous() | |
| for k, v in batch.items() | |
| } | |
| with th.autocast(device_type='cuda', | |
| dtype=th.float16, | |
| enabled=self.mp_trainer_cano_cvD.use_amp): | |
| latent = self.ddp_rec_model(img=micro['img_to_encoder'], | |
| behaviour='enc_dec_wo_triplane') | |
| cano_pred = self.ddp_rec_model(latent=latent, | |
| c=micro['c'], | |
| behaviour='triplane_dec') | |
| # TODO, optimize with one encoder, and two triplane decoder | |
| # FIXME quit autocast to runbackward | |
| if behaviour == 'rec': | |
| if 'image_sr' in cano_pred: | |
| # d_loss_cano = self.run_D_Diter( | |
| # # real=micro['img_sr'], | |
| # # fake=cano_pred['image_sr'], | |
| # real=0.5 * micro['img_sr'] + 0.5 * th.nn.functional.interpolate(micro['img'], size=micro['img_sr'].shape[2:], mode='bilinear'), | |
| # fake=0.5 * cano_pred['image_sr'] + 0.5 * th.nn.functional.interpolate(cano_pred['image_raw'], size=cano_pred['image_sr'].shape[2:], mode='bilinear'), | |
| # D=self.ddp_canonical_cvD) # ! failed, color bias | |
| # try concat them in batch | |
| d_loss = self.run_D_Diter( | |
| real=th.cat([ | |
| th.nn.functional.interpolate( | |
| micro['img'], | |
| size=micro['img_sr'].shape[2:], | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True), | |
| micro['img_sr'], | |
| ], | |
| dim=1), | |
| fake=th.cat([ | |
| th.nn.functional.interpolate( | |
| cano_pred['image_raw'], | |
| size=cano_pred['image_sr'].shape[2:], | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True), | |
| cano_pred['image_sr'], | |
| ], | |
| dim=1), | |
| D=self.ddp_cano_cvD) # TODO, add SR for FFHQ | |
| else: | |
| d_loss = self.run_D_Diter(real=micro['img'], | |
| fake=cano_pred['image_raw'], | |
| D=self.ddp_cano_cvD) | |
| log_rec3d_loss_dict({'vision_aided_loss/D_cano': d_loss}) | |
| # self.mp_trainer_canonical_cvD.backward(d_loss_cano) | |
| else: | |
| assert behaviour == 'nvs' | |
| novel_view_c = th.roll(micro['c'], 1, 0) | |
| nvs_pred = self.ddp_rec_model(latent=latent, | |
| c=novel_view_c, | |
| behaviour='triplane_dec') | |
| if 'image_sr' in nvs_pred: | |
| d_loss = self.run_D_Diter( | |
| real=th.cat([ | |
| th.nn.functional.interpolate( | |
| cano_pred['image_raw'], | |
| size=cano_pred['image_sr'].shape[2:], | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True), | |
| cano_pred['image_sr'], | |
| ], | |
| dim=1), | |
| fake=th.cat([ | |
| th.nn.functional.interpolate( | |
| nvs_pred['image_raw'], | |
| size=nvs_pred['image_sr'].shape[2:], | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=True), | |
| nvs_pred['image_sr'], | |
| ], | |
| dim=1), | |
| D=self.ddp_nvs_cvD) # TODO, add SR for FFHQ | |
| else: | |
| d_loss = self.run_D_Diter( | |
| real=cano_pred['image_raw'], | |
| fake=nvs_pred['image_raw'], | |
| D=self.ddp_nvs_cvD) # TODO, add SR for FFHQ | |
| log_rec3d_loss_dict({'vision_aided_loss/D_nvs': d_loss}) | |
| # self.mp_trainer_cvD.backward(d_loss_nvs) | |
| # quit autocast to run backward() | |
| if behaviour == 'rec': | |
| self.mp_trainer_cano_cvD.backward(d_loss) | |
| # assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer." | |
| _ = self.mp_trainer_cano_cvD.optimize(self.opt_cano_cvD) | |
| else: | |
| assert behaviour == 'nvs' | |
| self.mp_trainer_nvs_cvD.backward(d_loss) | |
| _ = self.mp_trainer_nvs_cvD.optimize(self.opt_cvD) | |
| self.flip_encoder_grad(True) | |
| self.rec_model.train() | |
| # def forward_ddpm(self, eps): | |
| # args = self.sde_diffusion.args | |
| # # sample noise | |
| # noise = th.randn(size=eps.size(), device=eps.device | |
| # ) # note that this noise value is currently shared! | |
| # # get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae) | |
| # t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \ | |
| # self.sde_diffusion.iw_quantities(args.iw_sample_p) | |
| # eps_t_p = self.sde_diffusion.sample_q(eps, noise, var_t_p, m_t_p) | |
| # # logsnr_p = self.sde_diffusion.log_snr(m_t_p, | |
| # # var_t_p) # for p only | |
| # pred_eps_p, pred_x0_p, logsnr_p = self.ddpm_step( | |
| # eps_t_p, t_p, m_t_p, var_t_p) | |
| # # ! batchify for mixing_component | |
| # # mixing normal trick | |
| # mixing_component = self.sde_diffusion.mixing_component( | |
| # eps_t_p, var_t_p, t_p, enabled=True) # TODO, which should I use? | |
| # pred_eps_p = get_mixed_prediction( | |
| # True, pred_eps_p, | |
| # self.ddp_ddpm_model(x=None, | |
| # timesteps=None, | |
| # get_attr='mixing_logit'), mixing_component) | |
| # # ! eps loss equivalent to snr weighting of x0 loss, see "progressive distillation" | |
| # with self.ddp_ddpm_model.no_sync(): # type: ignore | |
| # l2_term_p = th.square(pred_eps_p - noise) # ? weights | |
| # p_eps_objective = th.mean(obj_weight_t_p * l2_term_p) | |
| # log_rec3d_loss_dict( | |
| # dict(mixing_logit=self.ddp_ddpm_model( | |
| # x=None, timesteps=None, get_attr='mixing_logit').detach(), )) | |
| # return { | |
| # 'pred_eps_p': pred_eps_p, | |
| # 'eps_t_p': eps_t_p, | |
| # 'p_eps_objective': p_eps_objective, | |
| # 'pred_x0_p': pred_x0_p, | |
| # 'logsnr_p': logsnr_p | |
| # } | |
| # ddpm + rec loss | |
| def joint_rec_ddpm(self, batch, behaviour='cano', *args, **kwargs): | |
| """ | |
| add sds grad to all ae predicted x_0 | |
| """ | |
| args = self.sde_diffusion.args | |
| # ! enable the gradient of both models | |
| # requires_grad(self.rec_model, True) | |
| self.flip_encoder_grad(True) | |
| self.rec_model.train() | |
| requires_grad(self.ddpm_model, True) | |
| self.ddpm_model.train() | |
| requires_grad(self.ddp_cano_cvD, False) | |
| requires_grad(self.ddp_nvs_cvD, False) | |
| self.ddp_cano_cvD.eval() | |
| self.ddp_nvs_cvD.eval() | |
| self.mp_trainer.zero_grad() | |
| # if args.train_vae: | |
| # for param in self.rec_model.decoder.triplane_decoder.parameters( # type: ignore | |
| # ): # type: ignore | |
| # param.requires_grad_( | |
| # False | |
| # ) # ! disable triplane_decoder grad in each iteration indepenently; | |
| assert args.train_vae | |
| batch_size = batch['img'].shape[0] | |
| for i in range(0, batch_size, self.microbatch): | |
| micro = { | |
| k: | |
| v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
| v, th.Tensor) else v | |
| for k, v in batch.items() | |
| } | |
| # =================================== ae part =================================== | |
| with th.cuda.amp.autocast(dtype=th.float16, | |
| enabled=self.mp_trainer.use_amp): | |
| # and args.train_vae): | |
| loss = th.tensor(0.).to(dist_util.dev()) | |
| vision_aided_loss = th.tensor(0.).to(dist_util.dev()) | |
| vae_out = self.ddp_rec_model( | |
| img=micro['img_to_encoder'], | |
| c=micro['c'], | |
| behaviour='encoder_vae', | |
| ) # pred: (B, 3, 64, 64) | |
| eps = vae_out[self.latent_name] | |
| if 'bg_plane' in vae_out: | |
| eps = th.cat((eps, vae_out['bg_plane']), dim=1) # include background, B 12+4 32 32 | |
| # eps = pred[self.latent_name] | |
| # eps = vae_out.pop(self.latent_name) | |
| # ! running diffusion forward | |
| p_sample_batch = self.prepare_ddpm(eps) | |
| # ddpm_ret = self.forward_ddpm(eps) | |
| ddpm_ret = self.apply_model(p_sample_batch) | |
| # p_loss = ddpm_ret['p_eps_objective'] | |
| loss += ddpm_ret['p_eps_objective'].mean() | |
| # ===================================================================== | |
| # ! reconstruction loss + gan loss | |
| if behaviour != 'cano_ddpm_only': | |
| if behaviour == 'cano': | |
| cano_pred = self.ddp_rec_model( | |
| latent=vae_out, | |
| c=micro['c'], | |
| behaviour=self.render_latent_behaviour) | |
| with self.ddp_model.no_sync(): # type: ignore | |
| q_vae_recon_loss, loss_dict = self.loss_class( | |
| cano_pred, micro, test_mode=False) | |
| loss += q_vae_recon_loss | |
| # add gan loss | |
| vision_aided_loss = self.ddp_cano_cvD( | |
| cano_pred['image_raw'], for_G=True | |
| ).mean( | |
| ) * self.loss_class.opt.rec_cvD_lambda # [B, 1] shape | |
| loss_dict.update({ | |
| 'vision_aided_loss/G_rec': | |
| vision_aided_loss.detach(), | |
| }) | |
| log_rec3d_loss_dict(loss_dict) | |
| if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
| self.cano_ddpm_log(cano_pred, micro, ddpm_ret) | |
| else: | |
| assert behaviour == 'nvs' | |
| nvs_pred = self.ddp_rec_model( | |
| img=micro['img_to_encoder'], | |
| c=th.roll(micro['c'], 1, 0), | |
| ) # ! render novel views only for D loss | |
| vision_aided_loss = self.ddp_nvs_cvD( | |
| nvs_pred['image_raw'], for_G=True | |
| ).mean( | |
| ) * self.loss_class.opt.nvs_cvD_lambda # [B, 1] shape | |
| log_rec3d_loss_dict( | |
| {'vision_aided_loss/G_nvs': vision_aided_loss}) | |
| if dist_util.get_rank() == 0 and self.step % 500 == 1: | |
| self.nvs_log(nvs_pred, micro) | |
| else: | |
| cano_pred = self.ddp_rec_model( | |
| latent=vae_out, | |
| c=micro['c'], | |
| behaviour=self.render_latent_behaviour) | |
| with self.ddp_model.no_sync(): # type: ignore | |
| q_vae_recon_loss, loss_dict = self.loss_class( | |
| { | |
| **vae_out, # include latent here. | |
| **cano_pred, | |
| }, | |
| micro, | |
| test_mode=False) | |
| # pred, | |
| # micro, | |
| # test_mode=False) | |
| log_rec3d_loss_dict(loss_dict) | |
| loss += q_vae_recon_loss | |
| loss += vision_aided_loss | |
| self.mp_trainer.backward(loss) | |
| # quit for loop | |
| _ = self.mp_trainer.optimize(self.opt, clip_grad=self.loss_class.opt.grad_clip) | |
| def cano_ddpm_log(self, cano_pred, micro, ddpm_ret): | |
| assert isinstance(cano_pred, dict) | |
| behaviour = 'cano' | |
| gt_depth = micro['depth'] | |
| if gt_depth.ndim == 3: | |
| gt_depth = gt_depth.unsqueeze(1) | |
| gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
| gt_depth.min()) | |
| if 'image_depth' in cano_pred: | |
| pred_depth = cano_pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| else: | |
| pred_depth = th.zeros_like(gt_depth) | |
| pred_img = cano_pred['image_raw'] | |
| gt_img = micro['img'] | |
| if 'image_sr' in cano_pred: | |
| if cano_pred['image_sr'].shape[-1] == 512: | |
| pred_img = th.cat( | |
| [self.pool_512(pred_img), cano_pred['image_sr']], dim=-1) | |
| gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
| dim=-1) | |
| pred_depth = self.pool_512(pred_depth) | |
| gt_depth = self.pool_512(gt_depth) | |
| elif cano_pred['image_sr'].shape[-1] == 256: | |
| pred_img = th.cat( | |
| [self.pool_256(pred_img), cano_pred['image_sr']], dim=-1) | |
| gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
| dim=-1) | |
| pred_depth = self.pool_256(pred_depth) | |
| gt_depth = self.pool_256(gt_depth) | |
| else: | |
| pred_img = th.cat( | |
| [self.pool_128(pred_img), cano_pred['image_sr']], dim=-1) | |
| gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
| dim=-1) | |
| gt_depth = self.pool_128(gt_depth) | |
| pred_depth = self.pool_128(pred_depth) | |
| else: | |
| gt_img = self.pool_64(gt_img) | |
| gt_depth = self.pool_64(gt_depth) | |
| gt_vis = th.cat([ | |
| gt_img, micro['img'], micro['img'], | |
| gt_depth.repeat_interleave(3, dim=1) | |
| ], | |
| dim=-1)[0:1] # TODO, fail to load depth. range [0, 1] | |
| # eps_t_p_3D = eps_t_p.reshape(batch_size, eps_t_p.shape[1]//3, 3, -1) # B C 3 L | |
| eps_t_p, pred_eps_p, logsnr_p = (ddpm_ret[k] | |
| for k in ('eps_t_p', 'pred_eps_p', | |
| 'logsnr_p')) | |
| if 'bg_plane' in cano_pred: | |
| noised_latent = { | |
| 'latent_normalized_2Ddiffusion': eps_t_p[0:1, :12] * self.triplane_scaling_divider, | |
| 'bg_plane': eps_t_p[0:1, 12:16] * self.triplane_scaling_divider, | |
| } | |
| else: | |
| noised_latent = { | |
| 'latent_normalized_2Ddiffusion': eps_t_p[0:1] * self.triplane_scaling_divider, | |
| } | |
| # st() # split bg_plane here | |
| noised_ae_pred = self.ddp_rec_model( | |
| img=None, | |
| c=micro['c'][0:1], | |
| latent=noised_latent, | |
| behaviour=self.render_latent_behaviour) | |
| pred_x0 = self.sde_diffusion._predict_x0_from_eps( | |
| eps_t_p, pred_eps_p, logsnr_p) # for VAE loss, denosied latent | |
| if 'bg_plane' in cano_pred: | |
| denoised_latent = { | |
| 'latent_normalized_2Ddiffusion': pred_x0[0:1, :12] * self.triplane_scaling_divider, | |
| 'bg_plane': pred_x0[0:1, 12:16] * self.triplane_scaling_divider, | |
| } | |
| else: | |
| denoised_latent = { | |
| 'latent_normalized_2Ddiffusion': pred_x0[0:1] * self.triplane_scaling_divider, | |
| } | |
| # pred_xstart_3D | |
| denoised_ae_pred = self.ddp_rec_model( | |
| img=None, | |
| c=micro['c'][0:1], | |
| latent=denoised_latent, | |
| behaviour=self.render_latent_behaviour) | |
| pred_vis = th.cat([ | |
| pred_img[0:1], noised_ae_pred['image_raw'][0:1], | |
| denoised_ae_pred['image_raw'][0:1], | |
| pred_depth[0:1].repeat_interleave(3, dim=1) | |
| ], | |
| dim=-1) # B, 3, H, W | |
| vis = th.cat([gt_vis, pred_vis], | |
| dim=-2)[0].permute(1, 2, | |
| 0).cpu() # ! pred in range[-1, 1] | |
| # vis_grid = torchvision.utils.make_grid(vis) # HWC | |
| vis = vis.numpy() * 127.5 + 127.5 | |
| vis = vis.clip(0, 255).astype(np.uint8) | |
| Image.fromarray(vis).save( | |
| # f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg' | |
| f'{logger.get_dir()}/{self.step+self.resume_step}_{behaviour}.jpg') | |
| print( | |
| 'log denoised vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}_{behaviour}.jpg') | |
| del vis, pred_vis, pred_x0, pred_eps_p, micro | |
| th.cuda.empty_cache() | |
| def nvs_log(self, nvs_pred, micro): | |
| behaviour = 'nvs' | |
| if dist_util.get_rank() == 0 and self.step % 500 == 1: | |
| # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
| gt_depth = micro['depth'] | |
| if gt_depth.ndim == 3: | |
| gt_depth = gt_depth.unsqueeze(1) | |
| gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
| gt_depth.min()) | |
| # if True: | |
| pred_depth = nvs_pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| pred_img = nvs_pred['image_raw'] | |
| gt_img = micro['img'] | |
| if 'image_sr' in nvs_pred: | |
| if nvs_pred['image_sr'].shape[-1] == 512: | |
| pred_img = th.cat( | |
| [self.pool_512(pred_img), nvs_pred['image_sr']], | |
| dim=-1) | |
| gt_img = th.cat( | |
| [self.pool_512(micro['img']), micro['img_sr']], dim=-1) | |
| pred_depth = self.pool_512(pred_depth) | |
| gt_depth = self.pool_512(gt_depth) | |
| elif nvs_pred['image_sr'].shape[-1] == 256: | |
| pred_img = th.cat( | |
| [self.pool_256(pred_img), nvs_pred['image_sr']], | |
| dim=-1) | |
| gt_img = th.cat( | |
| [self.pool_256(micro['img']), micro['img_sr']], dim=-1) | |
| pred_depth = self.pool_256(pred_depth) | |
| gt_depth = self.pool_256(gt_depth) | |
| else: | |
| pred_img = th.cat( | |
| [self.pool_128(pred_img), nvs_pred['image_sr']], | |
| dim=-1) | |
| gt_img = th.cat( | |
| [self.pool_128(micro['img']), micro['img_sr']], dim=-1) | |
| gt_depth = self.pool_128(gt_depth) | |
| pred_depth = self.pool_128(pred_depth) | |
| else: | |
| gt_img = self.pool_64(gt_img) | |
| gt_depth = self.pool_64(gt_depth) | |
| gt_vis = th.cat( | |
| [gt_img, gt_depth.repeat_interleave(3, dim=1)], | |
| dim=-1) # TODO, fail to load depth. range [0, 1] | |
| pred_vis = th.cat( | |
| [pred_img, pred_depth.repeat_interleave(3, dim=1)], | |
| dim=-1) # B, 3, H, W | |
| # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
| # 1, 2, 0).cpu() # ! pred in range[-1, 1] | |
| vis = th.cat([gt_vis, pred_vis], dim=-2) | |
| vis = torchvision.utils.make_grid( | |
| vis, normalize=True, scale_each=True, | |
| value_range=(-1, 1)).cpu().permute(1, 2, 0) # H W 3 | |
| vis = vis.numpy() * 255 | |
| vis = vis.clip(0, 255).astype(np.uint8) | |
| Image.fromarray(vis).save( | |
| f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg') | |
| print('log vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg') | |
| # ! all copied from train_util_cvD.py; should merge later. | |
| def run_D_Diter(self, real, fake, D=None): | |
| # Dmain: Minimize logits for generated images and maximize logits for real images. | |
| if D is None: | |
| D = self.ddp_nvs_cvD | |
| lossD = D(real, for_real=True).mean() + D(fake, for_real=False).mean() | |
| return lossD | |
| def save(self, mp_trainer=None, model_name='rec'): | |
| if mp_trainer is None: | |
| mp_trainer = self.mp_trainer_rec | |
| def save_checkpoint(rate, params): | |
| state_dict = mp_trainer.master_params_to_state_dict(params) | |
| if dist_util.get_rank() == 0: | |
| logger.log(f"saving model {model_name} {rate}...") | |
| if not rate: | |
| filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" | |
| else: | |
| filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" | |
| with bf.BlobFile(bf.join(get_blob_logdir(), filename), | |
| "wb") as f: | |
| th.save(state_dict, f) | |
| save_checkpoint(0, mp_trainer.master_params) | |
| if model_name == 'ddpm': | |
| for rate, params in zip(self.ema_rate, self.ema_params): | |
| save_checkpoint(rate, params) | |
| dist.barrier() | |
| def _load_and_sync_parameters(self, model=None, model_name='rec'): | |
| resume_checkpoint, self.resume_step = find_resume_checkpoint( | |
| self.resume_checkpoint, model_name) or self.resume_checkpoint | |
| if model is None: | |
| model = self.ddp_rec_model # default model in the parent class | |
| logger.log(resume_checkpoint) | |
| if resume_checkpoint and Path(resume_checkpoint).exists(): | |
| if dist_util.get_rank() == 0: | |
| logger.log( | |
| f"loading model from checkpoint: {resume_checkpoint}...") | |
| map_location = { | |
| 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() | |
| } # configure map_location properly | |
| logger.log(f'mark {model_name} loading ', ) | |
| resume_state_dict = dist_util.load_state_dict( | |
| resume_checkpoint, map_location=map_location) | |
| logger.log(f'mark {model_name} loading finished', ) | |
| model_state_dict = model.state_dict() | |
| for k, v in resume_state_dict.items(): | |
| if k in model_state_dict.keys() and v.size( | |
| ) == model_state_dict[k].size(): | |
| model_state_dict[k] = v | |
| # elif 'IN' in k and model_name == 'rec' and getattr(model.decoder, 'decomposed_IN', False): | |
| # model_state_dict[k.replace('IN', 'superresolution.norm.norm_layer')] = v # decomposed IN | |
| elif 'attn.wk' in k or 'attn.wv' in k: # old qkv | |
| logger.log('ignore ', k) | |
| elif 'decoder.vit_decoder.blocks' in k: | |
| # st() | |
| # load from 2D ViT pre-trained into 3D ViT blocks. | |
| assert len(model.decoder.vit_decoder.blocks[0].vit_blks | |
| ) == 2 # assert depth=2 here. | |
| fusion_ca_depth = len( | |
| model.decoder.vit_decoder.blocks[0].vit_blks) | |
| vit_subblk_index = int(k.split('.')[3]) | |
| vit_blk_keyname = ('.').join(k.split('.')[4:]) | |
| fusion_blk_index = vit_subblk_index // fusion_ca_depth | |
| fusion_blk_subindex = vit_subblk_index % fusion_ca_depth | |
| model_state_dict[ | |
| f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}'] = v | |
| # logger.log('load 2D ViT weight: {}'.format(f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}')) | |
| elif 'IN' in k: | |
| logger.log('ignore ', k) | |
| elif 'quant_conv' in k: | |
| logger.log('ignore ', k) | |
| else: | |
| logger.log( | |
| '!!!! ignore key: ', | |
| k, | |
| ": ", | |
| v.size(), | |
| ) | |
| if k in model_state_dict: | |
| logger.log('shape in model: ', | |
| model_state_dict[k].size()) | |
| else: | |
| logger.log(k, 'not in model_state_dict') | |
| model.load_state_dict(model_state_dict, strict=True) | |
| del model_state_dict | |
| if dist_util.get_world_size() > 1: | |
| dist_util.sync_params(model.parameters()) | |
| logger.log(f'synced {model_name} params') | |
| class TrainLoop3DDiffusionLSGM_cvD_scaling(TrainLoop3DDiffusionLSGM_cvD): | |
| def __init__(self, | |
| *, | |
| rec_model, | |
| denoise_model, | |
| diffusion, | |
| sde_diffusion, | |
| loss_class, | |
| data, | |
| eval_data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| eval_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=0.001, | |
| weight_decay=0, | |
| lr_anneal_steps=0, | |
| iterations=10001, | |
| triplane_scaling_divider=1, | |
| use_amp=False, | |
| diffusion_input_size=224, | |
| init_cvD=True, | |
| **kwargs): | |
| super().__init__(rec_model=rec_model, | |
| denoise_model=denoise_model, | |
| diffusion=diffusion, | |
| sde_diffusion=sde_diffusion, | |
| loss_class=loss_class, | |
| data=data, | |
| eval_data=eval_data, | |
| batch_size=batch_size, | |
| microbatch=microbatch, | |
| lr=lr, | |
| ema_rate=ema_rate, | |
| log_interval=log_interval, | |
| eval_interval=eval_interval, | |
| save_interval=save_interval, | |
| resume_checkpoint=resume_checkpoint, | |
| use_fp16=use_fp16, | |
| fp16_scale_growth=fp16_scale_growth, | |
| weight_decay=weight_decay, | |
| lr_anneal_steps=lr_anneal_steps, | |
| iterations=iterations, | |
| triplane_scaling_divider=triplane_scaling_divider, | |
| use_amp=use_amp, | |
| diffusion_input_size=diffusion_input_size, | |
| init_cvD=init_cvD, | |
| **kwargs) | |
| def _update_latent_stat_ema(self, latent: th.Tensor): | |
| # update the miu/var of ema_latent | |
| for rate, params in zip(self.ema_rate, | |
| [self.ddpm_model.ema_latent_mean]): | |
| update_ema(params, latent.mean(0, keepdim=True), rate=rate) | |
| for rate, params in zip(self.ema_rate, | |
| [self.ddpm_model.ema_latent_std]): | |
| update_ema(params, latent.std([1,2,3]).mean(0, keepdim=True), rate=rate) | |
| log_rec3d_loss_dict({'ema_latent_std': self.ddpm_model.ema_latent_std.mean()}) | |
| log_rec3d_loss_dict({'ema_latent_mean': self.ddpm_model.ema_latent_mean.mean()}) | |
| # def _init_optim_groups(self, rec_model, freeze_decoder=True): | |
| # # unfreeze decoder when scaling is enabled | |
| # return super()._init_optim_groups(rec_model, freeze_decoder=False) | |
| def _standarize(self, eps): | |
| # scaled_eps = (eps - self.ddpm_model.ema_latent_mean | |
| # ) / self.ddpm_model.ema_latent_std | |
| # scaled_eps = eps - self.ddpm_model.ema_latent_mean | |
| # scaled_eps = eps.div(self.ddpm_model.ema_latent_std) | |
| # scaled_eps = eps + self.ddpm_model.ema_latent_std | |
| scaled_eps = eps.add(-self.ddpm_model.ema_latent_mean).mul(1/self.ddpm_model.ema_latent_std) | |
| return scaled_eps | |
| def _unstandarize(self, scaled_eps): | |
| return scaled_eps.mul(self.ddpm_model.ema_latent_std).add(self.ddpm_model.ema_latent_mean) | |
| class TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm(TrainLoop3DDiffusionLSGM_cvD_scaling): | |
| def __init__(self, *, rec_model, denoise_model, diffusion, sde_diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, triplane_scaling_divider=1, use_amp=False, diffusion_input_size=224,init_cvD=False, **kwargs): | |
| super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, sde_diffusion=sde_diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, | |
| init_cvD=init_cvD, **kwargs) | |
| def _setup_opt(self): | |
| # TODO, two optims groups. | |
| self.opt = AdamW([{ | |
| 'name': 'ddpm', | |
| 'params': self.ddpm_model.parameters(), | |
| }], | |
| lr=self.lr, | |
| weight_decay=self.weight_decay) | |
| for rec_param_group in self._init_optim_groups(self.rec_model, True): # freeze D | |
| self.opt.add_param_group(rec_param_group) | |
| logger.log(self.opt) | |
| def next_n_batch(self, n=1): | |
| '''sample n batch at the same time. | |
| ''' | |
| all_batch_list = [next(self.data) for _ in range(n)] | |
| return { | |
| k: th.cat([batch[k] for batch in all_batch_list], 0) | |
| for k in all_batch_list[0].keys() | |
| } | |
| # pass | |
| def subset_batch(self, batch=None, micro_batchsize=4, big_endian=False): | |
| '''sample a batch subset | |
| ''' | |
| if batch is None: | |
| batch = next(self.data) | |
| if big_endian: | |
| return { | |
| k: v[-micro_batchsize:] | |
| for k, v in batch.items() | |
| } | |
| else: | |
| return { | |
| k: v[:micro_batchsize] | |
| for k, v in batch.items() | |
| } | |
| # pass | |
| def run_loop(self): | |
| while (not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps): | |
| # let all processes sync up before starting with a new epoch of training | |
| # dist_util.synchronize() | |
| # batch = self.next_n_batch(n=4) | |
| batch = self.next_n_batch(n=6) # effective BS=72 | |
| self.run_step(batch, 'ddpm') # ddpm fixed | |
| batch = next(self.data) | |
| self.run_step(batch, 'ce') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'cano_ddpm_step') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'd_step_rec') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'nvs_ddpm_step') | |
| # batch = next(self.data) | |
| # self.run_step(batch, 'd_step_nvs') | |
| self._post_run_step() | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| # self.save(self.mp_trainer_canonical_cvD, 'cvD') | |
| # def _init_optim_groups(self, rec_model, freeze_decoder=True): | |
| # # unfreeze decoder when scaling is enabled | |
| # # return super()._init_optim_groups(rec_model, freeze_decoder=False) | |
| # return super()._init_optim_groups(rec_model, freeze_decoder=True) | |
| def entropy_weight(self, normal_entropy=None): | |
| return self.loss_class.opt.negative_entropy_lambda | |
| # ddpm + rec loss | |
| def joint_rec_ddpm(self, batch, behaviour='ddpm', *args, **kwargs): | |
| """ | |
| add sds grad to all ae predicted x_0 | |
| """ | |
| args = self.sde_diffusion.args | |
| # ! enable the gradient of both models | |
| # requires_grad(self.rec_model, True) | |
| # if behaviour == 'ce': # ll sampling? later. train encoder. | |
| if 'ce' in behaviour: # ll sampling? later. train encoder. | |
| ############################################## | |
| ###### Update the VAE encoder/decoder ######## | |
| ############################################## | |
| requires_grad(self.ddpm_model, False) | |
| self.ddpm_model.eval() | |
| ce_flag = True | |
| if behaviour == 'ce_E': # unfreeze E and freeze D | |
| requires_grad(self.rec_model.encoder, True) | |
| self.rec_model.encoder.train() | |
| requires_grad(self.rec_model.decoder, False) | |
| self.rec_model.decoder.eval() | |
| else: # train all | |
| requires_grad(self.rec_model, True) | |
| self.rec_model.train() | |
| else: # train ddpm. | |
| ce_flag = False | |
| # self.flip_encoder_grad(False) | |
| requires_grad(self.rec_model, False) | |
| self.rec_model.eval() | |
| requires_grad(self.ddpm_model, True) | |
| self.ddpm_model.train() | |
| self.mp_trainer.zero_grad() | |
| # assert args.train_vae | |
| batch_size = batch['img'].shape[0] | |
| # for i in range(0, batch_size, self.microbatch): | |
| for i in range(0, batch_size, batch_size): | |
| micro = { | |
| k: | |
| v[i:i + batch_size].to(dist_util.dev()) if isinstance( | |
| # v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
| v, th.Tensor) else v | |
| for k, v in batch.items() | |
| } | |
| # =================================== ae part =================================== | |
| with th.cuda.amp.autocast(dtype=th.float16, | |
| # enabled=self.mp_trainer.use_amp): | |
| enabled=False): | |
| # and args.train_vae): | |
| loss = th.tensor(0.).to(dist_util.dev()) | |
| # with th.cuda.amp.autocast(dtype=th.float16, | |
| # enabled=False): | |
| # quit amp in encoder, avoid nan. | |
| vae_out = self.ddp_rec_model( | |
| img=micro['img_to_encoder'], | |
| c=micro['c'], | |
| behaviour='encoder_vae', | |
| ) # pred: (B, 3, 64, 64) | |
| eps = vae_out[self.latent_name] | |
| # ! prepare for diffusion | |
| if 'bg_plane' in vae_out: | |
| eps = th.cat((eps, vae_out['bg_plane']), dim=1) # include background, B 12+4 32 32 | |
| if ce_flag: | |
| p_sample_batch = self.prepare_ddpm(eps, 'q') | |
| else: # sgm prior | |
| eps.requires_grad_(True) | |
| p_sample_batch = self.prepare_ddpm(eps, 'p') | |
| # ! running diffusion forward | |
| ddpm_ret = self.apply_model(p_sample_batch) | |
| # p_loss = ddpm_ret['p_eps_objective'] | |
| p_loss = ddpm_ret['p_eps_objective'].mean() | |
| if ce_flag: | |
| cross_entropy = p_loss # why collapse? | |
| normal_entropy = vae_out['posterior'].normal_entropy() | |
| negative_entropy = -normal_entropy * self.entropy_weight(normal_entropy) | |
| ce_loss = (cross_entropy + negative_entropy.mean()) | |
| if self.diffusion_ce_anneal: # gradually add ce lambda | |
| raise NotImplementedError() | |
| diffusion_ce_lambda = kl_coeff( | |
| step=self.step + self.resume_step, | |
| constant_step=5e3, | |
| total_step=20e3, | |
| min_kl_coeff=1e-2, | |
| max_kl_coeff=self.loss_class.opt.negative_entropy_lambda) | |
| ce_loss *= diffusion_ce_lambda | |
| log_rec3d_loss_dict({ | |
| 'diffusion_ce_lambda': diffusion_ce_lambda, | |
| }) | |
| loss += ce_loss | |
| else: | |
| loss += p_loss # p loss | |
| if ce_flag and 'D' in behaviour: # ce only on E | |
| # ===================================================================== | |
| # ! reconstruction loss + gan loss | |
| with th.cuda.amp.autocast(dtype=th.float16, | |
| enabled=False): | |
| # 24GB memory use till now. | |
| cano_pred = self.ddp_rec_model( | |
| latent=vae_out, | |
| c=micro['c'], | |
| behaviour=self.render_latent_behaviour) | |
| with self.ddp_model.no_sync(): # type: ignore | |
| q_vae_recon_loss, loss_dict = self.loss_class( | |
| { | |
| **vae_out, # include latent here. | |
| **cano_pred, | |
| }, | |
| micro, | |
| test_mode=False) | |
| log_rec3d_loss_dict({ | |
| **loss_dict, | |
| 'negative_entropy': negative_entropy.mean(), | |
| }) | |
| loss += q_vae_recon_loss | |
| # save image log | |
| if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
| self.cano_ddpm_log(cano_pred, micro, ddpm_ret) | |
| self.mp_trainer.backward(loss) # grad accumulation | |
| # quit micro | |
| _ = self.mp_trainer.optimize(self.opt, clip_grad=self.loss_class.opt.grad_clip) | |
| class TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD(TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm): | |
| def __init__(self, *, rec_model, denoise_model, diffusion, sde_diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, triplane_scaling_divider=1, use_amp=False, diffusion_input_size=224, init_cvD=False, **kwargs): | |
| super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, sde_diffusion=sde_diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, init_cvD=init_cvD, **kwargs) | |
| def _setup_opt(self): | |
| # TODO, two optims groups. | |
| self.opt = AdamW([{ | |
| 'name': 'ddpm', | |
| 'params': self.ddpm_model.parameters(), | |
| }], | |
| lr=self.lr, | |
| weight_decay=self.weight_decay) | |
| for rec_param_group in self._init_optim_groups(self.rec_model, freeze_decoder=False): | |
| self.opt.add_param_group(rec_param_group) | |
| logger.log(self.opt) | |
| class TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_weightingv0(TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD): | |
| ''' | |
| 1. weight CE with ema(var(eps)), since ce decreases, sigma decreases. | |
| 2. clip entorpy (log sigma) with 0; avoid it form increasing too much | |
| 3. add eps scaling back with ema_rate=0.9999, make sure the std=1. | |
| 4. add grad clipping by default | |
| ''' | |
| def __init__(self, *, rec_model, denoise_model, diffusion, sde_diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, triplane_scaling_divider=1, use_amp=False, diffusion_input_size=224, init_cvD=False, **kwargs): | |
| super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, sde_diffusion=sde_diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, init_cvD=init_cvD, **kwargs) | |
| # for dynamic entropy penalize | |
| self.entropy_const = 0.5 * (np.log(2 * np.pi) + 1) | |
| # self._load_and_sync_parameters | |
| # def _load_model(self): | |
| # # TODO, for currently compatability | |
| # self._load_and_sync_parameters(model=self.model) # load to joint class | |
| # def save(self): | |
| # return super().save() | |
| def prepare_ddpm(self, eps, mode='p'): | |
| log_rec3d_loss_dict( | |
| { | |
| f'unscaled_eps_mean': eps.mean(), | |
| f'unscaled_eps_std': eps.std([1,2,3]).mean(0), | |
| } | |
| ) | |
| scaled_eps = self._standarize(eps) | |
| p_sample_batch = super().prepare_ddpm(scaled_eps, mode) | |
| # update ema; this will not affect the diffusion computation of this batch. | |
| self._update_latent_stat_ema(eps) | |
| return p_sample_batch | |
| def ce_weight(self): | |
| return self.loss_class.opt.ce_lambda * (self.ddpm_model.ema_latent_std.mean().detach()) | |
| # def ce_weight(self): | |
| # return self.loss_class.opt.ce_lambda | |
| def entropy_weight(self, normal_entropy=None): | |
| '''if log(sigma) > 0; stop penalty. | |
| ''' | |
| # basically L1 | |
| negative_entroy_lambda = self.loss_class.opt.negative_entropy_lambda | |
| # return th.where(normal_entropy>self.entropy_const, -negative_entroy_lambda, negative_entroy_lambda) # if log(sigma) > 0, weight = 0. | |
| # return negative_entroy_lambda * (1/self.ddpm_model.ema_latent_std.mean().detach()**2) # if log(sigma) > 0, weight = 0. | |
| return negative_entroy_lambda * (1/self.ddpm_model.ema_latent_std.mean().detach()) # if log(sigma) > 0, weight = 0. | |
| class TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED(TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_weightingv0): | |
| def __init__(self, *, rec_model, denoise_model, diffusion, sde_diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, triplane_scaling_divider=1, use_amp=False, diffusion_input_size=224, init_cvD=False, diffusion_ce_anneal=False, **kwargs): | |
| super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, sde_diffusion=sde_diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, init_cvD=init_cvD, **kwargs) | |
| self.diffusion_ce_anneal = diffusion_ce_anneal | |
| def run_step(self, batch, step='g_step'): | |
| assert step in ['ce', 'ddpm', 'cano_ddpm_only', 'ce_ED', 'ce_E', 'ce_D', 'D', 'ED'] | |
| self.joint_rec_ddpm(batch, step) | |
| self._anneal_lr() | |
| self.log_step() | |
| def run_loop(self): | |
| while (not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps): | |
| batch = self.next_n_batch(n=12) # effective BS=48 | |
| self.run_step(batch, 'ddpm') # ddpm fixed AE | |
| batch = self.next_n_batch(n=3) # effective BS=12 | |
| self.run_step(batch, 'ce_ED') | |
| self._post_run_step() | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| def log_diffusion_images(self, vae_out, p_sample_batch, micro, ddpm_ret): | |
| eps_t_p, t_p, logsnr_p = (p_sample_batch[k] for k in ( | |
| 'eps_t_p', | |
| 't_p', | |
| 'logsnr_p', | |
| )) | |
| pred_eps_p = ddpm_ret['pred_eps_p'] | |
| vae_out.pop('posterior') # for calculating kl loss | |
| vae_out_for_pred = { | |
| k: v[0:1].to(dist_util.dev()) if isinstance(v, th.Tensor) else v | |
| for k, v in vae_out.items() | |
| } | |
| pred = self.ddp_rec_model(latent=vae_out_for_pred, | |
| c=micro['c'][0:1], | |
| behaviour=self.render_latent_behaviour) | |
| assert isinstance(pred, dict) | |
| pred_img = pred['image_raw'] | |
| gt_img = micro['img'] | |
| if 'depth' in micro: | |
| gt_depth = micro['depth'] | |
| if gt_depth.ndim == 3: | |
| gt_depth = gt_depth.unsqueeze(1) | |
| gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
| gt_depth.min()) | |
| else: | |
| gt_depth = th.zeros_like(gt_img[:, 0:1, ...]) | |
| if 'image_depth' in pred: | |
| pred_depth = pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| else: | |
| pred_depth = th.zeros_like(gt_depth) | |
| gt_img = self.pool_128(gt_img) | |
| gt_depth = self.pool_128(gt_depth) | |
| # cond = self.get_c_input(micro) | |
| # hint = th.cat(cond['c_concat'], 1) | |
| gt_vis = th.cat( | |
| [ | |
| gt_img, | |
| gt_img, | |
| # self.pool_128(hint), | |
| gt_img, | |
| gt_depth.repeat_interleave(3, dim=1) | |
| ], | |
| dim=-1)[0:1] # TODO, fail to load depth. range [0, 1] | |
| # eps_t_p_3D = eps_t_p.reshape(batch_size, eps_t_p.shape[1]//3, 3, -1) # B C 3 L | |
| if 'bg_plane' in vae_out: | |
| noised_latent = { | |
| 'latent_normalized_2Ddiffusion': | |
| eps_t_p[0:1, :12] * self.triplane_scaling_divider, | |
| 'bg_plane': | |
| eps_t_p[0:1, 12:16] * self.triplane_scaling_divider, | |
| } | |
| else: | |
| noised_latent = { | |
| 'latent_normalized_2Ddiffusion': | |
| eps_t_p[0:1] * self.triplane_scaling_divider, | |
| } | |
| noised_ae_pred = self.ddp_rec_model( | |
| img=None, | |
| c=micro['c'][0:1], | |
| latent=noised_latent, | |
| # latent=eps_t_p[0:1] * self. | |
| # triplane_scaling_divider, # TODO, how to define the scale automatically | |
| behaviour=self.render_latent_behaviour) | |
| pred_x0 = self.sde_diffusion._predict_x0_from_eps( | |
| eps_t_p, pred_eps_p, logsnr_p) # for VAE loss, denosied latent | |
| if 'bg_plane' in vae_out: | |
| denoised_latent = { | |
| 'latent_normalized_2Ddiffusion': | |
| pred_x0[0:1, :12] * self.triplane_scaling_divider, | |
| 'bg_plane': | |
| pred_x0[0:1, 12:16] * self.triplane_scaling_divider, | |
| } | |
| else: | |
| denoised_latent = { | |
| 'latent_normalized_2Ddiffusion': | |
| pred_x0[0:1] * self.triplane_scaling_divider, | |
| } | |
| # pred_xstart_3D | |
| denoised_ae_pred = self.ddp_rec_model( | |
| img=None, | |
| c=micro['c'][0:1], | |
| latent=denoised_latent, | |
| # latent=pred_x0[0:1] * self. | |
| # triplane_scaling_divider, # TODO, how to define the scale automatically? | |
| behaviour=self.render_latent_behaviour) | |
| pred_vis = th.cat( | |
| [ | |
| self.pool_128(img) for img in ( | |
| pred_img[0:1], | |
| noised_ae_pred['image_raw'][0:1], | |
| denoised_ae_pred['image_raw'][0:1], # controlnet result | |
| pred_depth[0:1].repeat_interleave(3, dim=1)) | |
| ], | |
| dim=-1) # B, 3, H, W | |
| vis = th.cat([gt_vis, pred_vis], | |
| dim=-2)[0].permute(1, 2, | |
| 0).cpu() # ! pred in range[-1, 1] | |
| # vis_grid = torchvision.utils.make_grid(vis) # HWC | |
| vis = vis.numpy() * 127.5 + 127.5 | |
| vis = vis.clip(0, 255).astype(np.uint8) | |
| Image.fromarray(vis).save( | |
| f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}.jpg' | |
| ) | |
| print( | |
| 'log denoised vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t_p[0].item():3}.jpg' | |
| ) | |
| th.cuda.empty_cache() | |
| def log_patch_img(self, micro, pred, pred_cano): | |
| # gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
| def norm_depth(pred_depth): # to [-1,1] | |
| # pred_depth = pred['image_depth'] | |
| pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
| pred_depth.min()) | |
| return -(pred_depth * 2 - 1) | |
| pred_img = pred['image_raw'] | |
| gt_img = micro['img'] | |
| # infer novel view also | |
| # if self.loss_class.opt.symmetry_loss: | |
| # pred_nv_img = nvs_pred | |
| # else: | |
| # ! replace with novel view prediction | |
| # ! log another novel-view prediction | |
| # pred_nv_img = self.rec_model( | |
| # img=micro['img_to_encoder'], | |
| # c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
| # if 'depth' in micro: | |
| gt_depth = micro['depth'] | |
| if gt_depth.ndim == 3: | |
| gt_depth = gt_depth.unsqueeze(1) | |
| gt_depth = norm_depth(gt_depth) | |
| # gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
| # gt_depth.min()) | |
| # if True: | |
| fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
| input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 | |
| if 'image_depth' in pred: | |
| pred_depth = norm_depth(pred['image_depth']) | |
| pred_nv_depth = norm_depth(pred_cano['image_depth']) | |
| else: | |
| pred_depth = th.zeros_like(gt_depth) | |
| pred_nv_depth = th.zeros_like(gt_depth) | |
| # if 'image_sr' in pred: | |
| # if pred['image_sr'].shape[-1] == 512: | |
| # pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], | |
| # dim=-1) | |
| # gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
| # dim=-1) | |
| # pred_depth = self.pool_512(pred_depth) | |
| # gt_depth = self.pool_512(gt_depth) | |
| # elif pred['image_sr'].shape[-1] == 256: | |
| # pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], | |
| # dim=-1) | |
| # gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
| # dim=-1) | |
| # pred_depth = self.pool_256(pred_depth) | |
| # gt_depth = self.pool_256(gt_depth) | |
| # else: | |
| # pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], | |
| # dim=-1) | |
| # gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
| # dim=-1) | |
| # gt_depth = self.pool_128(gt_depth) | |
| # pred_depth = self.pool_128(pred_depth) | |
| # else: | |
| # gt_img = self.pool_64(gt_img) | |
| # gt_depth = self.pool_64(gt_depth) | |
| pred_vis = th.cat([ | |
| pred_img, | |
| pred_depth.repeat_interleave(3, dim=1), | |
| fg_mask.repeat_interleave(3, dim=1), | |
| ], | |
| dim=-1) # B, 3, H, W | |
| pred_vis_nv = th.cat([ | |
| pred_cano['image_raw'], | |
| pred_nv_depth.repeat_interleave(3, dim=1), | |
| input_fg_mask.repeat_interleave(3, dim=1), | |
| ], | |
| dim=-1) # B, 3, H, W | |
| pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim | |
| gt_vis = th.cat([ | |
| gt_img, | |
| gt_depth.repeat_interleave(3, dim=1), | |
| th.zeros_like(gt_img) | |
| ], | |
| dim=-1) # TODO, fail to load depth. range [0, 1] | |
| # if 'conf_sigma' in pred: | |
| # gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder | |
| # vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
| # st() | |
| vis = th.cat([gt_vis, pred_vis], dim=-2) | |
| # .permute( | |
| # 0, 2, 3, 1).cpu() | |
| vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // | |
| 64) # HWC | |
| torchvision.utils.save_image( | |
| vis_tensor, | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
| value_range=(-1, 1), | |
| normalize=True) | |
| logger.log('log vis to: ', | |
| f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
| # self.writer.add_image(f'images', | |
| # vis, | |
| # self.step + self.resume_step, | |
| # dataformats='HWC') | |
| class TrainLoop3D_LDM(TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED): | |
| def __init__(self, *, rec_model, denoise_model, diffusion, sde_diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, triplane_scaling_divider=1, use_amp=False, diffusion_input_size=224, init_cvD=False, diffusion_ce_anneal=False, **kwargs): | |
| super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, sde_diffusion=sde_diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, init_cvD=init_cvD, diffusion_ce_anneal=diffusion_ce_anneal, **kwargs) | |
| def run_loop(self): | |
| while (not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps): | |
| batch = self.next_n_batch(n=2) # effective BS=64, micro=4, 30.7gib | |
| self.run_step(batch, 'ddpm') # ddpm fixed AE | |
| # batch = self.next_n_batch(n=1) # | |
| # self.run_step(batch, 'ce_ED') | |
| self._post_run_step() | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| class TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED_nv(TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED): | |
| # reconstruction function from train_nv_util.py | |
| def __init__(self, *, rec_model, denoise_model, diffusion, sde_diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, triplane_scaling_divider=1, use_amp=False, diffusion_input_size=224, init_cvD=False, diffusion_ce_anneal=False, **kwargs): | |
| super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, sde_diffusion=sde_diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, init_cvD=init_cvD, diffusion_ce_anneal=diffusion_ce_anneal, **kwargs) | |
| # ! for rendering | |
| self.eg3d_model = self.rec_model.decoder.triplane_decoder # type: ignore | |
| self.renderdiff_loss = False # whether to render denoised latent for reconstruction loss | |
| # self.inner_loop_k = 2 | |
| # self.ce_d_loop_k = 6 | |
| def run_loop(self): | |
| while (not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps): | |
| batch = self.next_n_batch(n=2) # effective BS=2*8 | |
| self.run_step(batch, 'ddpm') | |
| # if self.step % self.inner_loop_k == 1: # train E per 2 steps | |
| batch = next(self.data) # sample a new batch for rec training | |
| # self.run_step(self.subset_batch(batch, micro_batchsize=6, big_endian=False), 'ce_ED') # freeze D, train E with diffusion prior | |
| # self.run_step(batch, 'ce_ED') # | |
| self.run_step(batch, 'ce_E') # | |
| # if self.step % self.ce_d_loop_k == 1: # train D per 4 steps | |
| # batch = next(self.data) # sample a new batch for rec training | |
| # self.run_step(self.subset_batch(batch, micro_batchsize=4, big_endian=True), 'ED') # freeze E, train D | |
| self._post_run_step() | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| # ddpm + rec loss | |
| def joint_rec_ddpm(self, batch, behaviour='ddpm', *args, **kwargs): | |
| """ | |
| add sds grad to all ae predicted x_0 | |
| """ | |
| args = self.sde_diffusion.args | |
| # ! enable the gradient of both models | |
| # requires_grad(self.rec_model, True) | |
| # if behaviour == 'ce': # ll sampling? later. train encoder. | |
| ce_flag = False | |
| diffusion_flag = True | |
| if 'ce' in behaviour: # ll sampling? later. train encoder. | |
| ############################################## | |
| ###### Update the VAE encoder/decoder ######## | |
| ############################################## | |
| requires_grad(self.ddpm_model, False) | |
| self.ddpm_model.eval() | |
| ce_flag = True | |
| if behaviour == 'ce_E': # unfreeze E and freeze D | |
| requires_grad(self.rec_model.encoder, True) | |
| self.rec_model.encoder.train() | |
| requires_grad(self.rec_model.decoder, False) | |
| self.rec_model.decoder.eval() | |
| elif behaviour == 'ce_D': # unfreeze E and freeze D | |
| requires_grad(self.rec_model.encoder, False) | |
| self.rec_model.encoder.eval() | |
| requires_grad(self.rec_model.decoder, True) | |
| self.rec_model.decoder.train() | |
| else: # train all, may oom | |
| requires_grad(self.rec_model, True) | |
| self.rec_model.train() | |
| elif behaviour == 'ED': # just train E and D | |
| diffusion_flag = False | |
| requires_grad(self.ddpm_model, False) | |
| self.ddpm_model.eval() | |
| requires_grad(self.rec_model, True) | |
| self.rec_model.train() | |
| elif behaviour == 'D': | |
| diffusion_flag = False | |
| requires_grad(self.rec_model.encoder, False) | |
| self.rec_model.encoder.eval() | |
| requires_grad(self.rec_model.decoder, True) | |
| self.rec_model.decoder.train() | |
| else: # train ddpm. | |
| # self.flip_encoder_grad(False) | |
| requires_grad(self.rec_model, False) | |
| self.rec_model.eval() | |
| requires_grad(self.ddpm_model, True) | |
| self.ddpm_model.train() | |
| self.mp_trainer.zero_grad() | |
| assert args.train_vae | |
| batch_size = batch['img'].shape[0] | |
| # for i in range(0, batch_size, self.microbatch): | |
| for i in range(0, batch_size, batch_size): | |
| micro = { | |
| k: v[i:i + self.microbatch].to(dist_util.dev()) | |
| for k, v in batch.items() | |
| } | |
| # ! sample rendering patch | |
| target = { | |
| **self.eg3d_model( | |
| c=micro['nv_c'], # type: ignore | |
| ws=None, | |
| planes=None, | |
| sample_ray_only=True, | |
| fg_bbox=micro['nv_bbox']), # rays o / dir | |
| } | |
| patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
| 'patch_rendering_resolution'] # type: ignore | |
| cropped_target = { | |
| k: th.empty_like(v) | |
| [..., :patch_rendering_resolution, :patch_rendering_resolution] | |
| if k not in [ | |
| 'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
| 'nv_img_sr', 'c' | |
| ] else v | |
| for k, v in micro.items() | |
| } | |
| # crop according to uv sampling | |
| for j in range(micro['img'].shape[0]): | |
| top, left, height, width = target['ray_bboxes'][ | |
| j] # list of tuple | |
| # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
| for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
| # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
| # cropped_target[key][ | |
| # j:j + 1] = torchvision.transforms.functional.crop( | |
| # micro[key][j:j + 1], top, left, height, width) | |
| cropped_target[f'{key}'][ # ! no nv_ here | |
| j:j + 1] = torchvision.transforms.functional.crop( | |
| micro[f'nv_{key}'][j:j + 1], top, left, height, | |
| width) | |
| # ! cano view loss | |
| cano_target = { | |
| **self.eg3d_model( | |
| c=micro['c'], # type: ignore | |
| ws=None, | |
| planes=None, | |
| sample_ray_only=True, | |
| fg_bbox=micro['bbox']), # rays o / dir | |
| } | |
| cano_cropped_target = { | |
| k: th.empty_like(v) | |
| for k, v in cropped_target.items() | |
| } | |
| for j in range(micro['img'].shape[0]): | |
| top, left, height, width = cano_target['ray_bboxes'][ | |
| j] # list of tuple | |
| # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
| for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
| # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
| cano_cropped_target[key][ | |
| j:j + 1] = torchvision.transforms.functional.crop( | |
| micro[key][j:j + 1], top, left, height, width) | |
| # =================================== ae part =================================== | |
| with th.cuda.amp.autocast(dtype=th.float16, | |
| # enabled=self.mp_trainer.use_amp): | |
| enabled=False): | |
| # and args.train_vae): | |
| loss = th.tensor(0.).to(dist_util.dev()) | |
| # with th.cuda.amp.autocast(dtype=th.float16, | |
| # enabled=False): | |
| # quit amp in encoder, avoid nan. | |
| vae_out = self.ddp_rec_model( | |
| img=micro['img_to_encoder'], | |
| c=micro['c'], | |
| behaviour='encoder_vae', | |
| ) # pred: (B, 3, 64, 64) | |
| if diffusion_flag: | |
| eps = vae_out[self.latent_name] # 12542mib, bs=4 | |
| # ''' | |
| # ! prepare for diffusion | |
| if 'bg_plane' in vae_out: | |
| eps = th.cat((eps, vae_out['bg_plane']), dim=1) # include background, B 12+4 32 32 | |
| if ce_flag: | |
| p_sample_batch = self.prepare_ddpm(eps, 'q') | |
| else: | |
| eps.requires_grad_(True) | |
| p_sample_batch = self.prepare_ddpm(eps, 'p') | |
| # ! running diffusion forward | |
| ddpm_ret = self.apply_model(p_sample_batch) | |
| # p_loss = ddpm_ret['p_eps_objective'] | |
| p_loss = ddpm_ret['p_eps_objective'].mean() | |
| # st() # 12890mib | |
| if ce_flag: | |
| cross_entropy = p_loss # why collapse? | |
| normal_entropy = vae_out['posterior'].normal_entropy() | |
| entropy_weight = self.entropy_weight(normal_entropy) | |
| negative_entropy = -normal_entropy * entropy_weight | |
| ce_loss = (cross_entropy + negative_entropy.mean()) | |
| # if self.diffusion_ce_anneal: # gradually add ce lambda | |
| # diffusion_ce_lambda = kl_coeff( | |
| # step=self.step + self.resume_step, | |
| # constant_step=5e3+self.resume_step, | |
| # total_step=25e3, | |
| # min_kl_coeff=1e-5, | |
| # max_kl_coeff=self.loss_class.opt.negative_entropy_lambda) | |
| # # diffusion_ce_lambda = th.tensor(1e-5).to(dist_util.dev()) | |
| # ce_loss *= diffusion_ce_lambda | |
| log_rec3d_loss_dict({ | |
| # 'diffusion_ce_lambda': diffusion_ce_lambda, | |
| 'negative_entropy': negative_entropy.mean(), | |
| 'entropy_weight': entropy_weight, | |
| 'ce_loss': ce_loss | |
| }) | |
| loss += ce_loss | |
| else: | |
| loss += p_loss # p loss | |
| # ! do reconstruction supervision | |
| # ''' | |
| if ce_flag or not diffusion_flag: # vae part | |
| latent_to_decode = vae_out | |
| else: | |
| latent_to_decode = { # diffusion part | |
| self.latent_name: ddpm_ret['pred_x0_p'] | |
| } # render denoised latent | |
| # with th.cuda.amp.autocast(dtype=th.float16, | |
| # enabled=False): | |
| # st() | |
| if ce_flag or self.renderdiff_loss or not diffusion_flag: | |
| # ! do vae latent -> triplane decode | |
| latent_to_decode.update(self.ddp_rec_model(latent=latent_to_decode, behaviour='decode_after_vae_no_render')) # triplane, 19mib bs=4 | |
| # ! do render | |
| # st() | |
| pred_nv_cano = self.ddp_rec_model( # 24gb, bs=4 | |
| # latent=latent.expand(2,), | |
| latent={ | |
| 'latent_after_vit': # ! triplane for rendering | |
| latent_to_decode['latent_after_vit'].repeat(2, 1, 1, 1) | |
| }, | |
| c=th.cat([micro['nv_c'], | |
| micro['c']]), # predict novel view here | |
| behaviour='triplane_dec', | |
| # ray_origins=target['ray_origins'], | |
| # ray_directions=target['ray_directions'], | |
| ray_origins=th.cat( | |
| [target['ray_origins'], cano_target['ray_origins']], | |
| 0), | |
| ray_directions=th.cat([ | |
| target['ray_directions'], cano_target['ray_directions'] | |
| ]), | |
| ) | |
| pred_nv_cano.update({ # for kld | |
| 'posterior': vae_out['posterior'], | |
| 'latent_normalized_2Ddiffusion': vae_out['latent_normalized_2Ddiffusion'] | |
| }) | |
| # ! 2D loss | |
| with self.ddp_model.no_sync(): # type: ignore | |
| loss_rec, loss_rec_dict, _ = self.loss_class( | |
| pred_nv_cano, | |
| { | |
| k: th.cat([v, cano_cropped_target[k]], 0) | |
| for k, v in cropped_target.items() | |
| }, # prepare merged data | |
| step=self.step + self.resume_step, | |
| test_mode=False, | |
| return_fg_mask=True, | |
| conf_sigma_l1=None, | |
| conf_sigma_percl=None) | |
| if diffusion_flag and not ce_flag: | |
| prefix = 'denoised_' | |
| else: | |
| prefix = '' | |
| log_rec3d_loss_dict({ | |
| f'{prefix}{k}': v for k, v in loss_rec_dict.items() | |
| }) | |
| loss += loss_rec # l2, LPIPS, Alpha loss | |
| # save image log | |
| # if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
| # self.cano_ddpm_log(cano_pred, micro, ddpm_ret) | |
| self.mp_trainer.backward(loss) # grad accumulation, 27gib | |
| # st() | |
| # for name, p in self.model.named_parameters(): | |
| # if p.grad is None: | |
| # logger.log(f"found rec unused param: {name}") | |
| # _ = self.mp_trainer.optimize(self.opt, clip_grad=self.loss_class.opt.grad_clip) | |
| _ = self.mp_trainer.optimize(self.opt, clip_grad=True) | |
| if dist_util.get_rank() == 0: | |
| if self.step % 500 == 0: # log diffusion | |
| self.log_diffusion_images(vae_out, p_sample_batch, micro, ddpm_ret) | |
| elif self.step % 500 == 1 and ce_flag: # log reconstruction | |
| # st() | |
| micro_bs = micro['img_to_encoder'].shape[0] | |
| self.log_patch_img( | |
| cropped_target, | |
| { | |
| k: pred_nv_cano[k][:micro_bs] | |
| for k in ['image_raw', 'image_depth', 'image_mask'] | |
| }, | |
| { | |
| k: pred_nv_cano[k][micro_bs:] | |
| for k in ['image_raw', 'image_depth', 'image_mask'] | |
| }, | |
| ) | |
| def _init_optim_groups(self, rec_model, freeze_decoder=False): | |
| # unfreeze decoder when scaling is enabled | |
| return super()._init_optim_groups(rec_model, freeze_decoder=True) | |
| # class TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED_nv_noCE(TrainLoop3DDiffusionLSGM_cvD_scaling_lsgm_unfreezeD_iterativeED_nv): | |
| # """no sepatate CE schedule, use single schedule for joint ddpm/nv-rec training with entropy regularization | |
| # """ | |
| # def __init__(self, *, rec_model, denoise_model, diffusion, sde_diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, triplane_scaling_divider=1, use_amp=False, diffusion_input_size=224, init_cvD=False, diffusion_ce_anneal=False, **kwargs): | |
| # super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, sde_diffusion=sde_diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, init_cvD=init_cvD, diffusion_ce_anneal=diffusion_ce_anneal, **kwargs) | |
| # def run_loop(self): | |
| # while (not self.lr_anneal_steps | |
| # or self.step + self.resume_step < self.lr_anneal_steps): | |
| # batch = self.next_n_batch(n=2) # effective BS=2*8 | |
| # self.run_step(batch, 'ddpm') | |
| # # if self.step % self.inner_loop_k == 1: # train E per 2 steps | |
| # batch = next(self.data) # sample a new batch for rec training | |
| # self.run_step(self.subset_batch(batch, micro_batchsize=6, big_endian=False), 'ce_ED') # freeze D, train E with diffusion prior | |
| # self._post_run_step() | |
| # # Save the last checkpoint if it wasn't already saved. | |
| # if (self.step - 1) % self.save_interval != 0: | |
| # self.save() |