Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| from threading import local | |
| import torch | |
| import torch.nn as nn | |
| from utils.torch_utils import persistence | |
| from .networks_stylegan2 import Generator as StyleGAN2Backbone | |
| from .networks_stylegan2 import ToRGBLayer, SynthesisNetwork, MappingNetwork | |
| from .volumetric_rendering.renderer import ImportanceRenderer | |
| from .volumetric_rendering.ray_sampler import RaySampler, PatchRaySampler | |
| import dnnlib | |
| from pdb import set_trace as st | |
| import math | |
| import torch.nn.functional as F | |
| import itertools | |
| from ldm.modules.diffusionmodules.model import SimpleDecoder, Decoder | |
| class TriPlaneGenerator(torch.nn.Module): | |
| def __init__( | |
| self, | |
| z_dim, # Input latent (Z) dimensionality. | |
| c_dim, # Conditioning label (C) dimensionality. | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| img_resolution, # Output resolution. | |
| img_channels, # Number of output color channels. | |
| sr_num_fp16_res=0, | |
| mapping_kwargs={}, # Arguments for MappingNetwork. | |
| rendering_kwargs={}, | |
| sr_kwargs={}, | |
| bcg_synthesis_kwargs={}, | |
| # pifu_kwargs={}, | |
| # ada_kwargs={}, # not used, place holder | |
| **synthesis_kwargs, # Arguments for SynthesisNetwork. | |
| ): | |
| super().__init__() | |
| self.z_dim = z_dim | |
| self.c_dim = c_dim | |
| self.w_dim = w_dim | |
| self.img_resolution = img_resolution | |
| self.img_channels = img_channels | |
| self.renderer = ImportanceRenderer() | |
| # if 'PatchRaySampler' in rendering_kwargs: | |
| # self.ray_sampler = PatchRaySampler() | |
| # else: | |
| # self.ray_sampler = RaySampler() | |
| self.backbone = StyleGAN2Backbone(z_dim, | |
| c_dim, | |
| w_dim, | |
| img_resolution=256, | |
| img_channels=32 * 3, | |
| mapping_kwargs=mapping_kwargs, | |
| **synthesis_kwargs) | |
| self.superresolution = dnnlib.util.construct_class_by_name( | |
| class_name=rendering_kwargs['superresolution_module'], | |
| channels=32, | |
| img_resolution=img_resolution, | |
| sr_num_fp16_res=sr_num_fp16_res, | |
| sr_antialias=rendering_kwargs['sr_antialias'], | |
| **sr_kwargs) | |
| # self.bcg_synthesis = None | |
| if rendering_kwargs.get('use_background', False): | |
| self.bcg_synthesis = SynthesisNetwork( | |
| w_dim, | |
| img_resolution=self.superresolution.input_resolution, | |
| img_channels=32, | |
| **bcg_synthesis_kwargs) | |
| self.bcg_mapping = MappingNetwork(z_dim=z_dim, | |
| c_dim=c_dim, | |
| w_dim=w_dim, | |
| num_ws=self.num_ws, | |
| **mapping_kwargs) | |
| # New mapping network for self-adaptive camera pose, dim = 3 | |
| self.decoder = OSGDecoder( | |
| 32, { | |
| 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), | |
| 'decoder_output_dim': 32 | |
| }) | |
| self.neural_rendering_resolution = 64 | |
| self.rendering_kwargs = rendering_kwargs | |
| self._last_planes = None | |
| self.pool_256 = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
| def mapping(self, | |
| z, | |
| c, | |
| truncation_psi=1, | |
| truncation_cutoff=None, | |
| update_emas=False): | |
| if self.rendering_kwargs['c_gen_conditioning_zero']: | |
| c = torch.zeros_like(c) | |
| return self.backbone.mapping(z, | |
| c * | |
| self.rendering_kwargs.get('c_scale', 0), | |
| truncation_psi=truncation_psi, | |
| truncation_cutoff=truncation_cutoff, | |
| update_emas=update_emas) | |
| def synthesis(self, | |
| ws, | |
| c, | |
| neural_rendering_resolution=None, | |
| update_emas=False, | |
| cache_backbone=False, | |
| use_cached_backbone=False, | |
| return_meta=False, | |
| return_raw_only=False, | |
| **synthesis_kwargs): | |
| return_sampling_details_flag = self.rendering_kwargs.get( | |
| 'return_sampling_details_flag', False) | |
| if return_sampling_details_flag: | |
| return_meta = True | |
| cam2world_matrix = c[:, :16].view(-1, 4, 4) | |
| # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0) | |
| # c[:, :16] = cam2world_matrix.view(-1, 16) | |
| intrinsics = c[:, 16:25].view(-1, 3, 3) | |
| if neural_rendering_resolution is None: | |
| neural_rendering_resolution = self.neural_rendering_resolution | |
| else: | |
| self.neural_rendering_resolution = neural_rendering_resolution | |
| H = W = self.neural_rendering_resolution | |
| # Create a batch of rays for volume rendering | |
| ray_origins, ray_directions = self.ray_sampler( | |
| cam2world_matrix, intrinsics, neural_rendering_resolution) | |
| # Create triplanes by running StyleGAN backbone | |
| N, M, _ = ray_origins.shape | |
| if use_cached_backbone and self._last_planes is not None: | |
| planes = self._last_planes | |
| else: | |
| planes = self.backbone.synthesis( | |
| ws[:, :self.backbone.num_ws, :], # ws, BS 14 512 | |
| update_emas=update_emas, | |
| **synthesis_kwargs) | |
| if cache_backbone: | |
| self._last_planes = planes | |
| # Reshape output into three 32-channel planes | |
| planes = planes.view(len(planes), 3, 32, planes.shape[-2], | |
| planes.shape[-1]) # BS 96 256 256 | |
| # Perform volume rendering | |
| # st() | |
| rendering_details = self.renderer( | |
| planes, | |
| self.decoder, | |
| ray_origins, | |
| ray_directions, | |
| self.rendering_kwargs, | |
| # return_meta=True) | |
| return_meta=return_meta) | |
| # calibs = create_calib_matrix(c) | |
| # all_coords = rendering_details['all_coords'] | |
| # B, num_rays, S, _ = all_coords.shape | |
| # all_coords_B3N = all_coords.reshape(B, -1, 3).permute(0,2,1) | |
| # homo_coords = torch.cat([all_coords, torch.zeros_like(all_coords[..., :1])], -1) | |
| # homo_coords[..., -1] = 1 | |
| # homo_coords = homo_coords.reshape(homo_coords.shape[0], -1, 4) | |
| # homo_coords = homo_coords.permute(0,2,1) | |
| # xyz = calibs @ homo_coords | |
| # xyz = xyz.permute(0,2,1).reshape(B, H, W, S, 4) | |
| # st() | |
| # xyz_proj = perspective(all_coords_B3N, calibs) | |
| # xyz_proj = xyz_proj.permute(0,2,1).reshape(B, H, W, S, 3) # [0,0] - [1,1] | |
| # st() | |
| feature_samples, depth_samples, weights_samples = ( | |
| rendering_details[k] | |
| for k in ['feature_samples', 'depth_samples', 'weights_samples']) | |
| if return_sampling_details_flag: | |
| shape_synthesized = rendering_details['shape_synthesized'] | |
| else: | |
| shape_synthesized = None | |
| # Reshape into 'raw' neural-rendered image | |
| feature_image = feature_samples.permute(0, 2, 1).reshape( | |
| N, feature_samples.shape[-1], H, W).contiguous() # B 32 H W | |
| depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
| # Run superresolution to get final image | |
| rgb_image = feature_image[:, :3] # B 3 H W | |
| if not return_raw_only: | |
| sr_image = self.superresolution( | |
| rgb_image, | |
| feature_image, | |
| ws[:, -1:, :], # only use the last layer | |
| noise_mode=self.rendering_kwargs['superresolution_noise_mode'], | |
| **{ | |
| k: synthesis_kwargs[k] | |
| for k in synthesis_kwargs.keys() if k != 'noise_mode' | |
| }) | |
| else: | |
| sr_image = rgb_image | |
| ret_dict = { | |
| 'image': sr_image, | |
| 'image_raw': rgb_image, | |
| 'image_depth': depth_image, | |
| 'weights_samples': weights_samples, | |
| 'shape_synthesized': shape_synthesized | |
| } | |
| if return_meta: | |
| ret_dict.update({ | |
| # 'feature_image': feature_image, | |
| 'feature_volume': | |
| rendering_details['feature_volume'], | |
| 'all_coords': | |
| rendering_details['all_coords'], | |
| 'weights': | |
| rendering_details['weights'], | |
| }) | |
| return ret_dict | |
| def sample(self, | |
| coordinates, | |
| directions, | |
| z, | |
| c, | |
| truncation_psi=1, | |
| truncation_cutoff=None, | |
| update_emas=False, | |
| **synthesis_kwargs): | |
| # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. | |
| ws = self.mapping(z, | |
| c, | |
| truncation_psi=truncation_psi, | |
| truncation_cutoff=truncation_cutoff, | |
| update_emas=update_emas) | |
| planes = self.backbone.synthesis(ws, | |
| update_emas=update_emas, | |
| **synthesis_kwargs) | |
| planes = planes.view(len(planes), 3, 32, planes.shape[-2], | |
| planes.shape[-1]) | |
| return self.renderer.run_model(planes, self.decoder, coordinates, | |
| directions, self.rendering_kwargs) | |
| def sample_mixed(self, | |
| coordinates, | |
| directions, | |
| ws, | |
| truncation_psi=1, | |
| truncation_cutoff=None, | |
| update_emas=False, | |
| **synthesis_kwargs): | |
| # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z' | |
| planes = self.backbone.synthesis(ws, | |
| update_emas=update_emas, | |
| **synthesis_kwargs) | |
| planes = planes.view(len(planes), 3, 32, planes.shape[-2], | |
| planes.shape[-1]) | |
| return self.renderer.run_model(planes, self.decoder, coordinates, | |
| directions, self.rendering_kwargs) | |
| def forward(self, | |
| z, | |
| c, | |
| truncation_psi=1, | |
| truncation_cutoff=None, | |
| neural_rendering_resolution=None, | |
| update_emas=False, | |
| cache_backbone=False, | |
| use_cached_backbone=False, | |
| **synthesis_kwargs): | |
| # Render a batch of generated images. | |
| ws = self.mapping(z, | |
| c, | |
| truncation_psi=truncation_psi, | |
| truncation_cutoff=truncation_cutoff, | |
| update_emas=update_emas) | |
| return self.synthesis( | |
| ws, | |
| c, | |
| update_emas=update_emas, | |
| neural_rendering_resolution=neural_rendering_resolution, | |
| cache_backbone=cache_backbone, | |
| use_cached_backbone=use_cached_backbone, | |
| **synthesis_kwargs) | |
| from .networks_stylegan2 import FullyConnectedLayer | |
| # class OSGDecoder(torch.nn.Module): | |
| # def __init__(self, n_features, options): | |
| # super().__init__() | |
| # self.hidden_dim = 64 | |
| # self.output_dim = options['decoder_output_dim'] | |
| # self.n_features = n_features | |
| # self.net = torch.nn.Sequential( | |
| # FullyConnectedLayer(n_features, | |
| # self.hidden_dim, | |
| # lr_multiplier=options['decoder_lr_mul']), | |
| # torch.nn.Softplus(), | |
| # FullyConnectedLayer(self.hidden_dim, | |
| # 1 + options['decoder_output_dim'], | |
| # lr_multiplier=options['decoder_lr_mul'])) | |
| # def forward(self, sampled_features, ray_directions): | |
| # # Aggregate features | |
| # sampled_features = sampled_features.mean(1) | |
| # x = sampled_features | |
| # N, M, C = x.shape | |
| # x = x.view(N * M, C) | |
| # x = self.net(x) | |
| # x = x.view(N, M, -1) | |
| # rgb = torch.sigmoid(x[..., 1:]) * ( | |
| # 1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF | |
| # sigma = x[..., 0:1] | |
| # return {'rgb': rgb, 'sigma': sigma} | |
| class OSGDecoder(torch.nn.Module): | |
| def __init__(self, n_features, options): | |
| super().__init__() | |
| self.hidden_dim = 64 | |
| self.decoder_output_dim = options['decoder_output_dim'] | |
| self.net = torch.nn.Sequential( | |
| FullyConnectedLayer(n_features, | |
| self.hidden_dim, | |
| lr_multiplier=options['decoder_lr_mul']), | |
| torch.nn.Softplus(), | |
| FullyConnectedLayer(self.hidden_dim, | |
| 1 + options['decoder_output_dim'], | |
| lr_multiplier=options['decoder_lr_mul'])) | |
| self.activation = options.get('decoder_activation', 'sigmoid') | |
| def forward(self, sampled_features, ray_directions): | |
| # Aggregate features | |
| sampled_features = sampled_features.mean(1) | |
| x = sampled_features | |
| N, M, C = x.shape | |
| x = x.view(N * M, C) | |
| x = self.net(x) | |
| x = x.view(N, M, -1) | |
| rgb = x[..., 1:] | |
| sigma = x[..., 0:1] | |
| if self.activation == "sigmoid": | |
| # Original EG3D | |
| rgb = torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 | |
| elif self.activation == "lrelu": | |
| # StyleGAN2-style, use with toRGB | |
| rgb = torch.nn.functional.leaky_relu(rgb, 0.2, | |
| inplace=True) * math.sqrt(2) | |
| return {'rgb': rgb, 'sigma': sigma} | |
| class LRMOSGDecoder(nn.Module): | |
| """ | |
| Triplane decoder that gives RGB and sigma values from sampled features. | |
| Using ReLU here instead of Softplus in the original implementation. | |
| Reference: | |
| EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 | |
| """ | |
| def __init__(self, n_features: int, | |
| hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): | |
| super().__init__() | |
| self.decoder_output_dim = 3 | |
| self.net = nn.Sequential( | |
| nn.Linear(3 * n_features, hidden_dim), | |
| activation(), | |
| *itertools.chain(*[[ | |
| nn.Linear(hidden_dim, hidden_dim), | |
| activation(), | |
| ] for _ in range(num_layers - 2)]), | |
| nn.Linear(hidden_dim, 1 + self.decoder_output_dim), | |
| ) | |
| # init all bias to zero | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.zeros_(m.bias) | |
| def forward(self, sampled_features, ray_directions): | |
| # Aggregate features by mean | |
| # sampled_features = sampled_features.mean(1) | |
| # Aggregate features by concatenation | |
| _N, n_planes, _M, _C = sampled_features.shape | |
| sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) | |
| x = sampled_features | |
| N, M, C = x.shape | |
| x = x.contiguous().view(N*M, C) | |
| x = self.net(x) | |
| x = x.view(N, M, -1) | |
| rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF | |
| sigma = x[..., 0:1] | |
| return {'rgb': rgb, 'sigma': sigma} | |
| class Triplane(torch.nn.Module): | |
| def __init__( | |
| self, | |
| c_dim=25, # Conditioning label (C) dimensionality. | |
| img_resolution=128, # Output resolution. | |
| img_channels=3, # Number of output color channels. | |
| out_chans=96, | |
| triplane_size=224, | |
| rendering_kwargs={}, | |
| decoder_in_chans=32, | |
| decoder_output_dim=32, | |
| sr_num_fp16_res=0, | |
| sr_kwargs={}, | |
| create_triplane=False, # for overfitting single instance study | |
| bcg_synthesis_kwargs={}, | |
| lrm_decoder=False, | |
| ): | |
| super().__init__() | |
| self.c_dim = c_dim | |
| self.img_resolution = img_resolution # TODO | |
| self.img_channels = img_channels | |
| self.triplane_size = triplane_size | |
| self.decoder_in_chans = decoder_in_chans | |
| self.out_chans = out_chans | |
| self.renderer = ImportanceRenderer() | |
| if 'PatchRaySampler' in rendering_kwargs: | |
| self.ray_sampler = PatchRaySampler() | |
| else: | |
| self.ray_sampler = RaySampler() | |
| if lrm_decoder: | |
| self.decoder = LRMOSGDecoder( | |
| decoder_in_chans,) | |
| else: | |
| self.decoder = OSGDecoder( | |
| decoder_in_chans, | |
| { | |
| 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), | |
| # 'decoder_output_dim': 32 | |
| 'decoder_output_dim': decoder_output_dim | |
| }) | |
| self.neural_rendering_resolution = img_resolution # TODO | |
| # self.neural_rendering_resolution = 128 # TODO | |
| self.rendering_kwargs = rendering_kwargs | |
| self.create_triplane = create_triplane | |
| if create_triplane: | |
| self.planes = nn.Parameter(torch.randn(1, out_chans, 256, 256)) | |
| if bool(sr_kwargs): # check whether empty | |
| assert decoder_in_chans == decoder_output_dim, 'tradition' | |
| if rendering_kwargs['superresolution_module'] in [ | |
| 'utils.torch_utils.components.PixelUnshuffleUpsample', | |
| 'utils.torch_utils.components.NearestConvSR', | |
| 'utils.torch_utils.components.NearestConvSR_Residual' | |
| ]: | |
| self.superresolution = dnnlib.util.construct_class_by_name( | |
| class_name=rendering_kwargs['superresolution_module'], | |
| # * for PixelUnshuffleUpsample | |
| sr_ratio=2, # 2x SR, 128 -> 256 | |
| output_dim=decoder_output_dim, | |
| num_out_ch=3, | |
| ) | |
| else: | |
| self.superresolution = dnnlib.util.construct_class_by_name( | |
| class_name=rendering_kwargs['superresolution_module'], | |
| # * for stylegan upsample | |
| channels=decoder_output_dim, | |
| img_resolution=img_resolution, | |
| sr_num_fp16_res=sr_num_fp16_res, | |
| sr_antialias=rendering_kwargs['sr_antialias'], | |
| **sr_kwargs) | |
| else: | |
| self.superresolution = None | |
| self.bcg_synthesis = None | |
| # * pure reconstruction | |
| def forward( | |
| self, | |
| planes=None, | |
| # img, | |
| c=None, | |
| ws=None, | |
| ray_origins=None, | |
| ray_directions=None, | |
| z_bcg=None, | |
| neural_rendering_resolution=None, | |
| update_emas=False, | |
| cache_backbone=False, | |
| use_cached_backbone=False, | |
| return_meta=False, | |
| return_raw_only=False, | |
| sample_ray_only=False, | |
| fg_bbox=None, | |
| **synthesis_kwargs): | |
| cam2world_matrix = c[:, :16].reshape(-1, 4, 4) | |
| # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0) | |
| # c[:, :16] = cam2world_matrix.view(-1, 16) | |
| intrinsics = c[:, 16:25].reshape(-1, 3, 3) | |
| if neural_rendering_resolution is None: | |
| neural_rendering_resolution = self.neural_rendering_resolution | |
| else: | |
| self.neural_rendering_resolution = neural_rendering_resolution | |
| if ray_directions is None: # when output video | |
| H = W = self.neural_rendering_resolution | |
| # Create a batch of rays for volume rendering | |
| # ray_origins, ray_directions, ray_bboxes = self.ray_sampler( | |
| # cam2world_matrix, intrinsics, neural_rendering_resolution) | |
| if sample_ray_only: # ! for sampling | |
| ray_origins, ray_directions, ray_bboxes = self.ray_sampler( | |
| cam2world_matrix, intrinsics, | |
| self.rendering_kwargs.get( 'patch_rendering_resolution' ), | |
| self.neural_rendering_resolution, fg_bbox) | |
| # for patch supervision | |
| ret_dict = { | |
| 'ray_origins': ray_origins, | |
| 'ray_directions': ray_directions, | |
| 'ray_bboxes': ray_bboxes, | |
| } | |
| return ret_dict | |
| else: # ! for rendering | |
| ray_origins, ray_directions, _ = self.ray_sampler( | |
| cam2world_matrix, intrinsics, self.neural_rendering_resolution, | |
| self.neural_rendering_resolution) | |
| else: | |
| assert ray_origins is not None | |
| H = W = int(ray_directions.shape[1]** | |
| 0.5) # dynamically set patch resolution | |
| # ! match the batch size, if not returned | |
| if planes is None: | |
| assert self.planes is not None | |
| planes = self.planes.repeat_interleave(c.shape[0], dim=0) | |
| return_sampling_details_flag = self.rendering_kwargs.get( | |
| 'return_sampling_details_flag', False) | |
| if return_sampling_details_flag: | |
| return_meta = True | |
| # Create triplanes by running StyleGAN backbone | |
| N, M, _ = ray_origins.shape | |
| # Reshape output into three 32-channel planes | |
| if planes.shape[1] == 3 * 2 * self.decoder_in_chans: | |
| # if isinstance(planes, tuple): | |
| # N *= 2 | |
| triplane_bg = True | |
| # planes = torch.cat(planes, 0) # inference in parallel | |
| # ray_origins = ray_origins.repeat(2,1,1) | |
| # ray_directions = ray_directions.repeat(2,1,1) | |
| else: | |
| triplane_bg = False | |
| # assert not triplane_bg | |
| # ! hard coded, will fix later | |
| # if planes.shape[1] == 3 * self.decoder_in_chans: | |
| # else: | |
| # planes = planes.view(len(planes), 3, self.decoder_in_chans, | |
| planes = planes.reshape( | |
| len(planes), | |
| 3, | |
| -1, # ! support background plane | |
| planes.shape[-2], | |
| planes.shape[-1]) # BS 96 256 256 | |
| # Perform volume rendering | |
| rendering_details = self.renderer(planes, | |
| self.decoder, | |
| ray_origins, | |
| ray_directions, | |
| self.rendering_kwargs, | |
| return_meta=return_meta) | |
| feature_samples, depth_samples, weights_samples = ( | |
| rendering_details[k] | |
| for k in ['feature_samples', 'depth_samples', 'weights_samples']) | |
| if return_sampling_details_flag: | |
| shape_synthesized = rendering_details['shape_synthesized'] | |
| else: | |
| shape_synthesized = None | |
| # Reshape into 'raw' neural-rendered image | |
| feature_image = feature_samples.permute(0, 2, 1).reshape( | |
| N, feature_samples.shape[-1], H, | |
| W).contiguous() # B 32 H W, in [-1,1] | |
| depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
| weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
| # Generate Background | |
| # if self.bcg_synthesis: | |
| # # bg composition | |
| # # if self.decoder.activation == "sigmoid": | |
| # # feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher | |
| # assert isinstance( | |
| # z_bcg, torch.Tensor | |
| # ) # 512 latents after reparmaterization, reuse the name | |
| # # ws_bcg = ws[:,:self.bcg_synthesis.num_ws] if ws_bcg is None else ws_bcg[:,:self.bcg_synthesis.num_ws] | |
| # with torch.autocast(device_type='cuda', | |
| # dtype=torch.float16, | |
| # enabled=False): | |
| # ws_bcg = self.bcg_mapping(z_bcg, c=None) # reuse the name | |
| # if ws_bcg.size(1) < self.bcg_synthesis.num_ws: | |
| # ws_bcg = torch.cat([ | |
| # ws_bcg, ws_bcg[:, -1:].repeat( | |
| # 1, self.bcg_synthesis.num_ws - ws_bcg.size(1), 1) | |
| # ], 1) | |
| # bcg_image = self.bcg_synthesis(ws_bcg, | |
| # update_emas=update_emas, | |
| # **synthesis_kwargs) | |
| # bcg_image = torch.nn.functional.interpolate( | |
| # bcg_image, | |
| # size=feature_image.shape[2:], | |
| # mode='bilinear', | |
| # align_corners=False, | |
| # antialias=self.rendering_kwargs['sr_antialias']) | |
| # feature_image = feature_image + (1 - weights_samples) * bcg_image | |
| # # Generate Raw image | |
| # assert self.torgb | |
| # rgb_image = self.torgb(feature_image, | |
| # ws_bcg[:, -1], | |
| # fused_modconv=False) | |
| # rgb_image = rgb_image.to(dtype=torch.float32, | |
| # memory_format=torch.contiguous_format) | |
| # # st() | |
| # else: | |
| mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 | |
| if triplane_bg: | |
| # true_bs = N // 2 | |
| # weights_samples = weights_samples[:true_bs] | |
| # mask_image = mask_image[:true_bs] | |
| # feature_image = feature_image[:true_bs] * mask_image + feature_image[true_bs:] * (1-mask_image) # the first is foreground | |
| # depth_image = depth_image[:true_bs] | |
| # ! composited colors | |
| # rgb_final = ( | |
| # 1 - fg_ret_dict['weights'] | |
| # ) * bg_ret_dict['rgb_final'] + fg_ret_dict[ | |
| # 'feature_samples'] # https://github.com/SizheAn/PanoHead/blob/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/training/triplane.py#L127C45-L127C64 | |
| # ret_dict.update({ | |
| # 'feature_samples': rgb_final, | |
| # }) | |
| # st() | |
| feature_image = (1 - mask_image) * rendering_details[ | |
| 'bg_ret_dict']['rgb_final'] + feature_image | |
| rgb_image = feature_image[:, :3] | |
| # # Run superresolution to get final image | |
| if self.superresolution is not None and not return_raw_only: | |
| # assert ws is not None, 'feed in [cls] token here for SR module' | |
| if ws is not None and ws.ndim == 2: | |
| ws = ws.unsqueeze( | |
| 1)[:, -1:, :] # follow stylegan tradition, B, N, C | |
| sr_image = self.superresolution( | |
| rgb=rgb_image, | |
| x=feature_image, | |
| base_x=rgb_image, | |
| ws=ws, # only use the last layer | |
| noise_mode=self. | |
| rendering_kwargs['superresolution_noise_mode'], # none | |
| **{ | |
| k: synthesis_kwargs[k] | |
| for k in synthesis_kwargs.keys() if k != 'noise_mode' | |
| }) | |
| else: | |
| # sr_image = rgb_image | |
| sr_image = None | |
| if shape_synthesized is not None: | |
| shape_synthesized.update({ | |
| 'image_depth': depth_image, | |
| }) # for 3D loss easy computation, wrap all 3D in a single dict | |
| ret_dict = { | |
| 'feature_image': feature_image, | |
| # 'image_raw': feature_image[:, :3], | |
| 'image_raw': rgb_image, | |
| 'image_depth': depth_image, | |
| 'weights_samples': weights_samples, | |
| # 'silhouette': mask_image, | |
| # 'silhouette_normalized_3channel': (mask_image*2-1).repeat_interleave(3,1), # N 3 H W | |
| 'shape_synthesized': shape_synthesized, | |
| "image_mask": mask_image, | |
| } | |
| if sr_image is not None: | |
| ret_dict.update({ | |
| 'image_sr': sr_image, | |
| }) | |
| if return_meta: | |
| ret_dict.update({ | |
| 'feature_volume': | |
| rendering_details['feature_volume'], | |
| 'all_coords': | |
| rendering_details['all_coords'], | |
| 'weights': | |
| rendering_details['weights'], | |
| }) | |
| return ret_dict | |
| class Triplane_fg_bg_plane(Triplane): | |
| # a separate background plane | |
| def __init__(self, | |
| c_dim=25, | |
| img_resolution=128, | |
| img_channels=3, | |
| out_chans=96, | |
| triplane_size=224, | |
| rendering_kwargs={}, | |
| decoder_in_chans=32, | |
| decoder_output_dim=32, | |
| sr_num_fp16_res=0, | |
| sr_kwargs={}, | |
| bcg_synthesis_kwargs={}): | |
| super().__init__(c_dim, img_resolution, img_channels, out_chans, | |
| triplane_size, rendering_kwargs, decoder_in_chans, | |
| decoder_output_dim, sr_num_fp16_res, sr_kwargs, | |
| bcg_synthesis_kwargs) | |
| self.bcg_decoder = Decoder( | |
| ch=64, # half channel size | |
| out_ch=32, | |
| # ch_mult=(1, 2, 4), | |
| ch_mult=(1, 2), # use res=64 for now | |
| num_res_blocks=2, | |
| dropout=0.0, | |
| attn_resolutions=(), | |
| z_channels=4, | |
| resolution=64, | |
| in_channels=3, | |
| ) | |
| # * pure reconstruction | |
| def forward( | |
| self, | |
| planes, | |
| bg_plane, | |
| # img, | |
| c, | |
| ws=None, | |
| z_bcg=None, | |
| neural_rendering_resolution=None, | |
| update_emas=False, | |
| cache_backbone=False, | |
| use_cached_backbone=False, | |
| return_meta=False, | |
| return_raw_only=False, | |
| **synthesis_kwargs): | |
| # ! match the batch size | |
| if planes is None: | |
| assert self.planes is not None | |
| planes = self.planes.repeat_interleave(c.shape[0], dim=0) | |
| return_sampling_details_flag = self.rendering_kwargs.get( | |
| 'return_sampling_details_flag', False) | |
| if return_sampling_details_flag: | |
| return_meta = True | |
| cam2world_matrix = c[:, :16].reshape(-1, 4, 4) | |
| # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0) | |
| # c[:, :16] = cam2world_matrix.view(-1, 16) | |
| intrinsics = c[:, 16:25].reshape(-1, 3, 3) | |
| if neural_rendering_resolution is None: | |
| neural_rendering_resolution = self.neural_rendering_resolution | |
| else: | |
| self.neural_rendering_resolution = neural_rendering_resolution | |
| H = W = self.neural_rendering_resolution | |
| # Create a batch of rays for volume rendering | |
| ray_origins, ray_directions, _ = self.ray_sampler( | |
| cam2world_matrix, intrinsics, neural_rendering_resolution) | |
| # Create triplanes by running StyleGAN backbone | |
| N, M, _ = ray_origins.shape | |
| # # Reshape output into three 32-channel planes | |
| # if planes.shape[1] == 3 * 2 * self.decoder_in_chans: | |
| # # if isinstance(planes, tuple): | |
| # # N *= 2 | |
| # triplane_bg = True | |
| # # planes = torch.cat(planes, 0) # inference in parallel | |
| # # ray_origins = ray_origins.repeat(2,1,1) | |
| # # ray_directions = ray_directions.repeat(2,1,1) | |
| # else: | |
| # triplane_bg = False | |
| # assert not triplane_bg | |
| planes = planes.view( | |
| len(planes), | |
| 3, | |
| -1, # ! support background plane | |
| planes.shape[-2], | |
| planes.shape[-1]) # BS 96 256 256 | |
| # Perform volume rendering | |
| rendering_details = self.renderer(planes, | |
| self.decoder, | |
| ray_origins, | |
| ray_directions, | |
| self.rendering_kwargs, | |
| return_meta=return_meta) | |
| feature_samples, depth_samples, weights_samples = ( | |
| rendering_details[k] | |
| for k in ['feature_samples', 'depth_samples', 'weights_samples']) | |
| if return_sampling_details_flag: | |
| shape_synthesized = rendering_details['shape_synthesized'] | |
| else: | |
| shape_synthesized = None | |
| # Reshape into 'raw' neural-rendered image | |
| feature_image = feature_samples.permute(0, 2, 1).reshape( | |
| N, feature_samples.shape[-1], H, | |
| W).contiguous() # B 32 H W, in [-1,1] | |
| depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
| weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
| bcg_image = self.bcg_decoder(bg_plane) | |
| bcg_image = torch.nn.functional.interpolate( | |
| bcg_image, | |
| size=feature_image.shape[2:], | |
| mode='bilinear', | |
| align_corners=False, | |
| antialias=self.rendering_kwargs['sr_antialias']) | |
| mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 | |
| # ! fuse fg/bg model output | |
| feature_image = feature_image + (1 - weights_samples) * bcg_image | |
| rgb_image = feature_image[:, :3] | |
| # # Run superresolution to get final image | |
| if self.superresolution is not None and not return_raw_only: | |
| # assert ws is not None, 'feed in [cls] token here for SR module' | |
| if ws is not None and ws.ndim == 2: | |
| ws = ws.unsqueeze( | |
| 1)[:, -1:, :] # follow stylegan tradition, B, N, C | |
| sr_image = self.superresolution( | |
| rgb=rgb_image, | |
| x=feature_image, | |
| base_x=rgb_image, | |
| ws=ws, # only use the last layer | |
| noise_mode=self. | |
| rendering_kwargs['superresolution_noise_mode'], # none | |
| **{ | |
| k: synthesis_kwargs[k] | |
| for k in synthesis_kwargs.keys() if k != 'noise_mode' | |
| }) | |
| else: | |
| # sr_image = rgb_image | |
| sr_image = None | |
| if shape_synthesized is not None: | |
| shape_synthesized.update({ | |
| 'image_depth': depth_image, | |
| }) # for 3D loss easy computation, wrap all 3D in a single dict | |
| ret_dict = { | |
| 'feature_image': feature_image, | |
| # 'image_raw': feature_image[:, :3], | |
| 'image_raw': rgb_image, | |
| 'image_depth': depth_image, | |
| 'weights_samples': weights_samples, | |
| # 'silhouette': mask_image, | |
| # 'silhouette_normalized_3channel': (mask_image*2-1).repeat_interleave(3,1), # N 3 H W | |
| 'shape_synthesized': shape_synthesized, | |
| "image_mask": mask_image, | |
| } | |
| if sr_image is not None: | |
| ret_dict.update({ | |
| 'image_sr': sr_image, | |
| }) | |
| if return_meta: | |
| ret_dict.update({ | |
| 'feature_volume': | |
| rendering_details['feature_volume'], | |
| 'all_coords': | |
| rendering_details['all_coords'], | |
| 'weights': | |
| rendering_details['weights'], | |
| }) | |
| return ret_dict | |