jev-aleks's picture
scenedino init
9e15541
import cv2
import numpy as np
import torch
from torchvision import transforms
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import functools
import math
import warnings
def image_float_to_uint8(img):
"""
Convert a float image (0.0-1.0) to uint8 (0-255)
"""
vmin = np.min(img)
vmax = np.max(img)
if vmax - vmin < 1e-10:
vmax += 1e-10
img = (img - vmin) / (vmax - vmin)
img *= 255.0
return img.astype(np.uint8)
def cmap(img, color_map=cv2.COLORMAP_HOT):
"""
Apply 'HOT' color to a float image
"""
return cv2.applyColorMap(image_float_to_uint8(img), color_map)
def batched_index_select_nd(t, inds):
"""
Index select on dim 1 of a n-dimensional batched tensor.
:param t (batch, n, ...)
:param inds (batch, k)
:return (batch, k, ...)
"""
return t.gather(
1, inds[(...,) + (None,) * (len(t.shape) - 2)].expand(-1, -1, *t.shape[2:])
)
def batched_index_select_nd_last(t, inds):
"""
Index select on dim -1 of a >=2D multi-batched tensor. inds assumed
to have all batch dimensions except one data dimension 'n'
:param t (batch..., n, m)
:param inds (batch..., k)
:return (batch..., n, k)
"""
dummy = inds.unsqueeze(-2).expand(*inds.shape[:-1], t.size(-2), inds.size(-1))
out = t.gather(-1, dummy)
return out
def repeat_interleave(input, repeats, dim=0):
"""
Repeat interleave along axis 0
torch.repeat_interleave is currently very slow
https://github.com/pytorch/pytorch/issues/31980
"""
output = input.unsqueeze(1).expand(-1, repeats, *input.shape[1:])
return output.reshape(-1, *input.shape[1:])
def get_image_to_tensor_balanced(image_size=0):
ops = []
if image_size > 0:
ops.append(transforms.Resize(image_size))
ops.extend(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]
)
return transforms.Compose(ops)
def get_mask_to_tensor():
return transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))]
)
def homogeneous(points):
"""
Concat 1 to each point
:param points (..., 3)
:return (..., 4)
"""
return F.pad(points, (0, 1), "constant", 1.0)
def gen_grid(*args, ij_indexing=False):
"""
Generete len(args)-dimensional grid.
Each arg should be (lo, hi, sz) so that in that dimension points
are taken at linspace(lo, hi, sz).
Example: gen_grid((0,1,10), (-1,1,20))
:return (prod_i args_i[2], len(args)), len(args)-dimensional grid points
"""
return torch.from_numpy(
np.vstack(
np.meshgrid(
*(np.linspace(lo, hi, sz, dtype=np.float32) for lo, hi, sz in args),
indexing="ij" if ij_indexing else "xy"
)
)
.reshape(len(args), -1)
.T
)
def unproj_map(width, height, f, c=None, device="cpu", norm_dir=True, xy_offset=None):
"""
Get camera unprojection map for given image size.
[y,x] of output tensor will contain unit vector of camera ray of that pixel.
:param width image width
:param height image height
:param f focal length, either a number or tensor [fx, fy]
:param c principal point, optional, either None or tensor [fx, fy]
if not specified uses center of image
:return unproj map (height, width, 3)
"""
if c is None:
c = torch.tensor([[0.0, 0.0]], device=device)
elif isinstance(c, float):
c = torch.tensor([[c, c]], device=device)
elif len(c.shape) == 0:
c = c[None, None].expand(1, 2)
elif len(c.shape) == 1:
c = c.unsqueeze(-1).expand(1, 2)
if isinstance(f, float):
f = torch.tensor([[f, f]], device=device)
elif len(f.shape) == 0:
f = f[None, None].expand(1, 2)
elif len(f.shape) == 1:
f = f.unsqueeze(-1).expand(1, 2)
n = f.shape[0]
pixel_width = 2 / width
pixel_height = 2 / height
x = torch.linspace(-1 + .5 * pixel_width, 1 - .5 * pixel_width, width, dtype=torch.float32, device=device).view(1, 1, width).expand(n, height, width)
y = torch.linspace(-1 + .5 * pixel_height, 1 - .5 * pixel_height, height, dtype=torch.float32, device=device).view(1, height, 1).expand(n, height, width)
if xy_offset is not None:
x = x + xy_offset[0] * pixel_width
y = y + xy_offset[1] * pixel_height
xy_img = torch.stack((x, y), dim=-1)
xy = (xy_img - c.view(n, 1, 1, 2)) / f.view(n, 1, 1, 2)
z = torch.ones_like(x).unsqueeze(-1)
unproj = torch.cat((xy, z), dim=-1)
if norm_dir:
unproj /= torch.norm(unproj, dim=-1).unsqueeze(-1)
return unproj, xy_img
def coord_from_blender(dtype=torch.float32, device="cpu"):
"""
Blender to standard coordinate system transform.
Standard coordinate system is: x right y up z out (out=screen to face)
Blender coordinate system is: x right y in z up
:return (4, 4)
"""
return torch.tensor(
[[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]],
dtype=dtype,
device=device,
)
def coord_to_blender(dtype=torch.float32, device="cpu"):
"""
Standard to Blender coordinate system transform.
Standard coordinate system is: x right y up z out (out=screen to face)
Blender coordinate system is: x right y in z up
:return (4, 4)
"""
return torch.tensor(
[[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]],
dtype=dtype,
device=device,
)
def look_at(origin, target, world_up=np.array([0, 1, 0], dtype=np.float32)):
"""
Get 4x4 camera to world space matrix, for camera looking at target
"""
back = origin - target
back /= np.linalg.norm(back)
right = np.cross(world_up, back)
right /= np.linalg.norm(right)
up = np.cross(back, right)
cam_to_world = np.empty((4, 4), dtype=np.float32)
cam_to_world[:3, 0] = right
cam_to_world[:3, 1] = up
cam_to_world[:3, 2] = back
cam_to_world[:3, 3] = origin
cam_to_world[3, :] = [0, 0, 0, 1]
return cam_to_world
def get_cuda(gpu_id):
"""
Get a torch.device for GPU gpu_id. If GPU not available,
returns CPU device.
"""
return (
torch.device("cuda:%d" % gpu_id)
if torch.cuda.is_available()
else torch.device("cpu")
)
def masked_sample(masks, num_pix, prop_inside, thresh=0.5):
"""
:return (num_pix, 3)
"""
num_inside = int(num_pix * prop_inside + 0.5)
num_outside = num_pix - num_inside
inside = (masks >= thresh).nonzero(as_tuple=False)
outside = (masks < thresh).nonzero(as_tuple=False)
pix_inside = inside[torch.randint(0, inside.shape[0], (num_inside,))]
pix_outside = outside[torch.randint(0, outside.shape[0], (num_outside,))]
pix = torch.cat((pix_inside, pix_outside))
return pix
def bbox_sample(bboxes, num_pix):
"""
:return (num_pix, 3)
"""
image_ids = torch.randint(0, bboxes.shape[0], (num_pix,))
pix_bboxes = bboxes[image_ids]
x = (
torch.rand(num_pix) * (pix_bboxes[:, 2] + 1 - pix_bboxes[:, 0])
+ pix_bboxes[:, 0]
).long()
y = (
torch.rand(num_pix) * (pix_bboxes[:, 3] + 1 - pix_bboxes[:, 1])
+ pix_bboxes[:, 1]
).long()
pix = torch.stack((image_ids, y, x), dim=-1)
return pix
def gen_rays(poses, width, height, z_near, z_far, focal=None, c=None, norm_dir=True, xy_offset=None):
"""
Generate camera rays
:return (B, H, W, 8)
"""
num_images = poses.shape[0]
device = poses.device
cam_unproj_map, xy = unproj_map(width, height, focal, c=c, device=device, norm_dir=norm_dir, xy_offset=xy_offset)
cam_unproj_map = cam_unproj_map.expand(num_images, -1, -1, -1)
xy = xy.expand(num_images, -1, -1, -1)
cam_centers = poses[:, None, None, :3, 3].expand(-1, height, width, -1)
cam_raydir = torch.matmul(
poses[:, None, None, :3, :3], cam_unproj_map.unsqueeze(-1)
)[:, :, :, :, 0]
cam_nears = (
torch.tensor(z_near, device=device)
.view(1, 1, 1, 1)
.expand(num_images, height, width, -1)
)
cam_fars = (
torch.tensor(z_far, device=device)
.view(1, 1, 1, 1)
.expand(num_images, height, width, -1)
)
rays = torch.cat(
(cam_centers, cam_raydir, cam_nears, cam_fars), dim=-1
)
return rays, xy
def trans_t(t):
return torch.tensor(
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1],], dtype=torch.float32,
)
def rot_phi(phi):
return torch.tensor(
[
[1, 0, 0, 0],
[0, np.cos(phi), -np.sin(phi), 0],
[0, np.sin(phi), np.cos(phi), 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
def rot_theta(th):
return torch.tensor(
[
[np.cos(th), 0, -np.sin(th), 0],
[0, 1, 0, 0],
[np.sin(th), 0, np.cos(th), 0],
[0, 0, 0, 1],
],
dtype=torch.float32,
)
def pose_spherical(theta, phi, radius):
"""
Spherical rendering poses, from NeRF
"""
c2w = trans_t(radius)
c2w = rot_phi(phi / 180.0 * np.pi) @ c2w
c2w = rot_theta(theta / 180.0 * np.pi) @ c2w
c2w = (
torch.tensor(
[[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]],
dtype=torch.float32,
)
@ c2w
)
return c2w
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_norm_layer(norm_type="instance", group_norm_groups=32):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == "batch":
norm_layer = functools.partial(
nn.BatchNorm2d, affine=True, track_running_stats=True
)
elif norm_type == "instance":
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False
)
elif norm_type == "group":
norm_layer = functools.partial(nn.GroupNorm, group_norm_groups)
elif norm_type == "none":
norm_layer = None
else:
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
return norm_layer
def make_conv_2d(
dim_in,
dim_out,
padding_type="reflect",
norm_layer=None,
activation=None,
kernel_size=3,
use_bias=False,
stride=1,
no_pad=False,
zero_init=False,
):
conv_block = []
amt = kernel_size // 2
if stride > 1 and not no_pad:
raise NotImplementedError(
"Padding with stride > 1 not supported, use same_pad_conv2d"
)
if amt > 0 and not no_pad:
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(amt)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(amt)]
elif padding_type == "zero":
conv_block += [nn.ZeroPad2d(amt)]
else:
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
conv_block.append(
nn.Conv2d(
dim_in, dim_out, kernel_size=kernel_size, bias=use_bias, stride=stride
)
)
if zero_init:
nn.init.zeros_(conv_block[-1].weight)
# else:
# nn.init.kaiming_normal_(conv_block[-1].weight)
if norm_layer is not None:
conv_block.append(norm_layer(dim_out))
if activation is not None:
conv_block.append(activation)
return nn.Sequential(*conv_block)
def calc_same_pad_conv2d(t_shape, kernel_size=3, stride=1):
in_height, in_width = t_shape[-2:]
out_height = math.ceil(in_height / stride)
out_width = math.ceil(in_width / stride)
pad_along_height = max((out_height - 1) * stride + kernel_size - in_height, 0)
pad_along_width = max((out_width - 1) * stride + kernel_size - in_width, 0)
pad_top = pad_along_height // 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width // 2
pad_right = pad_along_width - pad_left
return pad_left, pad_right, pad_top, pad_bottom
def same_pad_conv2d(t, padding_type="reflect", kernel_size=3, stride=1, layer=None):
"""
Perform SAME padding on tensor, given kernel size/stride of conv operator
assumes kernel/stride are equal in all dimensions.
Use before conv called.
Dilation not supported.
:param t image tensor input (B, C, H, W)
:param padding_type padding type constant | reflect | replicate | circular
constant is 0-pad.
:param kernel_size kernel size of conv
:param stride stride of conv
:param layer optionally, pass conv layer to automatically get kernel_size and stride
(overrides these)
"""
if layer is not None:
if isinstance(layer, nn.Sequential):
layer = next(layer.children())
kernel_size = layer.kernel_size[0]
stride = layer.stride[0]
return F.pad(
t, calc_same_pad_conv2d(t.shape, kernel_size, stride), mode=padding_type
)
def same_unpad_deconv2d(t, kernel_size=3, stride=1, layer=None):
"""
Perform SAME unpad on tensor, given kernel/stride of deconv operator.
Use after deconv called.
Dilation not supported.
"""
if layer is not None:
if isinstance(layer, nn.Sequential):
layer = next(layer.children())
kernel_size = layer.kernel_size[0]
stride = layer.stride[0]
h_scaled = (t.shape[-2] - 1) * stride
w_scaled = (t.shape[-1] - 1) * stride
pad_left, pad_right, pad_top, pad_bottom = calc_same_pad_conv2d(
(h_scaled, w_scaled), kernel_size, stride
)
if pad_right == 0:
pad_right = -10000
if pad_bottom == 0:
pad_bottom = -10000
return t[..., pad_top:-pad_bottom, pad_left:-pad_right]
def combine_interleaved(t, inner_dims=(1,), agg_type="average"):
if len(inner_dims) == 1 and inner_dims[0] == 1:
return t
t = t.reshape(-1, *inner_dims, *t.shape[1:])
if agg_type == "average":
t = torch.mean(t, dim=1)
elif agg_type == "max":
t = torch.max(t, dim=1)[0]
else:
raise NotImplementedError("Unsupported combine type " + agg_type)
return t
def psnr(pred, target):
"""
Compute PSNR of two tensors in decibels.
pred/target should be of same size or broadcastable
"""
mse = ((pred - target) ** 2).mean()
psnr = -10 * math.log10(mse)
return psnr
def quat_to_rot(q):
"""
Quaternion to rotation matrix
"""
batch_size, _ = q.shape
q = F.normalize(q, dim=1)
R = torch.ones((batch_size, 3, 3), device=q.device)
qr = q[:, 0]
qi = q[:, 1]
qj = q[:, 2]
qk = q[:, 3]
R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2)
R[:, 0, 1] = 2 * (qj * qi - qk * qr)
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2)
R[:, 1, 2] = 2 * (qj * qk - qi * qr)
R[:, 2, 0] = 2 * (qk * qi - qj * qr)
R[:, 2, 1] = 2 * (qj * qk + qi * qr)
R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2)
return R
def rot_to_quat(R):
"""
Rotation matrix to quaternion
"""
batch_size, _, _ = R.shape
q = torch.ones((batch_size, 4), device=R.device)
R00 = R[:, 0, 0]
R01 = R[:, 0, 1]
R02 = R[:, 0, 2]
R10 = R[:, 1, 0]
R11 = R[:, 1, 1]
R12 = R[:, 1, 2]
R20 = R[:, 2, 0]
R21 = R[:, 2, 1]
R22 = R[:, 2, 2]
q[:, 0] = torch.sqrt(1.0 + R00 + R11 + R22) / 2
q[:, 1] = (R21 - R12) / (4 * q[:, 0])
q[:, 2] = (R02 - R20) / (4 * q[:, 0])
q[:, 3] = (R10 - R01) / (4 * q[:, 0])
return q
def get_module(net):
"""
Shorthand for either net.module (if net is instance of DataParallel) or net
"""
if isinstance(net, torch.nn.DataParallel):
return net.module
else:
return net
# Nan safe entropy
def normalized_entropy(p, dim=-1, eps=2 ** (-8)):
H_max = math.log2(p.shape[dim])
# x log2 (x) -> 0 . Therefore, we can set log2 (x) to 0 if x is small enough.
# This should ensure numerical stability.
p_too_small = p < eps
p = p.clone()
p[p_too_small] = 1
plp = torch.log2(p) * p
plp[p_too_small] = 0
# This is the formula for the normalised entropy
entropy = -plp.sum(dim) / H_max
return entropy
def kl_div(p, q, dim=-1, eps=2 ** (-8)):
p_too_small = p < eps
q_too_small = q < eps
too_small = p_too_small | q_too_small
p = p.clone()
q = q.clone()
p[too_small] = 0
q[too_small] = 0
p = p / p.sum(dim, keepdims=True).detach()
q = q / q.sum(dim, keepdims=True).detach()
p[too_small] = 1
q[too_small] = 1
els = p * (p.log() - q.log())
els[too_small] = 0
kl_div = els.sum(dim)
return kl_div