""" NeRF differentiable renderer. References: https://github.com/bmild/nerf https://github.com/kwea123/nerf_pl """ import torch import torch.autograd.profiler as profiler from dotmap import DotMap class _RenderWrapper(torch.nn.Module): def __init__(self, net, renderer, simple_output): super().__init__() self.net = net self.renderer = renderer self.simple_output = simple_output def forward( self, rays, want_weights=False, want_alphas=False, want_z_samps=False, want_rgb_samps=False, sample_from_dist=None, ): if rays.shape[0] == 0: return ( torch.zeros(0, 3, device=rays.device), torch.zeros(0, device=rays.device), ) outputs = self.renderer( self.net, rays, want_weights=want_weights and not self.simple_output, want_alphas=want_alphas and not self.simple_output, want_z_samps=want_z_samps and not self.simple_output, want_rgb_samps=want_rgb_samps and not self.simple_output, sample_from_dist=sample_from_dist, ) if self.simple_output: if self.renderer.using_fine: rgb = outputs.fine.rgb depth = outputs.fine.depth else: rgb = outputs.coarse.rgb depth = outputs.coarse.depth return rgb, depth else: # Make DotMap to dict to support DataParallel return outputs.toDict() class NeRFRenderer(torch.nn.Module): """ NeRF differentiable renderer :param n_coarse number of coarse (binned uniform) samples :param n_fine number of fine (importance) samples :param n_fine_depth number of expected depth samples :param noise_std noise to add to sigma. We do not use it :param depth_std noise for depth samples :param eval_batch_size ray batch size for evaluation :param white_bkgd if true, background color is white; else black :param lindisp if to use samples linear in disparity instead of distance :param sched ray sampling schedule. list containing 3 lists of equal length. sched[0] is list of iteration numbers, sched[1] is list of coarse sample numbers, sched[2] is list of fine sample numbers """ def __init__( self, n_coarse=128, n_fine=0, n_fine_depth=0, noise_std=0.0, depth_std=0.01, eval_batch_size=100000, white_bkgd=False, lindisp=False, sched=None, # ray sampling schedule for coarse and fine rays hard_alpha_cap=False, render_mode="volumetric", surface_sigmoid_scale=.1, render_flow=False, normalize_dino=False, ): super().__init__() self.n_coarse, self.n_fine = n_coarse, n_fine self.n_fine_depth = n_fine_depth self.noise_std = noise_std self.depth_std = depth_std self.eval_batch_size = eval_batch_size self.white_bkgd = white_bkgd self.lindisp = lindisp if lindisp: print("Using linear displacement rays") self.using_fine = n_fine > 0 self.sched = sched if sched is not None and len(sched) == 0: self.sched = None self.register_buffer( "iter_idx", torch.tensor(0, dtype=torch.long), persistent=True ) self.register_buffer( "last_sched", torch.tensor(0, dtype=torch.long), persistent=True ) self.hard_alpha_cap = hard_alpha_cap assert render_mode in ("volumetric", "surface", "neus") self.render_mode = render_mode self.only_surface_color = (self.render_mode == "surface") self.surface_sigmoid_scale = surface_sigmoid_scale self.render_flow = render_flow self.normalize_dino = normalize_dino def sample_coarse(self, rays): """ Stratified sampling. Note this is different from original NeRF slightly. :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) :return (B, Kc) """ device = rays.device near, far = rays[:, 6:7], rays[:, 7:8] # (B, 1) step = 1.0 / self.n_coarse B = rays.shape[0] z_steps = torch.linspace(0, 1 - step, self.n_coarse, device=device) # (Kc) z_steps = z_steps.unsqueeze(0).repeat(B, 1) # (B, Kc) z_steps += torch.rand_like(z_steps) * step if not self.lindisp: # Use linear sampling in depth space return near * (1 - z_steps) + far * z_steps # (B, Kf) else: # Use linear sampling in disparity space return 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf) # Use linear sampling in depth space return near * (1 - z_steps) + far * z_steps # (B, Kc) def sample_coarse_from_dist(self, rays, weights, z_samp): device = rays.device B = rays.shape[0] num_bins = weights.shape[-1] num_samples = self.n_coarse weights = weights.detach() + 1e-5 # Prevent division by zero pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc) cdf = torch.cumsum(pdf, -1) # (B, Kc) cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1) u = torch.rand(B, num_samples, dtype=torch.float32, device=device) # (B, Kf) interval_ids = torch.searchsorted(cdf, u, right=True) - 1 # (B, Kf) interval_ids = torch.clamp(interval_ids, 0, num_samples - 1) interval_interp = torch.rand_like(interval_ids, dtype=torch.float32) # z_samps describe the centers of the respective histogram bins. Therefore, we have to extend them to the left and right if self.lindisp: z_samp = 1 / z_samp centers = 0.5 * (z_samp[:, 1:] + z_samp[:, :-1]) interval_borders = torch.cat((z_samp[:, :1], centers, z_samp[:, -1:]), dim=-1) left_border = torch.gather(interval_borders, dim=-1, index=interval_ids) right_border = torch.gather(interval_borders, dim=-1, index=interval_ids + 1) z_samp_new = ( left_border * (1 - interval_interp) + right_border * interval_interp ) if self.lindisp: z_samp_new = 1 / z_samp_new assert not torch.any(torch.isnan(z_samp_new)) return z_samp_new def sample_fine(self, rays, weights): """min Weighted stratified (importance) sample :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) :param weights (B, Kc) :return (B, Kf-Kfd) """ device = rays.device B = rays.shape[0] weights = weights.detach() + 1e-5 # Prevent division by zero pdf = weights / torch.sum(weights, -1, keepdim=True) # (B, Kc) cdf = torch.cumsum(pdf, -1) # (B, Kc) cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (B, Kc+1) u = torch.rand( B, self.n_fine - self.n_fine_depth, dtype=torch.float32, device=device ) # (B, Kf) inds = torch.searchsorted(cdf, u, right=True).float() - 1.0 # (B, Kf) inds = torch.clamp_min(inds, 0.0) z_steps = (inds + torch.rand_like(inds)) / self.n_coarse # (B, Kf) near, far = rays[:, 6:7], rays[:, 7:8] # (B, 1) if not self.lindisp: # Use linear sampling in depth space z_samp = near * (1 - z_steps) + far * z_steps # (B, Kf) else: # Use linear sampling in disparity space z_samp = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) # (B, Kf) assert not torch.any(torch.isnan(z_samp)) return z_samp def sample_fine_depth(self, rays, depth): """ Sample around specified depth :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) :param depth (B) :return (B, Kfd) """ z_samp = depth.unsqueeze(1).repeat((1, self.n_fine_depth)) z_samp += torch.randn_like(z_samp) * self.depth_std # Clamp does not support tensor bounds z_samp = torch.max(torch.min(z_samp, rays[:, 7:8]), rays[:, 6:7]) assert not torch.any(torch.isnan(z_samp)) return z_samp def composite(self, model, rays, z_samp, coarse=True, sb=0): """ Render RGB and depth for each ray using NeRF alpha-compositing formula, given sampled positions along each ray (see sample_*) :param model should return (B, (r, g, b, sigma)) when called with (B, (x, y, z)) should also support 'coarse' boolean argument :param rays ray [origins (3), directions (3), near (1), far (1)] (B, 8) :param z_samp z positions sampled for each ray (B, K) :param coarse whether to evaluate using coarse NeRF :param sb super-batch dimension; 0 = disable :return weights (B, K), rgb (B, 3), depth (B) """ with profiler.record_function("renderer_composite"): B, K = z_samp.shape r_dim = rays.shape[-1] deltas = z_samp[:, 1:] - z_samp[:, :-1] # (B, K-1) delta_inf = 1e10 * torch.ones_like(deltas[:, :1]) # infty (B, 1) # delta_inf = rays[:, -1:] - z_samp[:, -1:] deltas = torch.cat([deltas, delta_inf], -1) # (B, K) # (B, K, 3) points = rays[:, None, :3] + z_samp.unsqueeze(2) * rays[:, None, 3:6] points = points.reshape(-1, 3) # (B*K, 3) if r_dim > 8: ray_info = rays[:, None, 8:].expand(-1, K, -1) else: ray_info = None if hasattr(model, "use_viewdirs"): use_viewdirs = model.use_viewdirs else: use_viewdirs = None viewdirs_all = [] rgbs_all, invalid_all, sigmas_all, extras_all, state_dicts_all = [], [], [], [], [] if sb > 0: points = points.reshape( sb, -1, 3 ) # (SB, B'*K, 3) B' is real ray batch size if ray_info is not None: ray_info = ray_info.reshape(sb, -1, ray_info.shape[-1]) eval_batch_dim = 1 eval_batch_size = (self.eval_batch_size - 1) // sb + 1 else: eval_batch_size = self.eval_batch_size eval_batch_dim = 0 split_points = torch.split(points, eval_batch_size, dim=eval_batch_dim) if ray_info is not None: split_ray_infos = torch.split(ray_info, eval_batch_size, dim=eval_batch_dim) else: split_ray_infos = [None for _ in split_points] if use_viewdirs: dim1 = K viewdirs = rays[:, None, 3:6].expand(-1, dim1, -1) if sb > 0: viewdirs = viewdirs.reshape(sb, -1, 3) # (SB, B'*K, 3) else: viewdirs = viewdirs.reshape(-1, 3) # (B*K, 3) split_viewdirs = torch.split( viewdirs, eval_batch_size, dim=eval_batch_dim ) for i, pnts in enumerate(split_points): dirs = split_viewdirs[i] infos = split_ray_infos[i] rgbs, invalid, sigmas, extras, state_dict = model( pnts, coarse=coarse, viewdirs=dirs, only_density=self.only_surface_color, ray_info=ray_info, render_flow=self.render_flow ) rgbs_all.append(rgbs) invalid_all.append(invalid) sigmas_all.append(sigmas) extras_all.append(extras) viewdirs_all.append(dirs) if state_dict is not None: state_dicts_all.append(state_dict) else: for i, pnts in enumerate(split_points): infos = split_ray_infos[i] rgbs, invalid, sigmas, extras, state_dict = model(pnts, coarse=coarse, only_density=self.only_surface_color, ray_info=infos, render_flow=self.render_flow) rgbs_all.append(rgbs) invalid_all.append(invalid) sigmas_all.append(sigmas) extras_all.append(extras) if state_dict is not None: state_dicts_all.append(state_dict) points, viewdirs = None, None # (B*K, 4) OR (SB, B'*K, 4) if not self.only_surface_color: rgbs = torch.cat(rgbs_all, dim=eval_batch_dim) else: rgbs = None invalid = torch.cat(invalid_all, dim=eval_batch_dim) sigmas = torch.cat(sigmas_all, dim=eval_batch_dim) if not extras_all[0] is None: extras = torch.cat(extras_all, dim=eval_batch_dim) else: extras = None deltas = deltas.float() sigmas = sigmas.float() if ( state_dicts_all is not None and len(state_dicts_all) != 0 ): ## not empty in a list state_dicts = { key: torch.cat( [state_dicts[key] for state_dicts in state_dicts_all], dim=eval_batch_dim, ) for key in state_dicts_all[0].keys() } else: state_dicts = None if rgbs is not None: rgbs = rgbs.reshape(B, K, -1) # (B, K, 4 or 5) invalid = invalid.reshape(B, K, -1) sigmas = sigmas.reshape(B, K) if extras is not None: extras = extras.reshape(B, K, -1) if state_dicts is not None: state_dicts = { key: value.reshape(B, K, *value.shape[2:]) for key, value in state_dicts.items() } # BxKx... (BxKxn_viewsx...) if self.training and self.noise_std > 0.0: sigmas = sigmas + torch.randn_like(sigmas) * self.noise_std alphas = 1 - torch.exp( -deltas.abs() * torch.relu(sigmas) ) # (B, K) (delta should be positive anyways) if self.hard_alpha_cap: alphas[:, -1] = 1 deltas, sigmas = None, None alphas_shifted = torch.cat( [torch.ones_like(alphas[:, :1]), 1 - alphas + 1e-10], -1 ) # (B, K+1) = [1, a1, a2, ...] T = torch.cumprod(alphas_shifted, -1) # (B) weights = alphas * T[:, :-1] # (B, K) # alphas = None alphas_shifted = None depth_final = torch.sum(weights * z_samp, -1) # (B) state_dicts["dino_features"] = torch.sum(state_dicts["dino_features"].mul_(weights.unsqueeze(-1)), -2) if self.render_mode == "neus": # dist_from_surf = z_samp - depth_final[..., None] indices = torch.arange(0, weights.shape[-1], device=weights.device, dtype=weights.dtype).unsqueeze(0) surface_index = torch.sum(weights * indices, dim=-1, keepdim=True) dist_from_surf = surface_index - indices weights = torch.exp(-.5 * (dist_from_surf * self.surface_sigmoid_scale) ** 2) weights = weights / torch.sum(weights, dim=-1, keepdim=True) if not self.only_surface_color: rgb_final = torch.sum(weights.unsqueeze(-1) * rgbs, -2) # (B, 3) else: surface_points = rays[:, None, :3] + depth_final[:, None, None] * rays[:, None, 3:6] surface_points = surface_points.reshape(sb, -1, 3) if ray_info is not None: ray_info = ray_info.reshape(sb, -1, K, ray_info.shape[-1])[:, :, 0, :] rgb_final, invalid_colors = model.sample_colors(surface_points, ray_info=ray_info, render_flow=self.render_flow) rgb_final = rgb_final.permute(0, 2, 1, 3).reshape(B, -1) invalid_colors = invalid_colors.permute(0, 2, 1, 3).reshape(B, 1, -1) invalid = ((invalid > .5) | invalid_colors).float() if self.white_bkgd: # White background pix_alpha = weights.sum(dim=1) # (B), pixel alpha rgb_final = rgb_final + 1 - pix_alpha.unsqueeze(-1) # (B, 3) if extras is not None: extras_final = torch.sum(weights.unsqueeze(-1) * extras, -2) # (B, extras) else: extras_final = None for name, x in [("weights", weights), ("rgb_final", rgb_final), ("depth_final", depth_final), ("alphas", alphas), ("invalid", invalid), ("z_samp", z_samp)]: if torch.any(torch.isnan(x)): print(f"Detected NaN in {name} ({x.dtype}):") print(x) exit() if ray_info is not None: ray_info = rays[:, None, 8:] # return (weights, rgb_final, depth_final, alphas, invalid, z_samp, rgbs, viewdirs) return ( weights, rgb_final, depth_final, alphas, invalid, z_samp, rgbs, ray_info, extras_final, state_dicts, ) def forward( self, model, rays, want_weights=False, want_alphas=False, want_z_samps=False, want_rgb_samps=False, sample_from_dist=None, ): """ :model nerf model, should return (SB, B, (r, g, b, sigma)) when called with (SB, B, (x, y, z)), for multi-object: SB = 'super-batch' = size of object batch, B = size of per-object ray batch. Should also support 'coarse' boolean argument for coarse NeRF. :param rays ray spec [origins (3), directions (3), near (1), far (1)] (SB, B, 8) :param want_weights if true, returns compositing weights (SB, B, K) :return render dict """ with profiler.record_function("renderer_forward"): if self.sched is not None and self.last_sched.item() > 0: self.n_coarse = self.sched[1][self.last_sched.item() - 1] self.n_fine = self.sched[2][self.last_sched.item() - 1] assert len(rays.shape) == 3 superbatch_size = rays.shape[0] r_dim = rays.shape[-1] rays = rays.reshape(-1, r_dim) # (SB * B, 8) if sample_from_dist is None: z_coarse = self.sample_coarse(rays) # (B, Kc) else: prop_weights, prop_z_samp = sample_from_dist n_samples = prop_weights.shape[-1] prop_weights = prop_weights.reshape(-1, n_samples) prop_z_samp = prop_z_samp.reshape(-1, n_samples) z_coarse = self.sample_coarse_from_dist(rays, prop_weights, prop_z_samp) z_coarse, _ = torch.sort(z_coarse, dim=-1) coarse_composite = self.composite( model, rays, z_coarse, coarse=True, sb=superbatch_size, ) outputs = DotMap( coarse=self._format_outputs( coarse_composite, superbatch_size, want_weights=want_weights, want_alphas=want_alphas, want_z_samps=want_z_samps, want_rgb_samps=want_rgb_samps, ), ) outputs.state_dict = coarse_composite[-1] if self.using_fine: all_samps = [z_coarse] if self.n_fine - self.n_fine_depth > 0: all_samps.append( self.sample_fine(rays, coarse_composite[0].detach()) ) # (B, Kf - Kfd) if self.n_fine_depth > 0: all_samps.append( self.sample_fine_depth(rays, coarse_composite[2]) ) # (B, Kfd) z_combine = torch.cat(all_samps, dim=-1) # (B, Kc + Kf) z_combine_sorted, argsort = torch.sort(z_combine, dim=-1) fine_composite = self.composite( model, rays, z_combine_sorted, coarse=False, sb=superbatch_size, ) outputs.fine = self._format_outputs( fine_composite, superbatch_size, want_weights=want_weights, want_alphas=want_alphas, want_z_samps=want_z_samps, want_rgb_samps=want_rgb_samps, ) return outputs def _format_outputs( self, rendered_outputs, superbatch_size, want_weights=False, want_alphas=False, want_z_samps=False, want_rgb_samps=False, ): ( weights, rgb_final, depth, alphas, invalid, z_samps, rgb_samps, ray_info, extras, state_dict, ) = rendered_outputs n_smps = weights.shape[-1] out_d_rgb = rgb_final.shape[-1] out_d_i = invalid.shape[-1] out_d_dino = state_dict["dino_features"].shape[-1] if superbatch_size > 0: rgb_final = rgb_final.reshape(superbatch_size, -1, out_d_rgb) depth = depth.reshape(superbatch_size, -1) invalid = invalid.reshape(superbatch_size, -1, n_smps, out_d_i) ret_dict = DotMap(rgb=rgb_final, depth=depth, invalid=invalid) if ray_info is not None: ri_shape = ray_info.shape[-1] ray_info = ray_info.reshape(superbatch_size, -1, ri_shape) ret_dict.ray_info = ray_info if extras is not None: extras_shape = extras.shape[-1] extras = extras.reshape(superbatch_size, -1, extras_shape) ret_dict.extras = extras if want_weights: weights = weights.reshape(superbatch_size, -1, n_smps) ret_dict.weights = weights if want_alphas: alphas = alphas.reshape(superbatch_size, -1, n_smps) ret_dict.alphas = alphas if want_z_samps: z_samps = z_samps.reshape(superbatch_size, -1, n_smps) ret_dict.z_samps = z_samps if want_rgb_samps: rgb_samps = rgb_samps.reshape(superbatch_size, -1, n_smps, out_d_rgb) ret_dict.rgb_samps = rgb_samps if "dino_features" in state_dict: dino_features = state_dict["dino_features"].reshape(superbatch_size, -1, out_d_dino) ret_dict.dino_features = dino_features if "invalid_features" in state_dict: invalid_features = state_dict["invalid_features"].reshape(superbatch_size, -1, n_smps, out_d_i) ret_dict.invalid_features = invalid_features return ret_dict def sched_step(self, steps=1): """ Called each training iteration to update sample numbers according to schedule """ if self.sched is None: return self.iter_idx += steps while ( self.last_sched.item() < len(self.sched[0]) and self.iter_idx.item() >= self.sched[0][self.last_sched.item()] ): self.n_coarse = self.sched[1][self.last_sched.item()] self.n_fine = self.sched[2][self.last_sched.item()] print( "INFO: NeRF sampling resolution changed on schedule ==> c", self.n_coarse, "f", self.n_fine, ) self.last_sched += 1 @classmethod def from_conf(cls, conf, white_bkgd=False, eval_batch_size=100000): return cls( conf.get("n_coarse", 128), conf.get("n_fine", 0), n_fine_depth=conf.get("n_fine_depth", 0), noise_std=conf.get("noise_std", 0.0), depth_std=conf.get("depth_std", 0.01), white_bkgd=conf.get("white_bkgd", white_bkgd), lindisp=conf.get("lindisp", True), eval_batch_size=conf.get("eval_batch_size", eval_batch_size), sched=conf.get("sched", None), hard_alpha_cap=conf.get("hard_alpha_cap", False), render_mode=conf.get("render_mode", "volumetric"), surface_sigmoid_scale=conf.get("surface_sigmoid_scale", 1), render_flow=conf.get("render_flow", False), normalize_dino=conf.get("normalize_dino", False), ) def bind_parallel(self, net, gpus=None, simple_output=False): """ Returns a wrapper module compatible with DataParallel. Specifically, it renders rays with this renderer but always using the given network instance. Specify a list of GPU ids in 'gpus' to apply DataParallel automatically. :param net A PixelNeRF network :param gpus list of GPU ids to parallize to. If length is 1, does not parallelize :param simple_output only returns rendered (rgb, depth) instead of the full render output map. Saves data tranfer cost. :return torch module """ wrapped = _RenderWrapper(net, self, simple_output=simple_output) if gpus is not None and len(gpus) > 1: print("Using multi-GPU", gpus) wrapped = torch.nn.DataParallel(wrapped, gpus, dim=1) return wrapped