Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn | |
| from nsr.triplane import Triplane_fg_bg_plane | |
| # import timm | |
| from vit.vit_triplane import Triplane, ViTTriplaneDecomposed | |
| import argparse | |
| import inspect | |
| import dnnlib | |
| from guided_diffusion import dist_util | |
| from pdb import set_trace as st | |
| import vit.vision_transformer as vits | |
| from guided_diffusion import logger | |
| from .confnet import ConfNet | |
| from ldm.modules.diffusionmodules.model import Encoder, MVEncoder, MVEncoderGS, MVEncoderGSDynamicInp | |
| from ldm.modules.diffusionmodules.mv_unet import MVUNet, LGM_MVEncoder | |
| # from ldm.modules.diffusionmodules.openaimodel import MultiViewUNetModel_Encoder | |
| # * create pre-trained encoder & triplane / other nsr decoder | |
| class AE(torch.nn.Module): | |
| def __init__(self, | |
| encoder, | |
| decoder, | |
| img_size, | |
| encoder_cls_token, | |
| decoder_cls_token, | |
| preprocess, | |
| use_clip, | |
| dino_version='v1', | |
| clip_dtype=None, | |
| no_dim_up_mlp=False, | |
| dim_up_mlp_as_func=False, | |
| uvit_skip_encoder=False, | |
| confnet=None) -> None: | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.img_size = img_size | |
| self.encoder_cls_token = encoder_cls_token | |
| self.decoder_cls_token = decoder_cls_token | |
| self.use_clip = use_clip | |
| self.dino_version = dino_version | |
| self.confnet = confnet | |
| if self.dino_version == 'v2': | |
| self.encoder.mask_token = None | |
| self.decoder.vit_decoder.mask_token = None | |
| if 'sd' not in self.dino_version: | |
| self.uvit_skip_encoder = uvit_skip_encoder | |
| if uvit_skip_encoder: | |
| logger.log( | |
| f'enables uvit: length of vit_encoder.blocks: {len(self.encoder.blocks)}' | |
| ) | |
| for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]: | |
| blk.skip_linear = nn.Linear(2 * self.encoder.embed_dim, | |
| self.encoder.embed_dim) | |
| # trunc_normal_(blk.skip_linear.weight, std=.02) | |
| nn.init.constant_(blk.skip_linear.weight, 0) | |
| if isinstance( | |
| blk.skip_linear, | |
| nn.Linear) and blk.skip_linear.bias is not None: | |
| nn.init.constant_(blk.skip_linear.bias, 0) | |
| else: | |
| logger.log(f'disable uvit') | |
| else: | |
| if 'dit' not in self.dino_version: # dino vit, not dit | |
| self.decoder.vit_decoder.cls_token = None | |
| self.decoder.vit_decoder.patch_embed.proj = nn.Identity() | |
| self.decoder.triplane_decoder.planes = None | |
| self.decoder.vit_decoder.mask_token = None | |
| if self.use_clip: | |
| self.clip_dtype = clip_dtype # torch.float16 | |
| else: | |
| if not no_dim_up_mlp and self.encoder.embed_dim != self.decoder.vit_decoder.embed_dim: | |
| self.dim_up_mlp = nn.Linear( | |
| self.encoder.embed_dim, | |
| self.decoder.vit_decoder.embed_dim) | |
| logger.log( | |
| f"dim_up_mlp: {self.encoder.embed_dim} -> {self.decoder.vit_decoder.embed_dim}, as_func: {self.dim_up_mlp_as_func}" | |
| ) | |
| else: | |
| logger.log('ignore dim_up_mlp: ', no_dim_up_mlp) | |
| self.preprocess = preprocess | |
| self.dim_up_mlp = None # CLIP/B-16 | |
| self.dim_up_mlp_as_func = dim_up_mlp_as_func | |
| # * remove certain components to make sure no unused parameters during DDP | |
| # self.decoder.vit_decoder.cls_token = nn.Identity() | |
| torch.cuda.empty_cache() | |
| # self.decoder.vit_decoder.patch_embed.proj.bias = nn.Identity() | |
| # self.decoder.vit_decoder.patch_embed.proj.weight = nn.Identity() | |
| # self.decoder.vit_decoder.patch_embed.proj.bias = nn.Identity() | |
| def encode(self, *args, **kwargs): | |
| if not self.use_clip: | |
| if self.dino_version == 'v1': | |
| latent = self.encode_dinov1(*args, **kwargs) | |
| elif self.dino_version == 'v2': | |
| if self.uvit_skip_encoder: | |
| latent = self.encode_dinov2_uvit(*args, **kwargs) | |
| else: | |
| latent = self.encode_dinov2(*args, **kwargs) | |
| else: | |
| latent = self.encoder(*args) | |
| else: | |
| latent = self.encode_clip(*args, **kwargs) | |
| return latent | |
| def encode_dinov1(self, x): | |
| # return self.encoder(img) | |
| x = self.encoder.prepare_tokens(x) | |
| for blk in self.encoder.blocks: | |
| x = blk(x) | |
| x = self.encoder.norm(x) | |
| if not self.encoder_cls_token: | |
| return x[:, 1:] | |
| return x | |
| def encode_dinov2(self, x): | |
| # return self.encoder(img) | |
| x = self.encoder.prepare_tokens_with_masks(x, masks=None) | |
| for blk in self.encoder.blocks: | |
| x = blk(x) | |
| x_norm = self.encoder.norm(x) | |
| if not self.encoder_cls_token: | |
| return x_norm[:, 1:] | |
| # else: | |
| # return x_norm[:, :1] | |
| # return { | |
| # "x_norm_clstoken": x_norm[:, 0], | |
| # "x_norm_patchtokens": x_norm[:, 1:], | |
| # } | |
| return x_norm | |
| def encode_dinov2_uvit(self, x): | |
| # return self.encoder(img) | |
| x = self.encoder.prepare_tokens_with_masks(x, masks=None) | |
| # for blk in self.encoder.blocks: | |
| # x = blk(x) | |
| skips = [x] | |
| # in blks | |
| for blk in self.encoder.blocks[0:len(self.encoder.blocks) // 2 - 1]: | |
| x = blk(x) # B 3 N C | |
| skips.append(x) | |
| # mid blks | |
| for blk in self.encoder.blocks[len(self.encoder.blocks) // 2 - | |
| 1:len(self.encoder.blocks) // 2]: | |
| x = blk(x) # B 3 N C | |
| # out blks | |
| for blk in self.encoder.blocks[len(self.encoder.blocks) // 2:]: | |
| x = x + blk.skip_linear(torch.cat( | |
| [x, skips.pop()], dim=-1)) # long skip connections in uvit | |
| x = blk(x) # B 3 N C | |
| x_norm = self.encoder.norm(x) | |
| if not self.decoder_cls_token: | |
| return x_norm[:, 1:] | |
| return x_norm | |
| def encode_clip(self, x): | |
| # * replace with CLIP encoding pipeline | |
| # return self.encoder(img) | |
| # x = x.dtype(self.clip_dtype) | |
| x = self.encoder.conv1(x) # shape = [*, width, grid, grid] | |
| x = x.reshape(x.shape[0], x.shape[1], | |
| -1) # shape = [*, width, grid ** 2] | |
| x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
| x = torch.cat([ | |
| self.encoder.class_embedding.to(x.dtype) + torch.zeros( | |
| x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x | |
| ], | |
| dim=1) # shape = [*, grid ** 2 + 1, width] | |
| x = x + self.encoder.positional_embedding.to(x.dtype) | |
| x = self.encoder.ln_pre(x) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.encoder.transformer(x) | |
| x = x.permute(1, 0, 2) # LND -> NLD | |
| x = self.encoder.ln_post(x[:, 1:, :]) # * return the spatial tokens | |
| return x | |
| # x = self.ln_post(x[:, 0, :]) # * return the spatial tokens | |
| # if self.proj is not None: | |
| # x = x @ self.proj | |
| # return x | |
| def decode_wo_triplane(self, latent, c=None, img_size=None): | |
| if img_size is None: | |
| img_size = self.img_size | |
| if self.dim_up_mlp is not None: | |
| if not self.dim_up_mlp_as_func: | |
| latent = self.dim_up_mlp(latent) | |
| # return self.decoder.vit_decode(latent, img_size) | |
| else: | |
| return self.decoder.vit_decode( | |
| latent, img_size, | |
| dim_up_mlp=self.dim_up_mlp) # used in vae-ldm | |
| return self.decoder.vit_decode(latent, img_size, c=c) | |
| def decode(self, latent, c, img_size=None, return_raw_only=False): | |
| # if img_size is None: | |
| # img_size = self.img_size | |
| # if self.dim_up_mlp is not None: | |
| # latent = self.dim_up_mlp(latent) | |
| latent = self.decode_wo_triplane(latent, img_size=img_size, c=c) | |
| # return self.decoder.triplane_decode(latent, c, return_raw_only=return_raw_only) | |
| return self.decoder.triplane_decode(latent, c) | |
| def decode_after_vae_no_render( | |
| self, | |
| ret_dict, | |
| img_size=None, | |
| ): | |
| if img_size is None: | |
| img_size = self.img_size | |
| assert self.dim_up_mlp is None | |
| # if not self.dim_up_mlp_as_func: | |
| # latent = self.dim_up_mlp(latent) | |
| # return self.decoder.vit_decode(latent, img_size) | |
| latent = self.decoder.vit_decode_backbone(ret_dict, img_size) | |
| ret_dict = self.decoder.vit_decode_postprocess(latent, ret_dict) | |
| return ret_dict | |
| def decode_after_vae( | |
| self, | |
| # latent, | |
| ret_dict, # vae_dict | |
| c, | |
| img_size=None, | |
| return_raw_only=False): | |
| ret_dict = self.decode_after_vae_no_render(ret_dict, img_size) | |
| return self.decoder.triplane_decode(ret_dict, c) | |
| def decode_confmap(self, img): | |
| assert self.confnet is not None | |
| # https://github.com/elliottwu/unsup3d/blob/dc961410d61684561f19525c2f7e9ee6f4dacb91/unsup3d/model.py#L152 | |
| # conf_sigma_l1 = self.confnet(img) # Bx2xHxW | |
| return self.confnet(img) # Bx1xHxW | |
| def encode_decode(self, img, c, return_raw_only=False): | |
| latent = self.encode(img) | |
| pred = self.decode(latent, c, return_raw_only=return_raw_only) | |
| if self.confnet is not None: | |
| pred.update({ | |
| 'conf_sigma': self.decode_confmap(img) # 224x224 | |
| }) | |
| return pred | |
| def forward(self, | |
| img=None, | |
| c=None, | |
| latent=None, | |
| behaviour='enc_dec', | |
| coordinates=None, | |
| directions=None, | |
| return_raw_only=False, | |
| *args, | |
| **kwargs): | |
| """wrap all operations inside forward() for DDP use. | |
| """ | |
| if behaviour == 'enc_dec': | |
| pred = self.encode_decode(img, c, return_raw_only=return_raw_only) | |
| return pred | |
| elif behaviour == 'enc': | |
| latent = self.encode(img) | |
| return latent | |
| elif behaviour == 'dec': | |
| assert latent is not None | |
| pred: dict = self.decode(latent, | |
| c, | |
| self.img_size, | |
| return_raw_only=return_raw_only) | |
| return pred | |
| elif behaviour == 'dec_wo_triplane': | |
| assert latent is not None | |
| pred: dict = self.decode_wo_triplane(latent, self.img_size) | |
| return pred | |
| elif behaviour == 'enc_dec_wo_triplane': | |
| latent = self.encode(img) | |
| pred: dict = self.decode_wo_triplane(latent, img_size=self.img_size, c=c) | |
| return pred | |
| elif behaviour == 'encoder_vae': | |
| latent = self.encode(img) | |
| ret_dict = self.decoder.vae_reparameterization(latent, True) | |
| return ret_dict | |
| elif behaviour == 'decode_after_vae_no_render': | |
| pred: dict = self.decode_after_vae_no_render(latent, self.img_size) | |
| return pred | |
| elif behaviour == 'decode_after_vae': | |
| pred: dict = self.decode_after_vae(latent, c, self.img_size) | |
| return pred | |
| # elif behaviour == 'gaussian_dec': | |
| # assert latent is not None | |
| # pred: dict = self.decoder.triplane_decode( | |
| # latent, c, return_raw_only=return_raw_only, **kwargs) | |
| # # pred: dict = self.decoder.triplane_decode(latent, c) | |
| elif behaviour == 'triplane_dec': | |
| assert latent is not None | |
| pred: dict = self.decoder.triplane_decode( | |
| latent, c, return_raw_only=return_raw_only, **kwargs) | |
| # pred: dict = self.decoder.triplane_decode(latent, c) | |
| elif behaviour == 'triplane_decode_grid': | |
| assert latent is not None | |
| pred: dict = self.decoder.triplane_decode_grid( | |
| latent, **kwargs) | |
| # pred: dict = self.decoder.triplane_decode(latent, c) | |
| elif behaviour == 'vit_postprocess_triplane_dec': | |
| assert latent is not None | |
| latent = self.decoder.vit_decode_postprocess( | |
| latent) # translate spatial token from vit-decoder into 2D | |
| pred: dict = self.decoder.triplane_decode( | |
| latent, c) # render with triplane | |
| elif behaviour == 'triplane_renderer': | |
| assert latent is not None | |
| pred: dict = self.decoder.triplane_renderer( | |
| latent, coordinates, directions) | |
| # elif behaviour == 'triplane_SR': | |
| # assert latent is not None | |
| # pred: dict = self.decoder.triplane_renderer( | |
| # latent, coordinates, directions) | |
| elif behaviour == 'get_rendering_kwargs': | |
| pred = self.decoder.triplane_decoder.rendering_kwargs | |
| return pred | |
| class AE_CLIPEncoder(AE): | |
| def __init__(self, encoder, decoder, img_size, cls_token) -> None: | |
| super().__init__(encoder, decoder, img_size, cls_token) | |
| class AE_with_Diffusion(torch.nn.Module): | |
| def __init__(self, auto_encoder, denoise_model) -> None: | |
| super().__init__() | |
| self.auto_encoder = auto_encoder | |
| self.denoise_model = denoise_model # simply for easy MPTrainer manipulation | |
| def forward(self, | |
| img, | |
| c, | |
| behaviour='enc_dec', | |
| latent=None, | |
| *args, | |
| **kwargs): | |
| # wrap auto_encoder and denoising model inside a single forward function to use DDP (only forward supported) and MPTrainer (single model) easier | |
| if behaviour == 'enc_dec': | |
| pred = self.auto_encoder(img, c) | |
| return pred | |
| elif behaviour == 'enc': | |
| latent = self.auto_encoder.encode(img) | |
| if self.auto_encoder.dim_up_mlp is not None: | |
| latent = self.auto_encoder.dim_up_mlp(latent) | |
| return latent | |
| elif behaviour == 'dec': | |
| assert latent is not None | |
| pred: dict = self.auto_encoder.decode(latent, c, self.img_size) | |
| return pred | |
| elif behaviour == 'denoise': | |
| assert latent is not None | |
| pred: dict = self.denoise_model(*args, **kwargs) | |
| return pred | |
| def eg3d_options_default(): | |
| opts = dnnlib.EasyDict( | |
| dict( | |
| cbase=32768, | |
| cmax=512, | |
| map_depth=2, | |
| g_class_name='nsr.triplane.TriPlaneGenerator', # TODO | |
| g_num_fp16_res=0, | |
| )) | |
| return opts | |
| def rendering_options_defaults(opts): | |
| rendering_options = { | |
| # 'image_resolution': c.training_set_kwargs.resolution, | |
| 'image_resolution': 256, | |
| 'disparity_space_sampling': False, | |
| 'clamp_mode': 'softplus', | |
| 'c_gen_conditioning_zero': | |
| True, # if true, fill generator pose conditioning label with dummy zero vector | |
| # 'gpc_reg_prob': opts.gpc_reg_prob if opts.gen_pose_cond else None, | |
| 'c_scale': | |
| opts.c_scale, # mutliplier for generator pose conditioning label | |
| 'superresolution_noise_mode': 'none', | |
| 'density_reg': opts.density_reg, # strength of density regularization | |
| 'density_reg_p_dist': opts. | |
| density_reg_p_dist, # distance at which to sample perturbed points for density regularization | |
| 'reg_type': opts. | |
| reg_type, # for experimenting with variations on density regularization | |
| 'decoder_lr_mul': 1, | |
| # opts.decoder_lr_mul, # learning rate multiplier for decoder | |
| 'decoder_activation': 'sigmoid', | |
| 'sr_antialias': True, | |
| 'return_triplane_features': False, # for DDF supervision | |
| 'return_sampling_details_flag': False, | |
| # * shape default sr | |
| # 'superresolution_module': 'nsr.superresolution.SuperresolutionHybrid4X', | |
| # 'superresolution_module': | |
| # 'utils.torch_utils.components.PixelUnshuffleUpsample', | |
| 'superresolution_module': 'utils.torch_utils.components.NearestConvSR', | |
| } | |
| if opts.cfg == 'ffhq': | |
| rendering_options.update({ | |
| 'superresolution_module': | |
| 'nsr.superresolution.SuperresolutionHybrid8XDC', | |
| 'focal': 2985.29 / 700, | |
| 'depth_resolution': | |
| 48 - 0, # number of uniform samples to take per ray. | |
| 'depth_resolution_importance': | |
| 48 - 0, # number of importance samples to take per ray. | |
| 'bg_depth_resolution': | |
| 16, # 4/14 in stylenerf, https://github.com/facebookresearch/StyleNeRF/blob/7f5610a058f27fcc360c6b972181983d7df794cb/conf/model/stylenerf_ffhq.yaml#L48 | |
| 'ray_start': | |
| 2.25, # near point along each ray to start taking samples. | |
| 'ray_end': | |
| 3.3, # far point along each ray to stop taking samples. | |
| 'box_warp': | |
| 1, # the side-length of the bounding box spanned by the tri-planes; box_warp=1 means [-0.5, -0.5, -0.5] -> [0.5, 0.5, 0.5]. | |
| 'avg_camera_radius': | |
| 2.7, # used only in the visualizer to specify camera orbit radius. | |
| 'avg_camera_pivot': [ | |
| 0, 0, 0.2 | |
| ], # used only in the visualizer to control center of camera rotation. | |
| 'superresolution_noise_mode': 'random', | |
| }) | |
| elif opts.cfg == 'afhq': | |
| rendering_options.update({ | |
| 'superresolution_module': | |
| 'nsr.superresolution.SuperresolutionHybrid8X', | |
| 'superresolution_noise_mode': 'random', | |
| 'focal': 4.2647, | |
| 'depth_resolution': 48, | |
| 'depth_resolution_importance': 48, | |
| 'ray_start': 2.25, | |
| 'ray_end': 3.3, | |
| 'box_warp': 1, | |
| 'avg_camera_radius': 2.7, | |
| 'avg_camera_pivot': [0, 0, -0.06], | |
| }) | |
| elif opts.cfg == 'shapenet': # TODO, lies in a sphere | |
| rendering_options.update({ | |
| 'depth_resolution': 64, | |
| 'depth_resolution_importance': 64, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': 0.2, | |
| 'ray_end': 2.2, | |
| # 'ray_start': opts.ray_start, | |
| # 'ray_end': opts.ray_end, | |
| 'box_warp': 2, # TODO, how to set this value? | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'eg3d_shapenet_aug_resolution': | |
| rendering_options.update({ | |
| 'depth_resolution': 80, | |
| 'depth_resolution_importance': 80, | |
| 'ray_start': 0.1, | |
| 'ray_end': 1.9, # 2.6/1.7*1.2 | |
| 'box_warp': 1.1, | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair': | |
| rendering_options.update({ | |
| 'depth_resolution': 96, | |
| 'depth_resolution_importance': 96, | |
| 'ray_start': 0.1, | |
| 'ray_end': 1.9, # 2.6/1.7*1.2 | |
| 'box_warp': 1.1, | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128': | |
| rendering_options.update({ | |
| 'depth_resolution': 128, | |
| 'depth_resolution_importance': 128, | |
| 'ray_start': 0.1, | |
| 'ray_end': 1.9, # 2.6/1.7*1.2 | |
| 'box_warp': 1.1, | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_64': | |
| rendering_options.update({ | |
| 'depth_resolution': 64, | |
| 'depth_resolution_importance': 64, | |
| 'ray_start': 0.1, | |
| 'ray_end': 1.9, # 2.6/1.7*1.2 | |
| 'box_warp': 1.1, | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'srn_shapenet_aug_resolution_chair_128': | |
| rendering_options.update({ | |
| 'depth_resolution': 128, | |
| 'depth_resolution_importance': 128, | |
| 'ray_start': 1.25, | |
| 'ray_end': 2.75, | |
| 'box_warp': 1.5, | |
| 'white_back': True, | |
| 'avg_camera_radius': 2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'eg3d_shapenet_aug_resolution_chair_128_residualSR': | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 128, | |
| 'depth_resolution_importance': | |
| 128, | |
| 'ray_start': | |
| 0.1, | |
| 'ray_end': | |
| 1.9, # 2.6/1.7*1.2 | |
| 'box_warp': | |
| 1.1, | |
| 'white_back': | |
| True, | |
| 'avg_camera_radius': | |
| 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| 'superresolution_module': | |
| 'utils.torch_utils.components.NearestConvSR_Residual', | |
| }) | |
| elif opts.cfg == 'shapenet_tuneray': # TODO, lies in a sphere | |
| rendering_options.update({ | |
| 'depth_resolution': 64, | |
| 'depth_resolution_importance': 64, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': opts.ray_start, | |
| 'ray_end': opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': 80, | |
| 'depth_resolution_importance': 80, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': opts.ray_start, | |
| 'ray_end': opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': 128, | |
| 'depth_resolution_importance': 128, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': opts.ray_start, | |
| 'ray_end': opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': 96, | |
| 'depth_resolution_importance': 96, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': opts.ray_start, | |
| 'ray_end': opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| # ! default version | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestSR': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 96, | |
| 'depth_resolution_importance': | |
| 96, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| opts.ray_start, | |
| 'ray_end': | |
| opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': | |
| True, | |
| 'avg_camera_radius': | |
| 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| 'superresolution_module': | |
| 'utils.torch_utils.components.NearestConvSR', | |
| }) | |
| # ! 64+64, since ssdnerf adopts this setting | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 64, | |
| 'depth_resolution_importance': | |
| 64, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| opts.ray_start, | |
| 'ray_end': | |
| opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': | |
| True, | |
| 'avg_camera_radius': | |
| 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| 'superresolution_module': | |
| 'utils.torch_utils.components.NearestConvSR', | |
| }) | |
| # ! 64+64+patch, since ssdnerf adopts this setting | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestSR_patch': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 64, | |
| 'depth_resolution_importance': | |
| 64, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| opts.ray_start, | |
| 'ray_end': | |
| opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': | |
| True, | |
| 'avg_camera_radius': | |
| 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| 'superresolution_module': | |
| 'utils.torch_utils.components.NearestConvSR', | |
| # patch configs | |
| 'PatchRaySampler': | |
| True, | |
| # 'patch_rendering_resolution': 32, | |
| # 'patch_rendering_resolution': 48, | |
| 'patch_rendering_resolution': | |
| opts.patch_rendering_resolution, | |
| }) | |
| elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_nearestSR': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 64, | |
| 'depth_resolution_importance': | |
| 64, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| opts.ray_start, | |
| # 'auto', | |
| 'ray_end': | |
| opts.ray_end, | |
| # 'auto', | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| # 2, | |
| 'white_back': | |
| True, | |
| 'avg_camera_radius': | |
| 1.946, # ? | |
| 'avg_camera_pivot': [0, 0, 0], | |
| 'superresolution_module': | |
| 'utils.torch_utils.components.NearestConvSR', | |
| # patch configs | |
| # 'PatchRaySampler': False, | |
| # 'patch_rendering_resolution': 32, | |
| # 'patch_rendering_resolution': 48, | |
| # 'patch_rendering_resolution': opts.patch_rendering_resolution, | |
| }) | |
| elif opts.cfg == 'objverse_tuneray_aug_resolution_64_64_auto': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 64, | |
| 'depth_resolution_importance': | |
| 64, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| 'auto', | |
| 'ray_end': | |
| 'auto', | |
| 'box_warp': | |
| 0.9, | |
| 'white_back': | |
| True, | |
| 'radius_range': [1.5,2], | |
| # 'z_near': 1.5-0.45, # radius in [1.5, 2], https://github.com/modelscope/richdreamer/issues/12#issuecomment-1897734616 | |
| # 'z_far': 2.0+0.45, | |
| 'sampler_bbox_min': | |
| -0.45, | |
| 'sampler_bbox_max': | |
| 0.45, | |
| # 'avg_camera_pivot': [0, 0, 0], # not used | |
| 'filter_out_of_bbox': | |
| True, | |
| # 'superresolution_module': | |
| # 'utils.torch_utils.components.NearestConvSR', | |
| # patch configs | |
| 'PatchRaySampler': | |
| True, | |
| # 'patch_rendering_resolution': 32, | |
| # 'patch_rendering_resolution': 48, | |
| 'patch_rendering_resolution': | |
| opts.patch_rendering_resolution, | |
| }) | |
| rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min'] | |
| rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max'] | |
| elif opts.cfg == 'objverse_tuneray_aug_resolution_128_128_auto': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 128, | |
| 'depth_resolution_importance': | |
| 128, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| 'auto', | |
| 'ray_end': | |
| 'auto', | |
| 'box_warp': | |
| 0.9, | |
| 'white_back': | |
| True, | |
| 'radius_range': [1.5,2], | |
| # 'z_near': 1.5-0.45, # radius in [1.5, 2], https://github.com/modelscope/richdreamer/issues/12#issuecomment-1897734616 | |
| # 'z_far': 2.0+0.45, | |
| 'sampler_bbox_min': | |
| -0.45, | |
| 'sampler_bbox_max': | |
| 0.45, | |
| # 'avg_camera_pivot': [0, 0, 0], # not used | |
| 'filter_out_of_bbox': | |
| True, | |
| # 'superresolution_module': | |
| # 'utils.torch_utils.components.NearestConvSR', | |
| # patch configs | |
| 'PatchRaySampler': | |
| True, | |
| # 'patch_rendering_resolution': 32, | |
| # 'patch_rendering_resolution': 48, | |
| 'patch_rendering_resolution': | |
| opts.patch_rendering_resolution, | |
| }) | |
| rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min'] | |
| rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max'] | |
| elif opts.cfg == 'objverse_tuneray_aug_resolution_96_96_auto': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 96, | |
| 'depth_resolution_importance': | |
| 96, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| 'auto', | |
| 'ray_end': | |
| 'auto', | |
| 'box_warp': | |
| 0.9, | |
| 'white_back': | |
| True, | |
| 'radius_range': [1.5,2], | |
| 'sampler_bbox_min': | |
| -0.45, | |
| 'sampler_bbox_max': | |
| 0.45, | |
| 'filter_out_of_bbox': | |
| True, | |
| 'PatchRaySampler': | |
| True, | |
| 'patch_rendering_resolution': | |
| opts.patch_rendering_resolution, | |
| }) | |
| rendering_options['z_near'] = rendering_options['radius_range'][0]+rendering_options['sampler_bbox_min'] | |
| rendering_options['z_far'] = rendering_options['radius_range'][1]+rendering_options['sampler_bbox_max'] | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_96_nearestResidualSR': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 96, | |
| 'depth_resolution_importance': | |
| 96, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| opts.ray_start, | |
| 'ray_end': | |
| opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': | |
| True, | |
| 'avg_camera_radius': | |
| 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| 'superresolution_module': | |
| 'utils.torch_utils.components.NearestConvSR_Residual', | |
| }) | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_64_nearestResidualSR': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': | |
| 64, | |
| 'depth_resolution_importance': | |
| 64, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': | |
| opts.ray_start, | |
| 'ray_end': | |
| opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': | |
| True, | |
| 'avg_camera_radius': | |
| 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| 'superresolution_module': | |
| 'utils.torch_utils.components.NearestConvSR_Residual', | |
| }) | |
| elif opts.cfg == 'shapenet_tuneray_aug_resolution_64_104': # to differentiate hwc | |
| rendering_options.update({ | |
| 'depth_resolution': 104, | |
| 'depth_resolution_importance': 104, | |
| # * radius 1.2 setting, newly rendered images | |
| 'ray_start': opts.ray_start, | |
| 'ray_end': opts.ray_end, | |
| 'box_warp': | |
| opts.ray_end - opts.ray_start, # TODO, how to set this value? | |
| 'white_back': True, | |
| 'avg_camera_radius': 1.2, | |
| 'avg_camera_pivot': [0, 0, 0], | |
| }) | |
| rendering_options.update({'return_sampling_details_flag': True}) | |
| rendering_options.update({'return_sampling_details_flag': True}) | |
| return rendering_options | |
| def model_encoder_defaults(): | |
| return dict( | |
| use_clip=False, | |
| arch_encoder="vits", | |
| arch_decoder="vits", | |
| load_pretrain_encoder=False, | |
| encoder_lr=1e-5, | |
| encoder_weight_decay= | |
| 0.001, # https://github.com/google-research/vision_transformer | |
| no_dim_up_mlp=False, | |
| dim_up_mlp_as_func=False, | |
| decoder_load_pretrained=True, | |
| uvit_skip_encoder=False, | |
| # vae ldm | |
| vae_p=1, | |
| ldm_z_channels=4, | |
| ldm_embed_dim=4, | |
| use_conf_map=False, | |
| # sd E, lite version by default | |
| sd_E_ch=64, | |
| z_channels=3*4, | |
| sd_E_num_res_blocks=1, | |
| num_frames=4, | |
| # vit_decoder | |
| arch_dit_decoder='DiT2-B/2', | |
| return_all_dit_layers=False, | |
| # sd D | |
| # sd_D_ch=32, | |
| # sd_D_res_blocks=1, | |
| # sd_D_res_blocks=1, | |
| lrm_decoder=False, | |
| plane_n=3, | |
| gs_rendering=False, | |
| ) | |
| def triplane_decoder_defaults(): | |
| opts = dict( | |
| triplane_fg_bg=False, | |
| cfg='shapenet', | |
| density_reg=0.25, | |
| density_reg_p_dist=0.004, | |
| reg_type='l1', | |
| triplane_decoder_lr=0.0025, # follow eg3d G lr | |
| super_resolution_lr=0.0025, | |
| # triplane_decoder_wd=0.1, | |
| c_scale=1, | |
| nsr_lr=0.02, | |
| triplane_size=224, | |
| decoder_in_chans=32, | |
| triplane_in_chans=-1, | |
| decoder_output_dim=3, | |
| out_chans=96, | |
| c_dim=25, # Conditioning label (C) dimensionality. | |
| # ray_start=0.2, | |
| # ray_end=2.2, | |
| ray_start=0.6, # shapenet default | |
| ray_end=1.8, | |
| rendering_kwargs={}, | |
| sr_training=False, | |
| bcg_synthesis=False, # from panohead | |
| bcg_synthesis_kwargs={}, # G_kwargs.copy() | |
| # | |
| image_size=128, # raw 3D rendering output resolution. | |
| patch_rendering_resolution=45, | |
| ) | |
| # else: | |
| # assert False, "Need to specify config" | |
| # opts = dict(opts) | |
| # opts.pop('cfg') | |
| return opts | |
| def vit_decoder_defaults(): | |
| res = dict( | |
| vit_decoder_lr=1e-5, # follow eg3d G lr | |
| vit_decoder_wd=0.001, | |
| ) | |
| return res | |
| def nsr_decoder_defaults(): | |
| res = { | |
| 'decomposed': False, | |
| } # TODO, add defaults for all nsr | |
| res.update(triplane_decoder_defaults()) # triplane by default now | |
| res.update(vit_decoder_defaults()) # type: ignore | |
| return res | |
| def loss_defaults(): | |
| opt = dict( | |
| color_criterion='mse', | |
| l2_lambda=1.0, | |
| lpips_lambda=0., | |
| lpips_delay_iter=0, | |
| sr_delay_iter=0, | |
| # kl_anneal=0, | |
| kl_anneal=False, | |
| latent_lambda=0., | |
| latent_criterion='mse', | |
| kl_lambda=0.0, | |
| # kl_anneal=False, | |
| ssim_lambda=0., | |
| l1_lambda=0., | |
| id_lambda=0.0, | |
| depth_lambda=0.0, # TODO | |
| alpha_lambda=0.0, # TODO | |
| fg_mse=False, | |
| bg_lamdba=0.0, | |
| density_reg=0.0, # tvloss in eg3d | |
| density_reg_p_dist=0.004, # 'density regularization strength.' | |
| density_reg_every=4, # lazy density reg | |
| # 3D supervision, ffhq/afhq eg3d warm up | |
| shape_uniform_lambda=0.005, | |
| shape_importance_lambda=0.01, | |
| shape_depth_lambda=0., | |
| # gan loss | |
| rec_cvD_lambda=0.01, | |
| nvs_cvD_lambda=0.025, | |
| patchgan_disc_factor=0.01, | |
| patchgan_disc_g_weight=0.2, # | |
| r1_gamma=1.0, # ffhq default value for eg3d | |
| sds_lamdba=1.0, | |
| nvs_D_lr_mul=1, # compared with 1e-4 | |
| cano_D_lr_mul=1, # compared with 1e-4 | |
| # lsgm loss | |
| ce_balanced_kl=1., | |
| p_eps_lambda=1, | |
| # symmetric loss | |
| symmetry_loss=False, | |
| depth_smoothness_lambda=0.0, | |
| ce_lambda=1.0, | |
| negative_entropy_lambda=1.0, | |
| grad_clip=False, | |
| online_mask=False, # in unsup3d | |
| ) | |
| return opt | |
| def dataset_defaults(): | |
| res = dict( | |
| use_lmdb=False, | |
| use_wds=False, | |
| use_lmdb_compressed=True, | |
| compile=False, | |
| interval=1, | |
| objv_dataset=False, | |
| decode_encode_img_only=False, | |
| load_wds_diff=False, | |
| load_wds_latent=False, | |
| eval_load_wds_instance=True, | |
| shards_lst="", | |
| eval_shards_lst="", | |
| mv_input=False, | |
| duplicate_sample=True, | |
| orthog_duplicate=False, | |
| split_chunk_input=False, # split=8 per chunk | |
| load_real=False, | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| shuffle_across_cls=False, | |
| load_extra_36_view=False, | |
| mv_latent_dir='', | |
| append_depth=False, | |
| plucker_embedding=False, | |
| gs_cam_format=False, | |
| split_chunk_size=8, | |
| ) | |
| return res | |
| def encoder_and_nsr_defaults(): | |
| """ | |
| Defaults for image training. | |
| """ | |
| # ViT configs | |
| res = dict( | |
| dino_version='v1', | |
| encoder_in_channels=3, | |
| img_size=[224], | |
| patch_size=16, # ViT-S/16 | |
| in_chans=384, | |
| num_classes=0, | |
| embed_dim=384, # Check ViT encoder dim | |
| depth=6, | |
| num_heads=16, | |
| mlp_ratio=4., | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop_rate=0.1, | |
| attn_drop_rate=0., | |
| drop_path_rate=0., | |
| norm_layer='nn.LayerNorm', | |
| # img_resolution=128, # Output resolution. | |
| cls_token=False, | |
| # image_size=128, # rendered output resolution. | |
| # img_channels=3, # Number of output color channels. | |
| encoder_cls_token=False, | |
| decoder_cls_token=False, | |
| sr_kwargs={}, | |
| sr_ratio=2, | |
| # sd configs | |
| ) | |
| # Triplane configs | |
| res.update(model_encoder_defaults()) | |
| res.update(nsr_decoder_defaults()) | |
| res.update( | |
| ae_classname='vit.vit_triplane.ViTTriplaneDecomposed') # if add SR | |
| return res | |
| def create_3DAE_model( | |
| arch_encoder, | |
| arch_decoder, | |
| dino_version='v1', | |
| img_size=[224], | |
| patch_size=16, | |
| in_chans=384, | |
| num_classes=0, | |
| embed_dim=1024, # Check ViT encoder dim | |
| depth=6, | |
| num_heads=16, | |
| mlp_ratio=4., | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop_rate=0.1, | |
| attn_drop_rate=0., | |
| drop_path_rate=0., | |
| # norm_layer=nn.LayerNorm, | |
| norm_layer='nn.LayerNorm', | |
| out_chans=96, | |
| decoder_in_chans=32, | |
| triplane_in_chans=-1, | |
| decoder_output_dim=32, | |
| encoder_cls_token=False, | |
| decoder_cls_token=False, | |
| c_dim=25, # Conditioning label (C) dimensionality. | |
| image_size=128, # Output resolution. | |
| img_channels=3, # Number of output color channels. | |
| rendering_kwargs={}, | |
| load_pretrain_encoder=False, | |
| decomposed=True, | |
| triplane_size=224, | |
| ae_classname='ViTTriplaneDecomposed', | |
| use_clip=False, | |
| sr_kwargs={}, | |
| sr_ratio=2, | |
| no_dim_up_mlp=False, | |
| dim_up_mlp_as_func=False, | |
| decoder_load_pretrained=True, | |
| uvit_skip_encoder=False, | |
| bcg_synthesis_kwargs={}, | |
| # decoder params | |
| vae_p=1, | |
| ldm_z_channels=4, | |
| ldm_embed_dim=4, | |
| use_conf_map=False, | |
| triplane_fg_bg=False, | |
| encoder_in_channels=3, | |
| sd_E_ch=64, | |
| z_channels=3*4, | |
| sd_E_num_res_blocks=1, | |
| num_frames=6, | |
| arch_dit_decoder='DiT2-B/2', | |
| lrm_decoder=False, | |
| gs_rendering=False, | |
| return_all_dit_layers=False, | |
| *args, | |
| **kwargs): | |
| # TODO, check pre-trained ViT encoder cfgs | |
| preprocess = None | |
| clip_dtype = None | |
| if load_pretrain_encoder: | |
| if not use_clip: | |
| if dino_version == 'v1': | |
| encoder = torch.hub.load( | |
| 'facebookresearch/dino:main', | |
| 'dino_{}{}'.format(arch_encoder, patch_size)) | |
| logger.log( | |
| f'loaded pre-trained dino v1 ViT-S{patch_size} encoder ckpt' | |
| ) | |
| elif dino_version == 'v2': | |
| encoder = torch.hub.load( | |
| 'facebookresearch/dinov2', | |
| 'dinov2_{}{}'.format(arch_encoder, patch_size)) | |
| logger.log( | |
| f'loaded pre-trained dino v2 {arch_encoder}{patch_size} encoder ckpt' | |
| ) | |
| elif 'sd' in dino_version: # just for compat | |
| if 'mv' in dino_version: | |
| if 'lgm' in dino_version: | |
| encoder_cls = MVUNet( | |
| input_size=256, | |
| up_channels=(1024, 1024, 512, 256, | |
| 128), # one more decoder | |
| up_attention=(True, True, True, False, False), | |
| splat_size=128, | |
| output_size= | |
| 512, # render & supervise Gaussians at a higher resolution. | |
| batch_size=8, | |
| num_views=8, | |
| gradient_accumulation_steps=1, | |
| # mixed_precision='bf16', | |
| ) | |
| elif 'gs' in dino_version: | |
| encoder_cls = MVEncoder | |
| else: | |
| encoder_cls = MVEncoder | |
| else: | |
| encoder_cls = Encoder | |
| encoder = encoder_cls( # mono input | |
| double_z=True, | |
| resolution=256, | |
| in_channels=encoder_in_channels, | |
| # ch=128, | |
| ch=64, # ! fit in the memory | |
| # ch_mult=[1,2,4,4], | |
| # num_res_blocks=2, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=1, | |
| dropout=0.0, | |
| attn_resolutions=[], | |
| out_ch=3, # unused | |
| z_channels=4 * 3, | |
| ) # stable diffusion encoder | |
| else: | |
| raise NotImplementedError() | |
| else: | |
| import clip | |
| model, preprocess = clip.load("ViT-B/16", device=dist_util.dev()) | |
| model.float() # convert weight to float32 | |
| clip_dtype = model.dtype | |
| encoder = getattr( | |
| model, 'visual') # only use the CLIP visual encoder here | |
| encoder.requires_grad_(False) | |
| logger.log( | |
| f'loaded pre-trained CLIP ViT-B{patch_size} encoder, fixed.') | |
| elif 'sd' in dino_version: | |
| attn_kwargs = {} | |
| if 'mv' in dino_version: | |
| if 'lgm' in dino_version: | |
| encoder = LGM_MVEncoder( | |
| in_channels=9, | |
| # input_size=256, | |
| up_channels=(1024, 1024, 512, 256, | |
| 128), # one more decoder | |
| up_attention=(True, True, True, False, False), | |
| ) | |
| else: | |
| if 'dynaInp' in dino_version: | |
| encoder_cls = MVEncoderGSDynamicInp | |
| else: | |
| encoder_cls = MVEncoder | |
| attn_kwargs = { | |
| 'n_heads': 8, | |
| 'd_head': 64, | |
| } | |
| else: | |
| encoder_cls = Encoder | |
| if 'lgm' not in dino_version: # TODO, for compat now | |
| # st() | |
| encoder = encoder_cls( | |
| double_z=True, | |
| resolution=256, | |
| in_channels=encoder_in_channels, | |
| # ch=128, | |
| # ch=64, # ! fit in the memory | |
| ch=sd_E_ch, | |
| # ch_mult=[1,2,4,4], | |
| # num_res_blocks=2, | |
| ch_mult=[1, 2, 4, 4], | |
| # num_res_blocks=1, | |
| num_res_blocks=sd_E_num_res_blocks, | |
| num_frames=num_frames, | |
| dropout=0.0, | |
| attn_resolutions=[], | |
| out_ch=3, # unused | |
| z_channels=z_channels, # 4 * 3 | |
| attn_kwargs=attn_kwargs, | |
| ) # stable diffusion encoder | |
| else: | |
| encoder = vits.__dict__[arch_encoder]( | |
| patch_size=patch_size, | |
| drop_path_rate=drop_path_rate, # stochastic depth | |
| img_size=img_size) | |
| # assert decomposed | |
| # if decomposed: | |
| if triplane_in_chans == -1: | |
| triplane_in_chans = decoder_in_chans | |
| # if triplane_fg_bg: | |
| # triplane_renderer_cls = Triplane_fg_bg_plane | |
| # else: | |
| triplane_renderer_cls = Triplane | |
| # triplane_decoder = Triplane( | |
| triplane_decoder = triplane_renderer_cls( | |
| c_dim, # Conditioning label (C) dimensionality. | |
| image_size, # Output resolution. | |
| img_channels, # Number of output color channels. | |
| rendering_kwargs=rendering_kwargs, | |
| out_chans=out_chans, | |
| # create_triplane=True, # compatability, remove later | |
| triplane_size=triplane_size, | |
| decoder_in_chans=triplane_in_chans, | |
| decoder_output_dim=decoder_output_dim, | |
| sr_kwargs=sr_kwargs, | |
| bcg_synthesis_kwargs=bcg_synthesis_kwargs, | |
| lrm_decoder=lrm_decoder) | |
| if load_pretrain_encoder: | |
| if dino_version == 'v1': | |
| vit_decoder = torch.hub.load( | |
| 'facebookresearch/dino:main', | |
| 'dino_{}{}'.format(arch_decoder, patch_size)) | |
| logger.log( | |
| 'loaded pre-trained decoder', | |
| "facebookresearch/dino:main', 'dino_{}{}".format( | |
| arch_decoder, patch_size)) | |
| else: | |
| vit_decoder = torch.hub.load( | |
| 'facebookresearch/dinov2', | |
| # 'dinov2_{}{}'.format(arch_decoder, patch_size)) | |
| 'dinov2_{}{}'.format(arch_decoder, patch_size), | |
| pretrained=decoder_load_pretrained) | |
| logger.log( | |
| 'loaded pre-trained decoder', | |
| "facebookresearch/dinov2', 'dinov2_{}{}".format( | |
| arch_decoder, | |
| patch_size), 'pretrianed=', decoder_load_pretrained) | |
| elif 'dit' in dino_version: | |
| from dit.dit_decoder import DiT2_models | |
| vit_decoder = DiT2_models[arch_dit_decoder]( | |
| input_size=16, | |
| num_classes=0, | |
| learn_sigma=False, | |
| in_channels=embed_dim, | |
| mixed_prediction=False, | |
| context_dim=None, # add CLIP text embedding | |
| roll_out=True, plane_n=4 if | |
| 'gs' in dino_version else 3, | |
| return_all_layers=return_all_dit_layers, | |
| ) | |
| else: # has bug on global token, to fix | |
| vit_decoder = vits.__dict__[arch_decoder]( | |
| patch_size=patch_size, | |
| drop_path_rate=drop_path_rate, # stochastic depth | |
| img_size=img_size) | |
| # decoder = ViTTriplaneDecomposed(vit_decoder, triplane_decoder) | |
| # if True: | |
| decoder_kwargs = dict( | |
| class_name=ae_classname, | |
| vit_decoder=vit_decoder, | |
| triplane_decoder=triplane_decoder, | |
| # encoder_cls_token=encoder_cls_token, | |
| cls_token=decoder_cls_token, | |
| sr_ratio=sr_ratio, | |
| vae_p=vae_p, | |
| ldm_z_channels=ldm_z_channels, | |
| ldm_embed_dim=ldm_embed_dim, | |
| ) | |
| decoder = dnnlib.util.construct_class_by_name(**decoder_kwargs) | |
| # if return_encoder_decoder: | |
| # return encoder, decoder, img_size[0], cls_token | |
| # else: | |
| if use_conf_map: | |
| confnet = ConfNet(cin=3, cout=1, nf=64, zdim=128) | |
| else: | |
| confnet = None | |
| auto_encoder = AE( | |
| encoder, | |
| decoder, | |
| img_size[0], | |
| encoder_cls_token, | |
| decoder_cls_token, | |
| preprocess, | |
| use_clip, | |
| dino_version, | |
| clip_dtype, | |
| no_dim_up_mlp=no_dim_up_mlp, | |
| dim_up_mlp_as_func=dim_up_mlp_as_func, | |
| uvit_skip_encoder=uvit_skip_encoder, | |
| confnet=confnet, | |
| ) | |
| logger.log(auto_encoder) | |
| torch.cuda.empty_cache() | |
| return auto_encoder | |
| # def create_3DAE_Diffusion_model( | |
| # arch_encoder, | |
| # arch_decoder, | |
| # img_size=[224], | |
| # patch_size=16, | |
| # in_chans=384, | |
| # num_classes=0, | |
| # embed_dim=1024, # Check ViT encoder dim | |
| # depth=6, | |
| # num_heads=16, | |
| # mlp_ratio=4., | |
| # qkv_bias=False, | |
| # qk_scale=None, | |
| # drop_rate=0.1, | |
| # attn_drop_rate=0., | |
| # drop_path_rate=0., | |
| # # norm_layer=nn.LayerNorm, | |
| # norm_layer='nn.LayerNorm', | |
| # out_chans=96, | |
| # decoder_in_chans=32, | |
| # decoder_output_dim=32, | |
| # cls_token=False, | |
| # c_dim=25, # Conditioning label (C) dimensionality. | |
| # img_resolution=128, # Output resolution. | |
| # img_channels=3, # Number of output color channels. | |
| # rendering_kwargs={}, | |
| # load_pretrain_encoder=False, | |
| # decomposed=True, | |
| # triplane_size=224, | |
| # ae_classname='ViTTriplaneDecomposed', | |
| # # return_encoder_decoder=False, | |
| # *args, | |
| # **kwargs | |
| # ): | |
| # # TODO, check pre-trained ViT encoder cfgs | |
| # encoder, decoder, img_size, cls_token = create_3DAE_model( | |
| # arch_encoder, | |
| # arch_decoder, | |
| # img_size, | |
| # patch_size, | |
| # in_chans, | |
| # num_classes, | |
| # embed_dim, # Check ViT encoder dim | |
| # depth, | |
| # num_heads, | |
| # mlp_ratio, | |
| # qkv_bias, | |
| # qk_scale, | |
| # drop_rate, | |
| # attn_drop_rate, | |
| # drop_path_rate, | |
| # # norm_layer=nn.LayerNorm, | |
| # norm_layer, | |
| # out_chans=96, | |
| # decoder_in_chans=32, | |
| # decoder_output_dim=32, | |
| # cls_token=False, | |
| # c_dim=25, # Conditioning label (C) dimensionality. | |
| # img_resolution=128, # Output resolution. | |
| # img_channels=3, # Number of output color channels. | |
| # rendering_kwargs={}, | |
| # load_pretrain_encoder=False, | |
| # decomposed=True, | |
| # triplane_size=224, | |
| # ae_classname='ViTTriplaneDecomposed', | |
| # return_encoder_decoder=False, | |
| # *args, | |
| # **kwargs | |
| # ) # type: ignore | |
| def create_Triplane( | |
| c_dim=25, # Conditioning label (C) dimensionality. | |
| img_resolution=128, # Output resolution. | |
| img_channels=3, # Number of output color channels. | |
| rendering_kwargs={}, | |
| decoder_output_dim=32, | |
| *args, | |
| **kwargs): | |
| decoder = Triplane( | |
| c_dim, # Conditioning label (C) dimensionality. | |
| img_resolution, # Output resolution. | |
| img_channels, # Number of output color channels. | |
| # TODO, replace with c | |
| rendering_kwargs=rendering_kwargs, | |
| create_triplane=True, | |
| decoder_output_dim=decoder_output_dim) | |
| return decoder | |
| def DiT_defaults(): | |
| return { | |
| 'dit_model': "DiT-B/16", | |
| 'vae': "ema" | |
| # dit_model="DiT-XL/2", | |
| # dit_patch_size=8, | |
| } | |