venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import csv
import time
# For binary dilation
from scipy import ndimage
import os
from imaginaire.model_utils.gancraft.mc_lbl_reduction import ReducedLabelMapper
def load_voxel_new(voxel_path, shape=[256, 512, 512]):
voxel_world = np.fromfile(voxel_path, dtype='int32')
voxel_world = voxel_world.reshape(
shape[1]//16, shape[2]//16, 16, 16, shape[0])
voxel_world = voxel_world.transpose(4, 0, 2, 1, 3)
voxel_world = voxel_world.reshape(shape[0], shape[1], shape[2])
voxel_world = np.ascontiguousarray(voxel_world)
voxel_world = torch.from_numpy(voxel_world.astype(np.int32))
return voxel_world
def gen_corner_voxel(voxel):
r"""Converting voxel center array to voxel corner array. The size of the
produced array grows by 1 on every dimension.
Args:
voxel (torch.IntTensor, CPU): Input voxel of three dimensions
"""
structure = np.zeros([3, 3, 3], dtype=np.bool)
structure[1:, 1:, 1:] = True
voxel_p = F.pad(voxel, (0, 1, 0, 1, 0, 1))
corners = ndimage.binary_dilation(voxel_p.numpy(), structure)
corners = torch.tensor(corners, dtype=torch.int32)
return corners
def calc_height_map(voxel_t):
r"""Calculate height map given a voxel grid [Y, X, Z] as input.
The height is defined as the Y index of the surface (non-air) block
Args:
voxel (Y x X x Z torch.IntTensor, CPU): Input voxel of three dimensions
Output:
heightmap (X x Z torch.IntTensor)
"""
start_time = time.time()
m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False)
heightmap = voxel_t.shape[0] - 1 - h
heightmap[m == 0] = 0 # Special case when the whole vertical column is empty
elapsed_time = time.time() - start_time
print("[GANcraft-utils] Heightmap time: {}".format(elapsed_time))
return heightmap
def trans_vec_homo(m, v, is_vec=False):
r"""3-dimensional Homogeneous matrix and regular vector multiplication
Convert v to homogeneous vector, perform M-V multiplication, and convert back
Note that this function does not support autograd.
Args:
m (4 x 4 tensor): a homogeneous matrix
v (3 tensor): a 3-d vector
vec (bool): if true, v is direction. Otherwise v is point
"""
if is_vec:
v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype)
else:
v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype)
v = torch.mv(m, v)
if not is_vec:
v = v / v[3]
v = v[:3]
return v
def cumsum_exclusive(tensor, dim):
cumsum = torch.cumsum(tensor, dim)
cumsum = torch.roll(cumsum, 1, dim)
cumsum.index_fill_(dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0)
return cumsum
def sample_depth_batched(depth2, nsamples, deterministic=False, use_box_boundaries=True, sample_depth=4):
r""" Make best effort to sample points within the same distance for every ray.
Exception: When there is not enough voxel.
Args:
depth2 (N x 2 x 256 x 256 x 4 x 1 tensor):
- N: Batch.
- 2: Entrance / exit depth for each intersected box.
- 256, 256: Height, Width.
- 4: Number of intersected boxes along the ray.
- 1: One extra dim for consistent tensor dims.
depth2 can include NaNs.
deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
use_box_boundaries (bool): Whether to add the entrance / exit points into the sample.
sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels.
"""
bs = depth2.size(0)
dim0 = depth2.size(2)
dim1 = depth2.size(3)
dists = depth2[:, 1] - depth2[:, 0]
dists[torch.isnan(dists)] = 0 # N, 256, 256, 4, 1
accu_depth = torch.cumsum(dists, dim=-2) # N, 256, 256, 4, 1
total_depth = accu_depth[..., [-1], :] # N, 256, 256, 1, 1
total_depth = torch.clamp(total_depth, None, sample_depth)
# Ignore out of range box boundaries. Fill with random samples.
if use_box_boundaries:
boundary_samples = accu_depth.clone().detach()
boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth
bad_mask = (accu_depth > sample_depth) | (dists == 0)
boundary_samples[bad_mask] = boundary_samples_filler[bad_mask]
rand_shape = [bs, dim0, dim1, nsamples, 1]
# 256, 256, N, 1
if deterministic:
rand_samples = torch.empty(rand_shape, dtype=total_depth.dtype, device=total_depth.device)
rand_samples[..., :, 0] = torch.linspace(0, 1, nsamples+2)[1:-1]
else:
rand_samples = torch.rand(rand_shape, dtype=total_depth.dtype, device=total_depth.device) # 256, 256, N, 1
# Stratified sampling as in NeRF
rand_samples = rand_samples / nsamples
rand_samples[..., :, 0] += torch.linspace(0, 1, nsamples+1, device=rand_samples.device)[:-1]
rand_samples = rand_samples * total_depth # 256, 256, N, 1
# Can also include boundaries
if use_box_boundaries:
rand_samples = torch.cat([rand_samples, boundary_samples, torch.zeros(
[bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device)], dim=-2)
rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False)
midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2
new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :]
# Scatter the random samples back
# 256, 256, 1, M, 1 > 256, 256, N, 1, 1
idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3) # 256, 256, M, 1
# print(idx.shape, idx.max(), idx.min()) # max 3, min 0
depth_deltas = depth2[:, 0, :, :, 1:, :] - depth2[:, 1, :, :, :-1, :] # There might be NaNs!
depth_deltas = torch.cumsum(depth_deltas, dim=-2)
depth_deltas = torch.cat([depth2[:, 0, :, :, [0], :], depth_deltas+depth2[:, 0, :, :, [0], :]], dim=-2)
heads = torch.gather(depth_deltas, -2, idx) # 256 256 M 1
# heads = torch.gather(depth2[0], -2, idx) # 256 256 M 1
# print(torch.any(torch.isnan(heads)))
rand_depth = heads + midpoints # 256 256 N 1
return rand_depth, new_dists, idx
def volum_rendering_relu(sigma, dists, dim=2):
free_energy = F.relu(sigma) * dists
a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here
b = torch.exp(-cumsum_exclusive(free_energy, dim=dim)) # probability of everything is empty up to now
probs = a * b # probability of the ray hits something here
return probs
class McVoxel(nn.Module):
r"""Voxel management."""
def __init__(self, voxel_t, preproc_ver):
super(McVoxel, self).__init__()
# Filter voxel
voxel_t[voxel_t == 246] = 0 # lily_pad
voxel_t[voxel_t == 241] = 0 # vine
voxel_t[voxel_t == 611] = 26 # Blue ice -> water
voxel_t[voxel_t == 183] = 26 # ice -> water
voxel_t[voxel_t == 401] = 25 # Packed ice -> bedrock
if preproc_ver >= 3 and preproc_ver < 6:
voxel_t[voxel_t == 27] = 25 # Lava -> bedrock
voxel_t[voxel_t == 616] = 9 # void_air -> dirt
voxel_t[voxel_t == 617] = 25 # cave_air -> bedrock
if preproc_ver >= 6:
voxel_t[voxel_t == 616] = 0 # void_air -> air
voxel_t[voxel_t == 617] = 0 # cave_air -> air
# Simplify voxel
structure = ndimage.generate_binary_structure(3, 3)
mask = voxel_t.numpy() > 0
if preproc_ver == 4: # Hollow bottom
mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
voxel_t[mask] = 0
if preproc_ver >= 5: # Close cell before hollow bottom
mask = ndimage.morphology.binary_dilation(mask, iterations=1, border_value=1)
mask = ndimage.morphology.binary_erosion(mask, iterations=1, border_value=1)
mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
voxel_t[mask] = 0
self.register_buffer('voxel_t', voxel_t, persistent=False)
self.trans_mat = torch.eye(4) # Transform voxel to world
# Generate heightmap for camera trajectory generation
self.heightmap = calc_height_map(self.voxel_t)
self._truncate_voxel()
# Convert voxel ([X, Y, Z], int32) to corner ([X+1, Y+1, Z+1], int32) (Requires CPU tensor)
corner_t = gen_corner_voxel(self.voxel_t)
self.register_buffer('corner_t', corner_t, persistent=False)
# Generate 3D position to 1D feature LUT table
nfilledvox = torch.sum(self.corner_t > 0)
print('[GANcraft-utils] Number of filled voxels: {} / {}'.format(nfilledvox.item(), torch.numel(self.corner_t)))
# Zero means non-existent voxel.
self.corner_t[self.corner_t > 0] = torch.arange(start=1, end=nfilledvox+1, step=1, dtype=torch.int32)
self.nfilledvox = nfilledvox
def world2local(self, v, is_vec=False):
mat_world2local = torch.inverse(self.trans_mat)
return trans_vec_homo(mat_world2local, v, is_vec)
def _truncate_voxel(self):
gnd_level = self.heightmap.min()
sky_level = self.heightmap.max() + 1
self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :]
self.trans_mat[0, 3] += gnd_level
print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item()))
def is_sea(self, loc):
r"""loc: [2]: x, z."""
x = int(loc[1])
z = int(loc[2])
if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1):
print('[McVoxel] is_sea(): Index out of bound.')
return True
y = self.heightmap[x, z] - self.trans_mat[0, 3]
y = int(y)
if self.voxel_t[y, x, z] == 26:
print('[McVoxel] is_sea(): Get a sea.')
print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z])
return True
else:
return False
class MCLabelTranslator:
r"""Resolving mapping across Minecraft voxel, coco-stuff label and GANcraft reduced label set."""
def __init__(self):
this_path = os.path.dirname(os.path.abspath(__file__))
# Load voxel name lut
id2name_lut = {}
id2color_lut = {}
id2glbl_lut = {}
with open(os.path.join(this_path, 'id2name_gg.csv'), newline='') as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
for row in csvreader:
id2name_lut[int(row[0])] = row[1]
id2color_lut[int(row[0])] = int(row[2])
id2glbl_lut[int(row[0])] = row[3]
# Load GauGAN color lut
glbl2color_lut = {}
glbl2cocoidx_lut = {}
with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
cocoidx = 1 # 0 is "Others"
for row in csvreader:
color = int(row[1].lstrip('#'), 16)
glbl2color_lut[row[0]] = color
glbl2cocoidx_lut[row[0]] = cocoidx
cocoidx += 1
# Generate id2ggcolor lut
id2ggcolor_lut = {}
for k, v in id2glbl_lut.items():
if v:
id2ggcolor_lut[k] = glbl2color_lut[v]
else:
id2ggcolor_lut[k] = 0
# Generate id2cocoidx
id2cocoidx_lut = {}
for k, v in id2glbl_lut.items():
if v:
id2cocoidx_lut[k] = glbl2cocoidx_lut[v]
else:
id2cocoidx_lut[k] = 0
self.id2color_lut = id2color_lut
self.id2name_lut = id2name_lut
self.id2glbl_lut = id2glbl_lut
self.id2ggcolor_lut = id2ggcolor_lut
self.id2cocoidx_lut = id2cocoidx_lut
if True:
mapper = ReducedLabelMapper()
mcid2rdid_lut = mapper.mcid2rdid_lut
mcid2rdid_lut = torch.tensor(mcid2rdid_lut, dtype=torch.long)
self.mcid2rdid_lut = mcid2rdid_lut
self.num_reduced_lbls = len(mapper.reduced_lbls)
self.ignore_id = mapper.ignore_id
self.dirt_id = mapper.dirt_id
self.water_id = mapper.water_id
self.mapper = mapper
ggid2rdid_lut = mapper.ggid2rdid + [0] # Last index is ignore
ggid2rdid_lut = torch.tensor(ggid2rdid_lut, dtype=torch.long)
self.ggid2rdid_lut = ggid2rdid_lut
if True:
mc2coco_lut = list(zip(*sorted([(k, v) for k, v in self.id2cocoidx_lut.items()])))[1]
mc2coco_lut = torch.tensor(mc2coco_lut, dtype=torch.long)
self.mc2coco_lut = mc2coco_lut
def gglbl2ggid(self, gglbl):
return self.mapper.gglbl2ggid[gglbl]
def mc2coco(self, mc):
self.mc2coco_lut = self.mc2coco_lut.to(mc.device)
coco = self.mc2coco_lut[mc.long()]
return coco
def mc2reduced(self, mc, ign2dirt=False):
self.mcid2rdid_lut = self.mcid2rdid_lut.to(mc.device)
reduced = self.mcid2rdid_lut[mc.long()]
if ign2dirt:
reduced[reduced == self.ignore_id] = self.dirt_id
return reduced
def coco2reduced(self, coco):
self.ggid2rdid_lut = self.ggid2rdid_lut.to(coco.device)
reduced = self.ggid2rdid_lut[coco.long()]
return reduced
def get_num_reduced_lbls(self):
return self.num_reduced_lbls
@staticmethod
def uint32_to_4uint8(x):
dt1 = np.dtype(('i4', [('bytes', 'u1', 4)]))
color = x.view(dtype=dt1)['bytes']
return color
def mc_color(self, img):
r"""Obtaining Minecraft default color.
Args:
img (H x W x 1 int32 numpy tensor): Segmentation map.
"""
lut = self.id2color_lut
lut = list(zip(*sorted([(k, v) for k, v in lut.items()])))[1]
lut = np.array(lut, dtype=np.uint32)
rgb = lut[img]
rgb = self.uint32_to_4uint8(rgb)[..., :3]
return rgb
def rand_crop(cam_c, cam_res, target_res):
r"""Produces a new cam_c so that the effect of rendering with the new cam_c and target_res is the same as rendering
with the old parameters and then crop out target_res.
"""
d0 = np.random.randint(cam_res[0] - target_res[0] + 1)
d1 = np.random.randint(cam_res[1] - target_res[1] + 1)
cam_c = [cam_c[0]-d0, cam_c[1]-d1]
return cam_c
def segmask_smooth(seg_mask, kernel_size=7):
labels = F.avg_pool2d(seg_mask, kernel_size, 1, kernel_size//2)
onehot_idx = torch.argmax(labels, dim=1, keepdims=True)
labels.fill_(0.0)
labels.scatter_(1, onehot_idx, 1.0)
return labels
def colormap(x, cmap='viridis'):
x = np.nan_to_num(x, np.nan, np.nan, np.nan)
x = x - np.nanmin(x)
x = x / np.nanmax(x)
rgb = plt.get_cmap(cmap)(x)[..., :3]
return rgb