Spaces:
Runtime error
Runtime error
File size: 15,236 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import csv
import time
# For binary dilation
from scipy import ndimage
import os
from imaginaire.model_utils.gancraft.mc_lbl_reduction import ReducedLabelMapper
def load_voxel_new(voxel_path, shape=[256, 512, 512]):
voxel_world = np.fromfile(voxel_path, dtype='int32')
voxel_world = voxel_world.reshape(
shape[1]//16, shape[2]//16, 16, 16, shape[0])
voxel_world = voxel_world.transpose(4, 0, 2, 1, 3)
voxel_world = voxel_world.reshape(shape[0], shape[1], shape[2])
voxel_world = np.ascontiguousarray(voxel_world)
voxel_world = torch.from_numpy(voxel_world.astype(np.int32))
return voxel_world
def gen_corner_voxel(voxel):
r"""Converting voxel center array to voxel corner array. The size of the
produced array grows by 1 on every dimension.
Args:
voxel (torch.IntTensor, CPU): Input voxel of three dimensions
"""
structure = np.zeros([3, 3, 3], dtype=np.bool)
structure[1:, 1:, 1:] = True
voxel_p = F.pad(voxel, (0, 1, 0, 1, 0, 1))
corners = ndimage.binary_dilation(voxel_p.numpy(), structure)
corners = torch.tensor(corners, dtype=torch.int32)
return corners
def calc_height_map(voxel_t):
r"""Calculate height map given a voxel grid [Y, X, Z] as input.
The height is defined as the Y index of the surface (non-air) block
Args:
voxel (Y x X x Z torch.IntTensor, CPU): Input voxel of three dimensions
Output:
heightmap (X x Z torch.IntTensor)
"""
start_time = time.time()
m, h = torch.max((torch.flip(voxel_t, [0]) != 0).int(), dim=0, keepdim=False)
heightmap = voxel_t.shape[0] - 1 - h
heightmap[m == 0] = 0 # Special case when the whole vertical column is empty
elapsed_time = time.time() - start_time
print("[GANcraft-utils] Heightmap time: {}".format(elapsed_time))
return heightmap
def trans_vec_homo(m, v, is_vec=False):
r"""3-dimensional Homogeneous matrix and regular vector multiplication
Convert v to homogeneous vector, perform M-V multiplication, and convert back
Note that this function does not support autograd.
Args:
m (4 x 4 tensor): a homogeneous matrix
v (3 tensor): a 3-d vector
vec (bool): if true, v is direction. Otherwise v is point
"""
if is_vec:
v = torch.tensor([v[0], v[1], v[2], 0], dtype=v.dtype)
else:
v = torch.tensor([v[0], v[1], v[2], 1], dtype=v.dtype)
v = torch.mv(m, v)
if not is_vec:
v = v / v[3]
v = v[:3]
return v
def cumsum_exclusive(tensor, dim):
cumsum = torch.cumsum(tensor, dim)
cumsum = torch.roll(cumsum, 1, dim)
cumsum.index_fill_(dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0)
return cumsum
def sample_depth_batched(depth2, nsamples, deterministic=False, use_box_boundaries=True, sample_depth=4):
r""" Make best effort to sample points within the same distance for every ray.
Exception: When there is not enough voxel.
Args:
depth2 (N x 2 x 256 x 256 x 4 x 1 tensor):
- N: Batch.
- 2: Entrance / exit depth for each intersected box.
- 256, 256: Height, Width.
- 4: Number of intersected boxes along the ray.
- 1: One extra dim for consistent tensor dims.
depth2 can include NaNs.
deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling.
use_box_boundaries (bool): Whether to add the entrance / exit points into the sample.
sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels.
"""
bs = depth2.size(0)
dim0 = depth2.size(2)
dim1 = depth2.size(3)
dists = depth2[:, 1] - depth2[:, 0]
dists[torch.isnan(dists)] = 0 # N, 256, 256, 4, 1
accu_depth = torch.cumsum(dists, dim=-2) # N, 256, 256, 4, 1
total_depth = accu_depth[..., [-1], :] # N, 256, 256, 1, 1
total_depth = torch.clamp(total_depth, None, sample_depth)
# Ignore out of range box boundaries. Fill with random samples.
if use_box_boundaries:
boundary_samples = accu_depth.clone().detach()
boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth
bad_mask = (accu_depth > sample_depth) | (dists == 0)
boundary_samples[bad_mask] = boundary_samples_filler[bad_mask]
rand_shape = [bs, dim0, dim1, nsamples, 1]
# 256, 256, N, 1
if deterministic:
rand_samples = torch.empty(rand_shape, dtype=total_depth.dtype, device=total_depth.device)
rand_samples[..., :, 0] = torch.linspace(0, 1, nsamples+2)[1:-1]
else:
rand_samples = torch.rand(rand_shape, dtype=total_depth.dtype, device=total_depth.device) # 256, 256, N, 1
# Stratified sampling as in NeRF
rand_samples = rand_samples / nsamples
rand_samples[..., :, 0] += torch.linspace(0, 1, nsamples+1, device=rand_samples.device)[:-1]
rand_samples = rand_samples * total_depth # 256, 256, N, 1
# Can also include boundaries
if use_box_boundaries:
rand_samples = torch.cat([rand_samples, boundary_samples, torch.zeros(
[bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device)], dim=-2)
rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False)
midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2
new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :]
# Scatter the random samples back
# 256, 256, 1, M, 1 > 256, 256, N, 1, 1
idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3) # 256, 256, M, 1
# print(idx.shape, idx.max(), idx.min()) # max 3, min 0
depth_deltas = depth2[:, 0, :, :, 1:, :] - depth2[:, 1, :, :, :-1, :] # There might be NaNs!
depth_deltas = torch.cumsum(depth_deltas, dim=-2)
depth_deltas = torch.cat([depth2[:, 0, :, :, [0], :], depth_deltas+depth2[:, 0, :, :, [0], :]], dim=-2)
heads = torch.gather(depth_deltas, -2, idx) # 256 256 M 1
# heads = torch.gather(depth2[0], -2, idx) # 256 256 M 1
# print(torch.any(torch.isnan(heads)))
rand_depth = heads + midpoints # 256 256 N 1
return rand_depth, new_dists, idx
def volum_rendering_relu(sigma, dists, dim=2):
free_energy = F.relu(sigma) * dists
a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here
b = torch.exp(-cumsum_exclusive(free_energy, dim=dim)) # probability of everything is empty up to now
probs = a * b # probability of the ray hits something here
return probs
class McVoxel(nn.Module):
r"""Voxel management."""
def __init__(self, voxel_t, preproc_ver):
super(McVoxel, self).__init__()
# Filter voxel
voxel_t[voxel_t == 246] = 0 # lily_pad
voxel_t[voxel_t == 241] = 0 # vine
voxel_t[voxel_t == 611] = 26 # Blue ice -> water
voxel_t[voxel_t == 183] = 26 # ice -> water
voxel_t[voxel_t == 401] = 25 # Packed ice -> bedrock
if preproc_ver >= 3 and preproc_ver < 6:
voxel_t[voxel_t == 27] = 25 # Lava -> bedrock
voxel_t[voxel_t == 616] = 9 # void_air -> dirt
voxel_t[voxel_t == 617] = 25 # cave_air -> bedrock
if preproc_ver >= 6:
voxel_t[voxel_t == 616] = 0 # void_air -> air
voxel_t[voxel_t == 617] = 0 # cave_air -> air
# Simplify voxel
structure = ndimage.generate_binary_structure(3, 3)
mask = voxel_t.numpy() > 0
if preproc_ver == 4: # Hollow bottom
mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
voxel_t[mask] = 0
if preproc_ver >= 5: # Close cell before hollow bottom
mask = ndimage.morphology.binary_dilation(mask, iterations=1, border_value=1)
mask = ndimage.morphology.binary_erosion(mask, iterations=1, border_value=1)
mask = ndimage.morphology.binary_erosion(mask, structure=structure, iterations=2, border_value=1)
voxel_t[mask] = 0
self.register_buffer('voxel_t', voxel_t, persistent=False)
self.trans_mat = torch.eye(4) # Transform voxel to world
# Generate heightmap for camera trajectory generation
self.heightmap = calc_height_map(self.voxel_t)
self._truncate_voxel()
# Convert voxel ([X, Y, Z], int32) to corner ([X+1, Y+1, Z+1], int32) (Requires CPU tensor)
corner_t = gen_corner_voxel(self.voxel_t)
self.register_buffer('corner_t', corner_t, persistent=False)
# Generate 3D position to 1D feature LUT table
nfilledvox = torch.sum(self.corner_t > 0)
print('[GANcraft-utils] Number of filled voxels: {} / {}'.format(nfilledvox.item(), torch.numel(self.corner_t)))
# Zero means non-existent voxel.
self.corner_t[self.corner_t > 0] = torch.arange(start=1, end=nfilledvox+1, step=1, dtype=torch.int32)
self.nfilledvox = nfilledvox
def world2local(self, v, is_vec=False):
mat_world2local = torch.inverse(self.trans_mat)
return trans_vec_homo(mat_world2local, v, is_vec)
def _truncate_voxel(self):
gnd_level = self.heightmap.min()
sky_level = self.heightmap.max() + 1
self.voxel_t = self.voxel_t[gnd_level:sky_level, :, :]
self.trans_mat[0, 3] += gnd_level
print('[GANcraft-utils] Voxel truncated. Gnd: {}; Sky: {}.'.format(gnd_level.item(), sky_level.item()))
def is_sea(self, loc):
r"""loc: [2]: x, z."""
x = int(loc[1])
z = int(loc[2])
if x < 0 or x > self.heightmap.size(0) or z < 0 or z > self.heightmap.size(1):
print('[McVoxel] is_sea(): Index out of bound.')
return True
y = self.heightmap[x, z] - self.trans_mat[0, 3]
y = int(y)
if self.voxel_t[y, x, z] == 26:
print('[McVoxel] is_sea(): Get a sea.')
print(self.voxel_t[y, x, z], self.voxel_t[y+1, x, z])
return True
else:
return False
class MCLabelTranslator:
r"""Resolving mapping across Minecraft voxel, coco-stuff label and GANcraft reduced label set."""
def __init__(self):
this_path = os.path.dirname(os.path.abspath(__file__))
# Load voxel name lut
id2name_lut = {}
id2color_lut = {}
id2glbl_lut = {}
with open(os.path.join(this_path, 'id2name_gg.csv'), newline='') as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
for row in csvreader:
id2name_lut[int(row[0])] = row[1]
id2color_lut[int(row[0])] = int(row[2])
id2glbl_lut[int(row[0])] = row[3]
# Load GauGAN color lut
glbl2color_lut = {}
glbl2cocoidx_lut = {}
with open(os.path.join(this_path, 'gaugan_lbl2col.csv'), newline='') as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
cocoidx = 1 # 0 is "Others"
for row in csvreader:
color = int(row[1].lstrip('#'), 16)
glbl2color_lut[row[0]] = color
glbl2cocoidx_lut[row[0]] = cocoidx
cocoidx += 1
# Generate id2ggcolor lut
id2ggcolor_lut = {}
for k, v in id2glbl_lut.items():
if v:
id2ggcolor_lut[k] = glbl2color_lut[v]
else:
id2ggcolor_lut[k] = 0
# Generate id2cocoidx
id2cocoidx_lut = {}
for k, v in id2glbl_lut.items():
if v:
id2cocoidx_lut[k] = glbl2cocoidx_lut[v]
else:
id2cocoidx_lut[k] = 0
self.id2color_lut = id2color_lut
self.id2name_lut = id2name_lut
self.id2glbl_lut = id2glbl_lut
self.id2ggcolor_lut = id2ggcolor_lut
self.id2cocoidx_lut = id2cocoidx_lut
if True:
mapper = ReducedLabelMapper()
mcid2rdid_lut = mapper.mcid2rdid_lut
mcid2rdid_lut = torch.tensor(mcid2rdid_lut, dtype=torch.long)
self.mcid2rdid_lut = mcid2rdid_lut
self.num_reduced_lbls = len(mapper.reduced_lbls)
self.ignore_id = mapper.ignore_id
self.dirt_id = mapper.dirt_id
self.water_id = mapper.water_id
self.mapper = mapper
ggid2rdid_lut = mapper.ggid2rdid + [0] # Last index is ignore
ggid2rdid_lut = torch.tensor(ggid2rdid_lut, dtype=torch.long)
self.ggid2rdid_lut = ggid2rdid_lut
if True:
mc2coco_lut = list(zip(*sorted([(k, v) for k, v in self.id2cocoidx_lut.items()])))[1]
mc2coco_lut = torch.tensor(mc2coco_lut, dtype=torch.long)
self.mc2coco_lut = mc2coco_lut
def gglbl2ggid(self, gglbl):
return self.mapper.gglbl2ggid[gglbl]
def mc2coco(self, mc):
self.mc2coco_lut = self.mc2coco_lut.to(mc.device)
coco = self.mc2coco_lut[mc.long()]
return coco
def mc2reduced(self, mc, ign2dirt=False):
self.mcid2rdid_lut = self.mcid2rdid_lut.to(mc.device)
reduced = self.mcid2rdid_lut[mc.long()]
if ign2dirt:
reduced[reduced == self.ignore_id] = self.dirt_id
return reduced
def coco2reduced(self, coco):
self.ggid2rdid_lut = self.ggid2rdid_lut.to(coco.device)
reduced = self.ggid2rdid_lut[coco.long()]
return reduced
def get_num_reduced_lbls(self):
return self.num_reduced_lbls
@staticmethod
def uint32_to_4uint8(x):
dt1 = np.dtype(('i4', [('bytes', 'u1', 4)]))
color = x.view(dtype=dt1)['bytes']
return color
def mc_color(self, img):
r"""Obtaining Minecraft default color.
Args:
img (H x W x 1 int32 numpy tensor): Segmentation map.
"""
lut = self.id2color_lut
lut = list(zip(*sorted([(k, v) for k, v in lut.items()])))[1]
lut = np.array(lut, dtype=np.uint32)
rgb = lut[img]
rgb = self.uint32_to_4uint8(rgb)[..., :3]
return rgb
def rand_crop(cam_c, cam_res, target_res):
r"""Produces a new cam_c so that the effect of rendering with the new cam_c and target_res is the same as rendering
with the old parameters and then crop out target_res.
"""
d0 = np.random.randint(cam_res[0] - target_res[0] + 1)
d1 = np.random.randint(cam_res[1] - target_res[1] + 1)
cam_c = [cam_c[0]-d0, cam_c[1]-d1]
return cam_c
def segmask_smooth(seg_mask, kernel_size=7):
labels = F.avg_pool2d(seg_mask, kernel_size, 1, kernel_size//2)
onehot_idx = torch.argmax(labels, dim=1, keepdims=True)
labels.fill_(0.0)
labels.scatter_(1, onehot_idx, 1.0)
return labels
def colormap(x, cmap='viridis'):
x = np.nan_to_num(x, np.nan, np.nan, np.nan)
x = x - np.nanmin(x)
x = x / np.nanmax(x)
rgb = plt.get_cmap(cmap)(x)[..., :3]
return rgb
|