Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Copyright 2021 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Lint as: python3 | |
| """Utility functions.""" | |
| import collections | |
| import os | |
| from os import path | |
| import pickle | |
| from absl import flags | |
| import flax | |
| import jax | |
| import jax.numpy as jnp | |
| import jax.scipy as jsp | |
| import numpy as np | |
| from PIL import Image | |
| import yaml | |
| from jaxnerf.nerf import datasets | |
| BASE_DIR = "jaxnerf" | |
| INTERNAL = False | |
| class TrainState: | |
| optimizer: flax.optim.Optimizer | |
| class Stats: | |
| loss: float | |
| psnr: float | |
| loss_c: float | |
| psnr_c: float | |
| weight_l2: float | |
| Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs")) | |
| def namedtuple_map(fn, tup): | |
| """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple.""" | |
| return type(tup)(*map(fn, tup)) | |
| def define_flags(): | |
| """Define flags for both training and evaluation modes.""" | |
| flags.DEFINE_string("train_dir", None, "where to store ckpts and logs") | |
| flags.DEFINE_string("data_dir", None, "input data directory.") | |
| flags.DEFINE_string("config", None, | |
| "using config files to set hyperparameters.") | |
| # CLIP part Flags | |
| flags.DEFINE_bool("use_semantic_loss", True, | |
| "whether use semantic loss or not") | |
| flags.DEFINE_string("precompute_pkl_path", None, | |
| "where to load the pickle file that precompute image features") | |
| flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP") | |
| flags.DEFINE_string("clip_output_dtype", "float32", | |
| "float32/ float16 (float16 for memory saving)") | |
| flags.DEFINE_integer("sc_loss_factor", 4, | |
| "factor for downsampling image (0/2/4). " | |
| "its compounded on top of another flag: factor") | |
| flags.DEFINE_integer("sc_loss_every", 16, | |
| "no. of steps to take before performing semantic loss evaluation") | |
| flags.DEFINE_float("sc_loss_mult", 10., | |
| "weighting for semantic loss from CLIP") | |
| # Dataset Flags | |
| # TODO(pratuls): rename to dataset_loader and consider cleaning up | |
| flags.DEFINE_enum("dataset", "blender", | |
| list(k for k in datasets.dataset_dict.keys()), | |
| "The type of dataset feed to nerf.") | |
| flags.DEFINE_enum( | |
| "batching", "single_image", ["single_image", "all_images"], | |
| "source of ray sampling when collecting training batch," | |
| "single_image for sampling from only one image in a batch," | |
| "all_images for sampling from all the training images.") | |
| flags.DEFINE_bool( | |
| "white_bkgd", True, "using white color as default background." | |
| "(used in the blender dataset only)") | |
| flags.DEFINE_integer("batch_size", 1024, | |
| "the number of rays in a mini-batch (for training).") | |
| flags.DEFINE_integer("factor", 4, | |
| "the downsample factor of images, 0 for no downsample.") | |
| flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.") | |
| flags.DEFINE_bool( | |
| "render_path", False, "render generated path if set true." | |
| "(used in the llff dataset only)") | |
| flags.DEFINE_integer( | |
| "llffhold", 8, "will take every 1/N images as LLFF test set." | |
| "(used in the llff dataset only)") | |
| flags.DEFINE_bool( | |
| "use_pixel_centers", False, | |
| "If True, generate rays through the center of each pixel. Note: While " | |
| "this is the correct way to handle rays, it is not the way rays are " | |
| "handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR " | |
| "compared to Vanilla NeRF.") | |
| # Model Flags | |
| flags.DEFINE_string("model", "nerf", "name of model to use.") | |
| flags.DEFINE_float("near", 2., "near clip of volumetric rendering.") | |
| flags.DEFINE_float("far", 6., "far clip of volumentric rendering.") | |
| flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.") | |
| flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.") | |
| flags.DEFINE_integer("net_depth_condition", 1, | |
| "depth of the second part of MLP.") | |
| flags.DEFINE_integer("net_width_condition", 128, | |
| "width of the second part of MLP.") | |
| flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay") | |
| flags.DEFINE_integer( | |
| "skip_layer", 4, "add a skip connection to the output vector of every" | |
| "skip_layer layers.") | |
| flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.") | |
| flags.DEFINE_integer("num_sigma_channels", 1, | |
| "the number of density channels.") | |
| flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.") | |
| flags.DEFINE_integer("min_deg_point", 0, | |
| "Minimum degree of positional encoding for points.") | |
| flags.DEFINE_integer("max_deg_point", 10, | |
| "Maximum degree of positional encoding for points.") | |
| flags.DEFINE_integer("deg_view", 4, | |
| "Degree of positional encoding for viewdirs.") | |
| flags.DEFINE_integer( | |
| "num_coarse_samples", 64, | |
| "the number of samples on each ray for the coarse model.") | |
| flags.DEFINE_integer("num_fine_samples", 128, | |
| "the number of samples on each ray for the fine model.") | |
| flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.") | |
| flags.DEFINE_float( | |
| "noise_std", None, "std dev of noise added to regularize sigma output." | |
| "(used in the llff dataset only)") | |
| flags.DEFINE_bool("lindisp", False, | |
| "sampling linearly in disparity rather than depth.") | |
| flags.DEFINE_string("net_activation", "relu", | |
| "activation function used within the MLP.") | |
| flags.DEFINE_string("rgb_activation", "sigmoid", | |
| "activation function used to produce RGB.") | |
| flags.DEFINE_string("sigma_activation", "relu", | |
| "activation function used to produce density.") | |
| flags.DEFINE_bool( | |
| "legacy_posenc_order", False, | |
| "If True, revert the positional encoding feature order to an older version of this codebase." | |
| ) | |
| # Train Flags | |
| flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.") | |
| flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.") | |
| flags.DEFINE_integer( | |
| "lr_delay_steps", 0, "The number of steps at the beginning of " | |
| "training to reduce the learning rate by lr_delay_mult") | |
| flags.DEFINE_float( | |
| "lr_delay_mult", 1., "A multiplier on the learning rate when the step " | |
| "is < lr_delay_steps") | |
| flags.DEFINE_float("grad_max_norm", 0., | |
| "The gradient clipping magnitude (disabled if == 0).") | |
| flags.DEFINE_float("grad_max_val", 0., | |
| "The gradient clipping value (disabled if == 0).") | |
| flags.DEFINE_integer("max_steps", 1000000, | |
| "the number of optimization steps.") | |
| flags.DEFINE_integer("save_every", 10000, | |
| "the number of steps to save a checkpoint.") | |
| flags.DEFINE_integer("print_every", 100, | |
| "the number of steps between reports to tensorboard.") | |
| flags.DEFINE_integer( | |
| "render_every", 5000, "the number of steps to render a test image," | |
| "better to be x00 for accurate step time record.") | |
| flags.DEFINE_integer("gc_every", 10000, | |
| "the number of steps to run python garbage collection.") | |
| flags.DEFINE_integer("few_shot", -1, | |
| "the number of images.") | |
| # Eval Flags | |
| flags.DEFINE_bool( | |
| "eval_once", True, | |
| "evaluate the model only once if true, otherwise keeping evaluating new" | |
| "checkpoints if there's any.") | |
| flags.DEFINE_bool("save_output", True, | |
| "save predicted images to disk if True.") | |
| flags.DEFINE_integer( | |
| "chunk", 8192, | |
| "the size of chunks for evaluation inferences, set to the value that" | |
| "fits your GPU/TPU memory.") | |
| def update_flags(args): | |
| """Update the flags in `args` with the contents of the config YAML file.""" | |
| pth = path.join(BASE_DIR, args.config + ".yaml") | |
| with open_file(pth, "r") as fin: | |
| configs = yaml.load(fin, Loader=yaml.FullLoader) | |
| # Only allow args to be updated if they already exist. | |
| invalid_args = list(set(configs.keys()) - set(dir(args))) | |
| if invalid_args: | |
| raise ValueError(f"Invalid args {invalid_args} in {pth}.") | |
| args.__dict__.update(configs) | |
| def open_file(pth, mode="r"): | |
| if not INTERNAL: | |
| return open(pth, mode=mode) | |
| def file_exists(pth): | |
| if not INTERNAL: | |
| return path.exists(pth) | |
| def listdir(pth): | |
| if not INTERNAL: | |
| return os.listdir(pth) | |
| def isdir(pth): | |
| if not INTERNAL: | |
| return path.isdir(pth) | |
| def makedirs(pth): | |
| if not INTERNAL: | |
| os.makedirs(pth) | |
| def render_image(render_fn, rays, rng, normalize_disp, chunk=8192): | |
| """Render all the pixels of an image (in test mode). | |
| Args: | |
| render_fn: function, jit-ed render function. | |
| rays: a `Rays` namedtuple, the rays to be rendered. | |
| rng: jnp.ndarray, random number generator (used in training mode only). | |
| normalize_disp: bool, if true then normalize `disp` to [0, 1]. | |
| chunk: int, the size of chunks to render sequentially. | |
| Returns: | |
| rgb: jnp.ndarray, rendered color image. | |
| disp: jnp.ndarray, rendered disparity image. | |
| acc: jnp.ndarray, rendered accumulated weights per pixel. | |
| """ | |
| height, width = rays[0].shape[:2] | |
| num_rays = height * width | |
| rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays) | |
| unused_rng, key_0, key_1 = jax.random.split(rng, 3) | |
| host_id = jax.host_id() | |
| results = [] | |
| for i in range(0, num_rays, chunk): | |
| # pylint: disable=cell-var-from-loop | |
| chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays) | |
| chunk_size = chunk_rays[0].shape[0] | |
| rays_remaining = chunk_size % jax.device_count() | |
| if rays_remaining != 0: | |
| padding = jax.device_count() - rays_remaining | |
| chunk_rays = namedtuple_map( | |
| lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays) | |
| else: | |
| padding = 0 | |
| # After padding the number of chunk_rays is always divisible by | |
| # host_count. | |
| rays_per_host = chunk_rays[0].shape[0] // jax.process_count() | |
| start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host | |
| chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays) | |
| chunk_results = render_fn(key_0, key_1, chunk_rays)[-1] | |
| results.append([unshard(x[0], padding) for x in chunk_results]) | |
| # pylint: enable=cell-var-from-loop | |
| rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)] | |
| # Normalize disp for visualization for ndc_rays in llff front-facing scenes. | |
| if normalize_disp: | |
| disp = (disp - disp.min()) / (disp.max() - disp.min()) | |
| return (rgb.reshape((height, width, -1)), disp.reshape( | |
| (height, width, -1)), acc.reshape((height, width, -1))) | |
| def compute_psnr(mse): | |
| """Compute psnr value given mse (we assume the maximum pixel value is 1). | |
| Args: | |
| mse: float, mean square error of pixels. | |
| Returns: | |
| psnr: float, the psnr value. | |
| """ | |
| return -10. * jnp.log(mse) / jnp.log(10.) | |
| def compute_ssim(img0, | |
| img1, | |
| max_val, | |
| filter_size=11, | |
| filter_sigma=1.5, | |
| k1=0.01, | |
| k2=0.03, | |
| return_map=False): | |
| """Computes SSIM from two images. | |
| This function was modeled after tf.image.ssim, and should produce comparable | |
| output. | |
| Args: | |
| img0: array. An image of size [..., width, height, num_channels]. | |
| img1: array. An image of size [..., width, height, num_channels]. | |
| max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. | |
| filter_size: int >= 1. Window size. | |
| filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. | |
| k1: float > 0. One of the SSIM dampening parameters. | |
| k2: float > 0. One of the SSIM dampening parameters. | |
| return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned | |
| Returns: | |
| Each image's mean SSIM, or a tensor of individual values if `return_map`. | |
| """ | |
| # Construct a 1D Gaussian blur filter. | |
| hw = filter_size // 2 | |
| shift = (2 * hw - filter_size + 1) / 2 | |
| f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2 | |
| filt = jnp.exp(-0.5 * f_i) | |
| filt /= jnp.sum(filt) | |
| # Blur in x and y (faster than the 2D convolution). | |
| filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid") | |
| filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid") | |
| # Vmap the blurs to the tensor size, and then compose them. | |
| num_dims = len(img0.shape) | |
| map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) | |
| for d in map_axes: | |
| filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) | |
| filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) | |
| filt_fn = lambda z: filt_fn1(filt_fn2(z)) | |
| mu0 = filt_fn(img0) | |
| mu1 = filt_fn(img1) | |
| mu00 = mu0 * mu0 | |
| mu11 = mu1 * mu1 | |
| mu01 = mu0 * mu1 | |
| sigma00 = filt_fn(img0 ** 2) - mu00 | |
| sigma11 = filt_fn(img1 ** 2) - mu11 | |
| sigma01 = filt_fn(img0 * img1) - mu01 | |
| # Clip the variances and covariances to valid values. | |
| # Variance must be non-negative: | |
| sigma00 = jnp.maximum(0., sigma00) | |
| sigma11 = jnp.maximum(0., sigma11) | |
| sigma01 = jnp.sign(sigma01) * jnp.minimum( | |
| jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) | |
| c1 = (k1 * max_val) ** 2 | |
| c2 = (k2 * max_val) ** 2 | |
| numer = (2 * mu01 + c1) * (2 * sigma01 + c2) | |
| denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) | |
| ssim_map = numer / denom | |
| ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) | |
| return ssim_map if return_map else ssim | |
| def save_img(img, pth): | |
| """Save an image to disk. | |
| Args: | |
| img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1] | |
| before saved to pth. | |
| pth: string, path to save the image to. | |
| """ | |
| with open_file(pth, "wb") as imgout: | |
| Image.fromarray(np.array( | |
| (np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG") | |
| def learning_rate_decay(step, | |
| lr_init, | |
| lr_final, | |
| max_steps, | |
| lr_delay_steps=0, | |
| lr_delay_mult=1): | |
| """Continuous learning rate decay function. | |
| The returned rate is lr_init when step=0 and lr_final when step=max_steps, and | |
| is log-linearly interpolated elsewhere (equivalent to exponential decay). | |
| If lr_delay_steps>0 then the learning rate will be scaled by some smooth | |
| function of lr_delay_mult, such that the initial learning rate is | |
| lr_init*lr_delay_mult at the beginning of optimization but will be eased back | |
| to the normal learning rate when steps>lr_delay_steps. | |
| Args: | |
| step: int, the current optimization step. | |
| lr_init: float, the initial learning rate. | |
| lr_final: float, the final learning rate. | |
| max_steps: int, the number of steps during optimization. | |
| lr_delay_steps: int, the number of steps to delay the full learning rate. | |
| lr_delay_mult: float, the multiplier on the rate when delaying it. | |
| Returns: | |
| lr: the learning for current step 'step'. | |
| """ | |
| if lr_delay_steps > 0: | |
| # A kind of reverse cosine decay. | |
| delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( | |
| 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) | |
| else: | |
| delay_rate = 1. | |
| t = np.clip(step / max_steps, 0, 1) | |
| log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) | |
| return delay_rate * log_lerp | |
| def shard(xs): | |
| """Split data into shards for multiple devices along the first dimension.""" | |
| ''' | |
| if 'embedding' in xs: | |
| xs['pixels'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['pixels']) | |
| xs['rays'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['rays']) | |
| xs['embedding'] = np.stack([xs['embedding']]*jax.local_device_count(),0) | |
| xs['random_rays'] = jax.tree_map(lambda x: np.stack([x]*jax.local_device_count(),0), xs['random_rays']) | |
| else: | |
| xs = jax.tree_map( | |
| lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x | |
| , xs) | |
| return xs | |
| ''' | |
| return jax.tree_map( | |
| lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x | |
| , xs) | |
| def to_device(xs): | |
| """Transfer data to devices (GPU/TPU).""" | |
| return jax.tree_map(jnp.array, xs) | |
| def unshard(x, padding=0): | |
| """Collect the sharded tensor to the shape before sharding.""" | |
| y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) | |
| if padding > 0: | |
| y = y[:-padding] | |
| return y | |
| def write_pickle(data, fn): | |
| with open(fn, 'wb') as f: | |
| pickle.dump(data, f) | |
| return None | |
| def read_pickle(fn): | |
| with open(fn, 'rb') as f: | |
| data = pickle.load(f) | |
| return data | |