import cv2 import numpy as np import torch from torchvision import transforms from torch import nn from torch.nn import init import torch.nn.functional as F import functools import math import warnings def image_float_to_uint8(img): """ Convert a float image (0.0-1.0) to uint8 (0-255) """ vmin = np.min(img) vmax = np.max(img) if vmax - vmin < 1e-10: vmax += 1e-10 img = (img - vmin) / (vmax - vmin) img *= 255.0 return img.astype(np.uint8) def cmap(img, color_map=cv2.COLORMAP_HOT): """ Apply 'HOT' color to a float image """ return cv2.applyColorMap(image_float_to_uint8(img), color_map) def batched_index_select_nd(t, inds): """ Index select on dim 1 of a n-dimensional batched tensor. :param t (batch, n, ...) :param inds (batch, k) :return (batch, k, ...) """ return t.gather( 1, inds[(...,) + (None,) * (len(t.shape) - 2)].expand(-1, -1, *t.shape[2:]) ) def batched_index_select_nd_last(t, inds): """ Index select on dim -1 of a >=2D multi-batched tensor. inds assumed to have all batch dimensions except one data dimension 'n' :param t (batch..., n, m) :param inds (batch..., k) :return (batch..., n, k) """ dummy = inds.unsqueeze(-2).expand(*inds.shape[:-1], t.size(-2), inds.size(-1)) out = t.gather(-1, dummy) return out def repeat_interleave(input, repeats, dim=0): """ Repeat interleave along axis 0 torch.repeat_interleave is currently very slow https://github.com/pytorch/pytorch/issues/31980 """ output = input.unsqueeze(1).expand(-1, repeats, *input.shape[1:]) return output.reshape(-1, *input.shape[1:]) def get_image_to_tensor_balanced(image_size=0): ops = [] if image_size > 0: ops.append(transforms.Resize(image_size)) ops.extend( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),] ) return transforms.Compose(ops) def get_mask_to_tensor(): return transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))] ) def homogeneous(points): """ Concat 1 to each point :param points (..., 3) :return (..., 4) """ return F.pad(points, (0, 1), "constant", 1.0) def gen_grid(*args, ij_indexing=False): """ Generete len(args)-dimensional grid. Each arg should be (lo, hi, sz) so that in that dimension points are taken at linspace(lo, hi, sz). Example: gen_grid((0,1,10), (-1,1,20)) :return (prod_i args_i[2], len(args)), len(args)-dimensional grid points """ return torch.from_numpy( np.vstack( np.meshgrid( *(np.linspace(lo, hi, sz, dtype=np.float32) for lo, hi, sz in args), indexing="ij" if ij_indexing else "xy" ) ) .reshape(len(args), -1) .T ) def unproj_map(width, height, f, c=None, device="cpu", norm_dir=True, xy_offset=None): """ Get camera unprojection map for given image size. [y,x] of output tensor will contain unit vector of camera ray of that pixel. :param width image width :param height image height :param f focal length, either a number or tensor [fx, fy] :param c principal point, optional, either None or tensor [fx, fy] if not specified uses center of image :return unproj map (height, width, 3) """ if c is None: c = torch.tensor([[0.0, 0.0]], device=device) elif isinstance(c, float): c = torch.tensor([[c, c]], device=device) elif len(c.shape) == 0: c = c[None, None].expand(1, 2) elif len(c.shape) == 1: c = c.unsqueeze(-1).expand(1, 2) if isinstance(f, float): f = torch.tensor([[f, f]], device=device) elif len(f.shape) == 0: f = f[None, None].expand(1, 2) elif len(f.shape) == 1: f = f.unsqueeze(-1).expand(1, 2) n = f.shape[0] pixel_width = 2 / width pixel_height = 2 / height x = torch.linspace(-1 + .5 * pixel_width, 1 - .5 * pixel_width, width, dtype=torch.float32, device=device).view(1, 1, width).expand(n, height, width) y = torch.linspace(-1 + .5 * pixel_height, 1 - .5 * pixel_height, height, dtype=torch.float32, device=device).view(1, height, 1).expand(n, height, width) if xy_offset is not None: x = x + xy_offset[0] * pixel_width y = y + xy_offset[1] * pixel_height xy_img = torch.stack((x, y), dim=-1) xy = (xy_img - c.view(n, 1, 1, 2)) / f.view(n, 1, 1, 2) z = torch.ones_like(x).unsqueeze(-1) unproj = torch.cat((xy, z), dim=-1) if norm_dir: unproj /= torch.norm(unproj, dim=-1).unsqueeze(-1) return unproj, xy_img def coord_from_blender(dtype=torch.float32, device="cpu"): """ Blender to standard coordinate system transform. Standard coordinate system is: x right y up z out (out=screen to face) Blender coordinate system is: x right y in z up :return (4, 4) """ return torch.tensor( [[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]], dtype=dtype, device=device, ) def coord_to_blender(dtype=torch.float32, device="cpu"): """ Standard to Blender coordinate system transform. Standard coordinate system is: x right y up z out (out=screen to face) Blender coordinate system is: x right y in z up :return (4, 4) """ return torch.tensor( [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=dtype, device=device, ) def look_at(origin, target, world_up=np.array([0, 1, 0], dtype=np.float32)): """ Get 4x4 camera to world space matrix, for camera looking at target """ back = origin - target back /= np.linalg.norm(back) right = np.cross(world_up, back) right /= np.linalg.norm(right) up = np.cross(back, right) cam_to_world = np.empty((4, 4), dtype=np.float32) cam_to_world[:3, 0] = right cam_to_world[:3, 1] = up cam_to_world[:3, 2] = back cam_to_world[:3, 3] = origin cam_to_world[3, :] = [0, 0, 0, 1] return cam_to_world def get_cuda(gpu_id): """ Get a torch.device for GPU gpu_id. If GPU not available, returns CPU device. """ return ( torch.device("cuda:%d" % gpu_id) if torch.cuda.is_available() else torch.device("cpu") ) def masked_sample(masks, num_pix, prop_inside, thresh=0.5): """ :return (num_pix, 3) """ num_inside = int(num_pix * prop_inside + 0.5) num_outside = num_pix - num_inside inside = (masks >= thresh).nonzero(as_tuple=False) outside = (masks < thresh).nonzero(as_tuple=False) pix_inside = inside[torch.randint(0, inside.shape[0], (num_inside,))] pix_outside = outside[torch.randint(0, outside.shape[0], (num_outside,))] pix = torch.cat((pix_inside, pix_outside)) return pix def bbox_sample(bboxes, num_pix): """ :return (num_pix, 3) """ image_ids = torch.randint(0, bboxes.shape[0], (num_pix,)) pix_bboxes = bboxes[image_ids] x = ( torch.rand(num_pix) * (pix_bboxes[:, 2] + 1 - pix_bboxes[:, 0]) + pix_bboxes[:, 0] ).long() y = ( torch.rand(num_pix) * (pix_bboxes[:, 3] + 1 - pix_bboxes[:, 1]) + pix_bboxes[:, 1] ).long() pix = torch.stack((image_ids, y, x), dim=-1) return pix def gen_rays(poses, width, height, z_near, z_far, focal=None, c=None, norm_dir=True, xy_offset=None): """ Generate camera rays :return (B, H, W, 8) """ num_images = poses.shape[0] device = poses.device cam_unproj_map, xy = unproj_map(width, height, focal, c=c, device=device, norm_dir=norm_dir, xy_offset=xy_offset) cam_unproj_map = cam_unproj_map.expand(num_images, -1, -1, -1) xy = xy.expand(num_images, -1, -1, -1) cam_centers = poses[:, None, None, :3, 3].expand(-1, height, width, -1) cam_raydir = torch.matmul( poses[:, None, None, :3, :3], cam_unproj_map.unsqueeze(-1) )[:, :, :, :, 0] cam_nears = ( torch.tensor(z_near, device=device) .view(1, 1, 1, 1) .expand(num_images, height, width, -1) ) cam_fars = ( torch.tensor(z_far, device=device) .view(1, 1, 1, 1) .expand(num_images, height, width, -1) ) rays = torch.cat( (cam_centers, cam_raydir, cam_nears, cam_fars), dim=-1 ) return rays, xy def trans_t(t): return torch.tensor( [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1],], dtype=torch.float32, ) def rot_phi(phi): return torch.tensor( [ [1, 0, 0, 0], [0, np.cos(phi), -np.sin(phi), 0], [0, np.sin(phi), np.cos(phi), 0], [0, 0, 0, 1], ], dtype=torch.float32, ) def rot_theta(th): return torch.tensor( [ [np.cos(th), 0, -np.sin(th), 0], [0, 1, 0, 0], [np.sin(th), 0, np.cos(th), 0], [0, 0, 0, 1], ], dtype=torch.float32, ) def pose_spherical(theta, phi, radius): """ Spherical rendering poses, from NeRF """ c2w = trans_t(radius) c2w = rot_phi(phi / 180.0 * np.pi) @ c2w c2w = rot_theta(theta / 180.0 * np.pi) @ c2w c2w = ( torch.tensor( [[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, ) @ c2w ) return c2w def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def get_norm_layer(norm_type="instance", group_norm_groups=32): """Return a normalization layer Parameters: norm_type (str) -- the name of the normalization layer: batch | instance | none For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. """ if norm_type == "batch": norm_layer = functools.partial( nn.BatchNorm2d, affine=True, track_running_stats=True ) elif norm_type == "instance": norm_layer = functools.partial( nn.InstanceNorm2d, affine=False, track_running_stats=False ) elif norm_type == "group": norm_layer = functools.partial(nn.GroupNorm, group_norm_groups) elif norm_type == "none": norm_layer = None else: raise NotImplementedError("normalization layer [%s] is not found" % norm_type) return norm_layer def make_conv_2d( dim_in, dim_out, padding_type="reflect", norm_layer=None, activation=None, kernel_size=3, use_bias=False, stride=1, no_pad=False, zero_init=False, ): conv_block = [] amt = kernel_size // 2 if stride > 1 and not no_pad: raise NotImplementedError( "Padding with stride > 1 not supported, use same_pad_conv2d" ) if amt > 0 and not no_pad: if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(amt)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(amt)] elif padding_type == "zero": conv_block += [nn.ZeroPad2d(amt)] else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block.append( nn.Conv2d( dim_in, dim_out, kernel_size=kernel_size, bias=use_bias, stride=stride ) ) if zero_init: nn.init.zeros_(conv_block[-1].weight) # else: # nn.init.kaiming_normal_(conv_block[-1].weight) if norm_layer is not None: conv_block.append(norm_layer(dim_out)) if activation is not None: conv_block.append(activation) return nn.Sequential(*conv_block) def calc_same_pad_conv2d(t_shape, kernel_size=3, stride=1): in_height, in_width = t_shape[-2:] out_height = math.ceil(in_height / stride) out_width = math.ceil(in_width / stride) pad_along_height = max((out_height - 1) * stride + kernel_size - in_height, 0) pad_along_width = max((out_width - 1) * stride + kernel_size - in_width, 0) pad_top = pad_along_height // 2 pad_bottom = pad_along_height - pad_top pad_left = pad_along_width // 2 pad_right = pad_along_width - pad_left return pad_left, pad_right, pad_top, pad_bottom def same_pad_conv2d(t, padding_type="reflect", kernel_size=3, stride=1, layer=None): """ Perform SAME padding on tensor, given kernel size/stride of conv operator assumes kernel/stride are equal in all dimensions. Use before conv called. Dilation not supported. :param t image tensor input (B, C, H, W) :param padding_type padding type constant | reflect | replicate | circular constant is 0-pad. :param kernel_size kernel size of conv :param stride stride of conv :param layer optionally, pass conv layer to automatically get kernel_size and stride (overrides these) """ if layer is not None: if isinstance(layer, nn.Sequential): layer = next(layer.children()) kernel_size = layer.kernel_size[0] stride = layer.stride[0] return F.pad( t, calc_same_pad_conv2d(t.shape, kernel_size, stride), mode=padding_type ) def same_unpad_deconv2d(t, kernel_size=3, stride=1, layer=None): """ Perform SAME unpad on tensor, given kernel/stride of deconv operator. Use after deconv called. Dilation not supported. """ if layer is not None: if isinstance(layer, nn.Sequential): layer = next(layer.children()) kernel_size = layer.kernel_size[0] stride = layer.stride[0] h_scaled = (t.shape[-2] - 1) * stride w_scaled = (t.shape[-1] - 1) * stride pad_left, pad_right, pad_top, pad_bottom = calc_same_pad_conv2d( (h_scaled, w_scaled), kernel_size, stride ) if pad_right == 0: pad_right = -10000 if pad_bottom == 0: pad_bottom = -10000 return t[..., pad_top:-pad_bottom, pad_left:-pad_right] def combine_interleaved(t, inner_dims=(1,), agg_type="average"): if len(inner_dims) == 1 and inner_dims[0] == 1: return t t = t.reshape(-1, *inner_dims, *t.shape[1:]) if agg_type == "average": t = torch.mean(t, dim=1) elif agg_type == "max": t = torch.max(t, dim=1)[0] else: raise NotImplementedError("Unsupported combine type " + agg_type) return t def psnr(pred, target): """ Compute PSNR of two tensors in decibels. pred/target should be of same size or broadcastable """ mse = ((pred - target) ** 2).mean() psnr = -10 * math.log10(mse) return psnr def quat_to_rot(q): """ Quaternion to rotation matrix """ batch_size, _ = q.shape q = F.normalize(q, dim=1) R = torch.ones((batch_size, 3, 3), device=q.device) qr = q[:, 0] qi = q[:, 1] qj = q[:, 2] qk = q[:, 3] R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2) R[:, 0, 1] = 2 * (qj * qi - qk * qr) R[:, 0, 2] = 2 * (qi * qk + qr * qj) R[:, 1, 0] = 2 * (qj * qi + qk * qr) R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2) R[:, 1, 2] = 2 * (qj * qk - qi * qr) R[:, 2, 0] = 2 * (qk * qi - qj * qr) R[:, 2, 1] = 2 * (qj * qk + qi * qr) R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2) return R def rot_to_quat(R): """ Rotation matrix to quaternion """ batch_size, _, _ = R.shape q = torch.ones((batch_size, 4), device=R.device) R00 = R[:, 0, 0] R01 = R[:, 0, 1] R02 = R[:, 0, 2] R10 = R[:, 1, 0] R11 = R[:, 1, 1] R12 = R[:, 1, 2] R20 = R[:, 2, 0] R21 = R[:, 2, 1] R22 = R[:, 2, 2] q[:, 0] = torch.sqrt(1.0 + R00 + R11 + R22) / 2 q[:, 1] = (R21 - R12) / (4 * q[:, 0]) q[:, 2] = (R02 - R20) / (4 * q[:, 0]) q[:, 3] = (R10 - R01) / (4 * q[:, 0]) return q def get_module(net): """ Shorthand for either net.module (if net is instance of DataParallel) or net """ if isinstance(net, torch.nn.DataParallel): return net.module else: return net # Nan safe entropy def normalized_entropy(p, dim=-1, eps=2 ** (-8)): H_max = math.log2(p.shape[dim]) # x log2 (x) -> 0 . Therefore, we can set log2 (x) to 0 if x is small enough. # This should ensure numerical stability. p_too_small = p < eps p = p.clone() p[p_too_small] = 1 plp = torch.log2(p) * p plp[p_too_small] = 0 # This is the formula for the normalised entropy entropy = -plp.sum(dim) / H_max return entropy def kl_div(p, q, dim=-1, eps=2 ** (-8)): p_too_small = p < eps q_too_small = q < eps too_small = p_too_small | q_too_small p = p.clone() q = q.clone() p[too_small] = 0 q[too_small] = 0 p = p / p.sum(dim, keepdims=True).detach() q = q / q.sum(dim, keepdims=True).detach() p[too_small] = 1 q[too_small] = 1 els = p * (p.log() - q.log()) els[too_small] = 0 kl_div = els.sum(dim) return kl_div