# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import csv import time # For binary dilation from scipy import ndimage import os from imaginaire.model_utils.gancraft.mc_lbl_reduction import ReducedLabelMapper def load_voxel_new(voxel_path, shape=[256, 512, 512]): voxel_world = np.fromfile(voxel_path, dtype='int32') voxel_world = voxel_world.reshape( shape[1]//16, shape[2]//16, 16, 16, shape[0]) voxel_world = voxel_world.transpose(4, 0, 2, 1, 3) voxel_world = voxel_world.reshape(shape[0], shape[1], shape[2]) voxel_world = np.ascontiguousarray(voxel_world) voxel_world = torch.from_numpy(voxel_world.astype(np.int32)) return voxel_world def gen_corner_voxel(voxel): r"""Converting voxel center array to voxel corner array. The size of the produced array grows by 1 on every dimension. Args: voxel (torch.IntTensor, CPU): Input voxel of three dimensions """ structure = np.zeros([3, 3, 3], dtype=np.bool) structure[1:, 1:, 1:] = True voxel_p = F.pad(voxel, (0, 1, 0, 1, 0, 1)) corners = ndimage.binary_dilation(voxel_p.numpy(), structure) corners = torch.tensor(corners, dtype=torch.int32) return corners def calc_height_map(voxel_t): r"""Calculate height map given a voxel grid [Y, X, Z] as input. The height is defined as the Y index of the surface (non-air) block Args: voxel (Y x X x Z torch.IntTensor, CPU): Input voxel of three dimensions Output: heightmap (X x Z torch.IntTensor) """ start_time = time.time() m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False) heightmap = voxel_t.shape[0] - 1 - h heightmap[m == 0] = 0 # Special case when the whole vertical column is empty elapsed_time = time.time() - start_time print("[GANcraft-utils] Heightmap time: {}".format(elapsed_time)) return heightmap def trans_vec_homo(m, v, is_vec=False): r"""3-dimensional Homogeneous matrix and regular vector multiplication Convert v to homogeneous vector, perform M-V multiplication, and convert back Note that this function does not support autograd. Args: m (4 x 4 tensor): a homogeneous matrix v (3 tensor): a 3-d vector vec (bool): if true, v is direction. Otherwise v is point """ if is_vec: v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype) else: v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype) v = torch.mv(m, v) if not is_vec: v = v / v[3] v = v[:3] return v def cumsum_exclusive(tensor, dim): cumsum = torch.cumsum(tensor, dim) cumsum = torch.roll(cumsum, 1, dim) cumsum.index_fill_(dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0) return cumsum def sample_depth_batched(depth2, nsamples, deterministic=False, use_box_boundaries=True, sample_depth=4): r""" Make best effort to sample points within the same distance for every ray. Exception: When there is not enough voxel. Args: depth2 (N x 2 x 256 x 256 x 4 x 1 tensor): - N: Batch. - 2: Entrance / exit depth for each intersected box. - 256, 256: Height, Width. - 4: Number of intersected boxes along the ray. - 1: One extra dim for consistent tensor dims. depth2 can include NaNs. deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling. use_box_boundaries (bool): Whether to add the entrance / exit points into the sample. sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels. """ bs = depth2.size(0) dim0 = depth2.size(2) dim1 = depth2.size(3) dists = depth2[:, 1] - depth2[:, 0] dists[torch.isnan(dists)] = 0 # N, 256, 256, 4, 1 accu_depth = torch.cumsum(dists, dim=-2) # N, 256, 256, 4, 1 total_depth = accu_depth[..., [-1], :] # N, 256, 256, 1, 1 total_depth = torch.clamp(total_depth, None, sample_depth) # Ignore out of range box boundaries. Fill with random samples. if use_box_boundaries: boundary_samples = accu_depth.clone().detach() boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth bad_mask = (accu_depth > sample_depth) | (dists == 0) boundary_samples[bad_mask] = boundary_samples_filler[bad_mask] rand_shape = [bs, dim0, dim1, nsamples, 1] # 256, 256, N, 1 if deterministic: rand_samples = torch.empty(rand_shape, dtype=total_depth.dtype, device=total_depth.device) rand_samples[..., :, 0] = torch.linspace(0, 1, nsamples+2)[1:-1] else: rand_samples = torch.rand(rand_shape, dtype=total_depth.dtype, device=total_depth.device) # 256, 256, N, 1 # Stratified sampling as in NeRF rand_samples = rand_samples / nsamples rand_samples[..., :, 0] += torch.linspace(0, 1, nsamples+1, device=rand_samples.device)[:-1] rand_samples = rand_samples * total_depth # 256, 256, N, 1 # Can also include boundaries if use_box_boundaries: rand_samples = torch.cat([rand_samples, boundary_samples, torch.zeros( [bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device)], dim=-2) rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False) midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2 new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :] # Scatter the random samples back # 256, 256, 1, M, 1 > 256, 256, N, 1, 1 idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3) # 256, 256, M, 1 # print(idx.shape, idx.max(), idx.min()) # max 3, min 0 depth_deltas = depth2[:, 0, :, :, 1:, :] - depth2[:, 1, :, :, :-1, :] # There might be NaNs! depth_deltas = torch.cumsum(depth_deltas, dim=-2) depth_deltas = torch.cat([depth2[:, 0, :, :, [0], :], depth_deltas+depth2[:, 0, :, :, [0], :]], dim=-2) heads = torch.gather(depth_deltas, -2, idx) # 256 256 M 1 # heads = torch.gather(depth2[0], -2, idx) # 256 256 M 1 # print(torch.any(torch.isnan(heads))) rand_depth = heads + midpoints # 256 256 N 1 return rand_depth, new_dists, idx def volum_rendering_relu(sigma, dists, dim=2): free_energy = F.relu(sigma) * dists a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here b = torch.exp(-cumsum_exclusive(free_energy, dim=dim)) # probability of everything is empty up to now probs = a * b # probability of the ray hits something here return probs class McVoxel(nn.Module): r"""Voxel management.""" def __init__(self, voxel_t, preproc_ver): super(McVoxel, self).__init__() # Filter voxel voxel_t[voxel_t == 246] = 0 # lily_pad voxel_t[voxel_t == 241] = 0 # vine voxel_t[voxel_t == 611] = 26 # Blue ice -> water voxel_t[voxel_t == 183] = 26 # ice -> water voxel_t[voxel_t == 401] = 25 # Packed ice -> bedrock if preproc_ver >= 3 and preproc_ver < 6: voxel_t[voxel_t == 27] = 25 # Lava -> bedrock voxel_t[voxel_t == 616] = 9 # void_air -> dirt voxel_t[voxel_t == 617] = 25 # cave_air -> bedrock if preproc_ver >= 6: voxel_t[voxel_t == 616] = 0 # void_air -> air voxel_t[voxel_t == 617] = 0 # cave_air -> air # Simplify voxel structure = ndimage.generate_binary_structure(3, 3) mask = voxel_t.numpy() > 0 if preproc_ver == 4: # Hollow bottom mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1) voxel_t[mask] = 0 if preproc_ver >= 5: # Close cell before hollow bottom mask = ndimage.morphology.binary_dilation(mask, iterations=1, border_value=1) mask = ndimage.morphology.binary_erosion(mask, iterations=1, border_value=1) mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1) voxel_t[mask] = 0 self.register_buffer('voxel_t', voxel_t, persistent=False) self.trans_mat = torch.eye(4) # Transform voxel to world # Generate heightmap for camera trajectory generation self.heightmap = calc_height_map(self.voxel_t) self._truncate_voxel() # Convert voxel ([X, Y, Z], int32) to corner ([X+1, Y+1, Z+1], int32) (Requires CPU tensor) corner_t = gen_corner_voxel(self.voxel_t) self.register_buffer('corner_t', corner_t, persistent=False) # Generate 3D position to 1D feature LUT table nfilledvox = torch.sum(self.corner_t > 0) print('[GANcraft-utils] Number of filled voxels: {} / {}'.format(nfilledvox.item(), torch.numel(self.corner_t))) # Zero means non-existent voxel. self.corner_t[self.corner_t > 0] = torch.arange(start=1, end=nfilledvox+1, step=1, dtype=torch.int32) self.nfilledvox = nfilledvox def world2local(self, v, is_vec=False): mat_world2local = torch.inverse(self.trans_mat) return trans_vec_homo(mat_world2local, v, is_vec) def _truncate_voxel(self): gnd_level = self.heightmap.min() sky_level = self.heightmap.max() + 1 self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :] self.trans_mat[0, 3] += gnd_level print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item())) def is_sea(self, loc): r"""loc: [2]: x, z.""" x = int(loc[1]) z = int(loc[2]) if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1): print('[McVoxel] is_sea(): Index out of bound.') return True y = self.heightmap[x, z] - self.trans_mat[0, 3] y = int(y) if self.voxel_t[y, x, z] == 26: print('[McVoxel] is_sea(): Get a sea.') print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z]) return True else: return False class MCLabelTranslator: r"""Resolving mapping across Minecraft voxel, coco-stuff label and GANcraft reduced label set.""" def __init__(self): this_path = os.path.dirname(os.path.abspath(__file__)) # Load voxel name lut id2name_lut = {} id2color_lut = {} id2glbl_lut = {} with open(os.path.join(this_path, 'id2name_gg.csv'), newline='') as csvfile: csvreader = csv.reader(csvfile, delimiter=',') for row in csvreader: id2name_lut[int(row[0])] = row[1] id2color_lut[int(row[0])] = int(row[2]) id2glbl_lut[int(row[0])] = row[3] # Load GauGAN color lut glbl2color_lut = {} glbl2cocoidx_lut = {} with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile: csvreader = csv.reader(csvfile, delimiter=',') cocoidx = 1 # 0 is "Others" for row in csvreader: color = int(row[1].lstrip('#'), 16) glbl2color_lut[row[0]] = color glbl2cocoidx_lut[row[0]] = cocoidx cocoidx += 1 # Generate id2ggcolor lut id2ggcolor_lut = {} for k, v in id2glbl_lut.items(): if v: id2ggcolor_lut[k] = glbl2color_lut[v] else: id2ggcolor_lut[k] = 0 # Generate id2cocoidx id2cocoidx_lut = {} for k, v in id2glbl_lut.items(): if v: id2cocoidx_lut[k] = glbl2cocoidx_lut[v] else: id2cocoidx_lut[k] = 0 self.id2color_lut = id2color_lut self.id2name_lut = id2name_lut self.id2glbl_lut = id2glbl_lut self.id2ggcolor_lut = id2ggcolor_lut self.id2cocoidx_lut = id2cocoidx_lut if True: mapper = ReducedLabelMapper() mcid2rdid_lut = mapper.mcid2rdid_lut mcid2rdid_lut = torch.tensor(mcid2rdid_lut, dtype=torch.long) self.mcid2rdid_lut = mcid2rdid_lut self.num_reduced_lbls = len(mapper.reduced_lbls) self.ignore_id = mapper.ignore_id self.dirt_id = mapper.dirt_id self.water_id = mapper.water_id self.mapper = mapper ggid2rdid_lut = mapper.ggid2rdid + [0] # Last index is ignore ggid2rdid_lut = torch.tensor(ggid2rdid_lut, dtype=torch.long) self.ggid2rdid_lut = ggid2rdid_lut if True: mc2coco_lut = list(zip(*sorted([(k, v) for k, v in self.id2cocoidx_lut.items()])))[1] mc2coco_lut = torch.tensor(mc2coco_lut, dtype=torch.long) self.mc2coco_lut = mc2coco_lut def gglbl2ggid(self, gglbl): return self.mapper.gglbl2ggid[gglbl] def mc2coco(self, mc): self.mc2coco_lut = self.mc2coco_lut.to(mc.device) coco = self.mc2coco_lut[mc.long()] return coco def mc2reduced(self, mc, ign2dirt=False): self.mcid2rdid_lut = self.mcid2rdid_lut.to(mc.device) reduced = self.mcid2rdid_lut[mc.long()] if ign2dirt: reduced[reduced == self.ignore_id] = self.dirt_id return reduced def coco2reduced(self, coco): self.ggid2rdid_lut = self.ggid2rdid_lut.to(coco.device) reduced = self.ggid2rdid_lut[coco.long()] return reduced def get_num_reduced_lbls(self): return self.num_reduced_lbls @staticmethod def uint32_to_4uint8(x): dt1 = np.dtype(('i4', [('bytes', 'u1', 4)])) color = x.view(dtype=dt1)['bytes'] return color def mc_color(self, img): r"""Obtaining Minecraft default color. Args: img (H x W x 1 int32 numpy tensor): Segmentation map. """ lut = self.id2color_lut lut = list(zip(*sorted([(k, v) for k, v in lut.items()])))[1] lut = np.array(lut, dtype=np.uint32) rgb = lut[img] rgb = self.uint32_to_4uint8(rgb)[..., :3] return rgb def rand_crop(cam_c, cam_res, target_res): r"""Produces a new cam_c so that the effect of rendering with the new cam_c and target_res is the same as rendering with the old parameters and then crop out target_res. """ d0 = np.random.randint(cam_res[0] - target_res[0] + 1) d1 = np.random.randint(cam_res[1] - target_res[1] + 1) cam_c = [cam_c[0]-d0, cam_c[1]-d1] return cam_c def segmask_smooth(seg_mask, kernel_size=7): labels = F.avg_pool2d(seg_mask, kernel_size, 1, kernel_size//2) onehot_idx = torch.argmax(labels, dim=1, keepdims=True) labels.fill_(0.0) labels.scatter_(1, onehot_idx, 1.0) return labels def colormap(x, cmap='viridis'): x = np.nan_to_num(x, np.nan, np.nan, np.nan) x = x - np.nanmin(x) x = x / np.nanmax(x) rgb = plt.get_cmap(cmap)(x)[..., :3] return rgb