Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Copyright 2021 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Lint as: python3 | |
| """Different model implementation plus a general port for all the models.""" | |
| from typing import Any, Callable | |
| from flax import linen as nn | |
| from jax import random | |
| import jax.numpy as jnp | |
| from jaxnerf.nerf import model_utils | |
| from jaxnerf.nerf import utils | |
| def get_model(key, example_batch, args): | |
| """A helper function that wraps around a 'model zoo'.""" | |
| model_dict = {"nerf": construct_nerf} | |
| return model_dict[args.model](key, example_batch, args) | |
| class NerfModel(nn.Module): | |
| """Nerf NN Model with both coarse and fine MLPs.""" | |
| num_coarse_samples: int # The number of samples for the coarse nerf. | |
| num_fine_samples: int # The number of samples for the fine nerf. | |
| use_viewdirs: bool # If True, use viewdirs as an input. | |
| near: float # The distance to the near plane | |
| far: float # The distance to the far plane | |
| noise_std: float # The std dev of noise added to raw sigma. | |
| net_depth: int # The depth of the first part of MLP. | |
| net_width: int # The width of the first part of MLP. | |
| net_depth_condition: int # The depth of the second part of MLP. | |
| net_width_condition: int # The width of the second part of MLP. | |
| net_activation: Callable[..., Any] # MLP activation | |
| skip_layer: int # How often to add skip connections. | |
| num_rgb_channels: int # The number of RGB channels. | |
| num_sigma_channels: int # The number of density channels. | |
| white_bkgd: bool # If True, use a white background. | |
| min_deg_point: int # The minimum degree of positional encoding for positions. | |
| max_deg_point: int # The maximum degree of positional encoding for positions. | |
| deg_view: int # The degree of positional encoding for viewdirs. | |
| lindisp: bool # If True, sample linearly in disparity rather than in depth. | |
| rgb_activation: Callable[..., Any] # Output RGB activation. | |
| sigma_activation: Callable[..., Any] # Output sigma activation. | |
| legacy_posenc_order: bool # Keep the same ordering as the original tf code. | |
| def __call__(self, rng_0, rng_1, rays, randomized): | |
| """Nerf Model. | |
| Args: | |
| rng_0: jnp.ndarray, random number generator for coarse model sampling. | |
| rng_1: jnp.ndarray, random number generator for fine model sampling. | |
| rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs. | |
| randomized: bool, use randomized stratified sampling. | |
| Returns: | |
| ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)] | |
| """ | |
| # Stratified sampling along rays | |
| key, rng_0 = random.split(rng_0) | |
| z_vals, samples = model_utils.sample_along_rays( | |
| key, | |
| rays.origins, | |
| rays.directions, | |
| self.num_coarse_samples, | |
| self.near, | |
| self.far, | |
| randomized, | |
| self.lindisp, | |
| ) | |
| samples_enc = model_utils.posenc( | |
| samples, | |
| self.min_deg_point, | |
| self.max_deg_point, | |
| self.legacy_posenc_order, | |
| ) | |
| # Construct the "coarse" MLP. | |
| coarse_mlp = model_utils.MLP( | |
| net_depth=self.net_depth, | |
| net_width=self.net_width, | |
| net_depth_condition=self.net_depth_condition, | |
| net_width_condition=self.net_width_condition, | |
| net_activation=self.net_activation, | |
| skip_layer=self.skip_layer, | |
| num_rgb_channels=self.num_rgb_channels, | |
| num_sigma_channels=self.num_sigma_channels) | |
| # Point attribute predictions | |
| if self.use_viewdirs: | |
| viewdirs_enc = model_utils.posenc( | |
| rays.viewdirs, | |
| 0, | |
| self.deg_view, | |
| self.legacy_posenc_order, | |
| ) | |
| raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc) | |
| else: | |
| viewdirs_enc = None | |
| raw_rgb, raw_sigma = coarse_mlp(samples_enc) | |
| # Add noises to regularize the density predictions if needed | |
| key, rng_0 = random.split(rng_0) | |
| raw_sigma = model_utils.add_gaussian_noise( | |
| key, | |
| raw_sigma, | |
| self.noise_std, | |
| randomized, | |
| ) | |
| rgb = self.rgb_activation(raw_rgb) | |
| sigma = self.sigma_activation(raw_sigma) | |
| # Volumetric rendering. | |
| comp_rgb, disp, acc, weights = model_utils.volumetric_rendering( | |
| rgb, | |
| sigma, | |
| z_vals, | |
| rays.directions, | |
| white_bkgd=self.white_bkgd, | |
| ) | |
| ret = [ | |
| (comp_rgb, disp, acc), | |
| ] | |
| # Hierarchical sampling based on coarse predictions | |
| if self.num_fine_samples > 0: | |
| z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) | |
| key, rng_1 = random.split(rng_1) | |
| z_vals, samples = model_utils.sample_pdf( | |
| key, | |
| z_vals_mid, | |
| weights[..., 1:-1], | |
| rays.origins, | |
| rays.directions, | |
| z_vals, | |
| self.num_fine_samples, | |
| randomized, | |
| ) | |
| samples_enc = model_utils.posenc( | |
| samples, | |
| self.min_deg_point, | |
| self.max_deg_point, | |
| self.legacy_posenc_order, | |
| ) | |
| # Construct the "fine" MLP. | |
| fine_mlp = model_utils.MLP( | |
| net_depth=self.net_depth, | |
| net_width=self.net_width, | |
| net_depth_condition=self.net_depth_condition, | |
| net_width_condition=self.net_width_condition, | |
| net_activation=self.net_activation, | |
| skip_layer=self.skip_layer, | |
| num_rgb_channels=self.num_rgb_channels, | |
| num_sigma_channels=self.num_sigma_channels) | |
| if self.use_viewdirs: | |
| raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc) | |
| else: | |
| raw_rgb, raw_sigma = fine_mlp(samples_enc) | |
| key, rng_1 = random.split(rng_1) | |
| raw_sigma = model_utils.add_gaussian_noise( | |
| key, | |
| raw_sigma, | |
| self.noise_std, | |
| randomized, | |
| ) | |
| rgb = self.rgb_activation(raw_rgb) | |
| sigma = self.sigma_activation(raw_sigma) | |
| comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering( | |
| rgb, | |
| sigma, | |
| z_vals, | |
| rays.directions, | |
| white_bkgd=self.white_bkgd, | |
| ) | |
| ret.append((comp_rgb, disp, acc)) | |
| return ret | |
| def construct_nerf(key, example_batch, args): | |
| """Construct a Neural Radiance Field. | |
| Args: | |
| key: jnp.ndarray. Random number generator. | |
| example_batch: dict, an example of a batch of data. | |
| args: FLAGS class. Hyperparameters of nerf. | |
| Returns: | |
| model: nn.Model. Nerf model with parameters. | |
| state: flax.Module.state. Nerf model state for stateful parameters. | |
| """ | |
| net_activation = getattr(nn, str(args.net_activation)) | |
| rgb_activation = getattr(nn, str(args.rgb_activation)) | |
| sigma_activation = getattr(nn, str(args.sigma_activation)) | |
| # Assert that rgb_activation always produces outputs in [0, 1], and | |
| # sigma_activation always produce non-negative outputs. | |
| x = jnp.exp(jnp.linspace(-90, 90, 1024)) | |
| x = jnp.concatenate([-x[::-1], x], 0) | |
| rgb = rgb_activation(x) | |
| if jnp.any(rgb < 0) or jnp.any(rgb > 1): | |
| raise NotImplementedError( | |
| "Choice of rgb_activation `{}` produces colors outside of [0, 1]" | |
| .format(args.rgb_activation)) | |
| sigma = sigma_activation(x) | |
| if jnp.any(sigma < 0): | |
| raise NotImplementedError( | |
| "Choice of sigma_activation `{}` produces negative densities".format( | |
| args.sigma_activation)) | |
| model = NerfModel( | |
| min_deg_point=args.min_deg_point, | |
| max_deg_point=args.max_deg_point, | |
| deg_view=args.deg_view, | |
| num_coarse_samples=args.num_coarse_samples, | |
| num_fine_samples=args.num_fine_samples, | |
| use_viewdirs=args.use_viewdirs, | |
| near=args.near, | |
| far=args.far, | |
| noise_std=args.noise_std, | |
| white_bkgd=args.white_bkgd, | |
| net_depth=args.net_depth, | |
| net_width=args.net_width, | |
| net_depth_condition=args.net_depth_condition, | |
| net_width_condition=args.net_width_condition, | |
| skip_layer=args.skip_layer, | |
| num_rgb_channels=args.num_rgb_channels, | |
| num_sigma_channels=args.num_sigma_channels, | |
| lindisp=args.lindisp, | |
| net_activation=net_activation, | |
| rgb_activation=rgb_activation, | |
| sigma_activation=sigma_activation, | |
| legacy_posenc_order=args.legacy_posenc_order) | |
| rays = example_batch["rays"] | |
| key1, key2, key3 = random.split(key, num=3) | |
| init_variables = model.init( | |
| key1, | |
| rng_0=key2, | |
| rng_1=key3, | |
| rays=utils.namedtuple_map(lambda x: x[0], rays), | |
| randomized=args.randomized) | |
| return model, init_variables | |