Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from torch import nn | |
from torch.nn import init | |
import torch.nn.functional as F | |
import functools | |
import math | |
import warnings | |
def image_float_to_uint8(img): | |
""" | |
Convert a float image (0.0-1.0) to uint8 (0-255) | |
""" | |
vmin = np.min(img) | |
vmax = np.max(img) | |
if vmax - vmin < 1e-10: | |
vmax += 1e-10 | |
img = (img - vmin) / (vmax - vmin) | |
img *= 255.0 | |
return img.astype(np.uint8) | |
def cmap(img, color_map=cv2.COLORMAP_HOT): | |
""" | |
Apply 'HOT' color to a float image | |
""" | |
return cv2.applyColorMap(image_float_to_uint8(img), color_map) | |
def batched_index_select_nd(t, inds): | |
""" | |
Index select on dim 1 of a n-dimensional batched tensor. | |
:param t (batch, n, ...) | |
:param inds (batch, k) | |
:return (batch, k, ...) | |
""" | |
return t.gather( | |
1, inds[(...,) + (None,) * (len(t.shape) - 2)].expand(-1, -1, *t.shape[2:]) | |
) | |
def batched_index_select_nd_last(t, inds): | |
""" | |
Index select on dim -1 of a >=2D multi-batched tensor. inds assumed | |
to have all batch dimensions except one data dimension 'n' | |
:param t (batch..., n, m) | |
:param inds (batch..., k) | |
:return (batch..., n, k) | |
""" | |
dummy = inds.unsqueeze(-2).expand(*inds.shape[:-1], t.size(-2), inds.size(-1)) | |
out = t.gather(-1, dummy) | |
return out | |
def repeat_interleave(input, repeats, dim=0): | |
""" | |
Repeat interleave along axis 0 | |
torch.repeat_interleave is currently very slow | |
https://github.com/pytorch/pytorch/issues/31980 | |
""" | |
output = input.unsqueeze(1).expand(-1, repeats, *input.shape[1:]) | |
return output.reshape(-1, *input.shape[1:]) | |
def get_image_to_tensor_balanced(image_size=0): | |
ops = [] | |
if image_size > 0: | |
ops.append(transforms.Resize(image_size)) | |
ops.extend( | |
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),] | |
) | |
return transforms.Compose(ops) | |
def get_mask_to_tensor(): | |
return transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.0,), (1.0,))] | |
) | |
def homogeneous(points): | |
""" | |
Concat 1 to each point | |
:param points (..., 3) | |
:return (..., 4) | |
""" | |
return F.pad(points, (0, 1), "constant", 1.0) | |
def gen_grid(*args, ij_indexing=False): | |
""" | |
Generete len(args)-dimensional grid. | |
Each arg should be (lo, hi, sz) so that in that dimension points | |
are taken at linspace(lo, hi, sz). | |
Example: gen_grid((0,1,10), (-1,1,20)) | |
:return (prod_i args_i[2], len(args)), len(args)-dimensional grid points | |
""" | |
return torch.from_numpy( | |
np.vstack( | |
np.meshgrid( | |
*(np.linspace(lo, hi, sz, dtype=np.float32) for lo, hi, sz in args), | |
indexing="ij" if ij_indexing else "xy" | |
) | |
) | |
.reshape(len(args), -1) | |
.T | |
) | |
def unproj_map(width, height, f, c=None, device="cpu", norm_dir=True, xy_offset=None): | |
""" | |
Get camera unprojection map for given image size. | |
[y,x] of output tensor will contain unit vector of camera ray of that pixel. | |
:param width image width | |
:param height image height | |
:param f focal length, either a number or tensor [fx, fy] | |
:param c principal point, optional, either None or tensor [fx, fy] | |
if not specified uses center of image | |
:return unproj map (height, width, 3) | |
""" | |
if c is None: | |
c = torch.tensor([[0.0, 0.0]], device=device) | |
elif isinstance(c, float): | |
c = torch.tensor([[c, c]], device=device) | |
elif len(c.shape) == 0: | |
c = c[None, None].expand(1, 2) | |
elif len(c.shape) == 1: | |
c = c.unsqueeze(-1).expand(1, 2) | |
if isinstance(f, float): | |
f = torch.tensor([[f, f]], device=device) | |
elif len(f.shape) == 0: | |
f = f[None, None].expand(1, 2) | |
elif len(f.shape) == 1: | |
f = f.unsqueeze(-1).expand(1, 2) | |
n = f.shape[0] | |
pixel_width = 2 / width | |
pixel_height = 2 / height | |
x = torch.linspace(-1 + .5 * pixel_width, 1 - .5 * pixel_width, width, dtype=torch.float32, device=device).view(1, 1, width).expand(n, height, width) | |
y = torch.linspace(-1 + .5 * pixel_height, 1 - .5 * pixel_height, height, dtype=torch.float32, device=device).view(1, height, 1).expand(n, height, width) | |
if xy_offset is not None: | |
x = x + xy_offset[0] * pixel_width | |
y = y + xy_offset[1] * pixel_height | |
xy_img = torch.stack((x, y), dim=-1) | |
xy = (xy_img - c.view(n, 1, 1, 2)) / f.view(n, 1, 1, 2) | |
z = torch.ones_like(x).unsqueeze(-1) | |
unproj = torch.cat((xy, z), dim=-1) | |
if norm_dir: | |
unproj /= torch.norm(unproj, dim=-1).unsqueeze(-1) | |
return unproj, xy_img | |
def coord_from_blender(dtype=torch.float32, device="cpu"): | |
""" | |
Blender to standard coordinate system transform. | |
Standard coordinate system is: x right y up z out (out=screen to face) | |
Blender coordinate system is: x right y in z up | |
:return (4, 4) | |
""" | |
return torch.tensor( | |
[[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]], | |
dtype=dtype, | |
device=device, | |
) | |
def coord_to_blender(dtype=torch.float32, device="cpu"): | |
""" | |
Standard to Blender coordinate system transform. | |
Standard coordinate system is: x right y up z out (out=screen to face) | |
Blender coordinate system is: x right y in z up | |
:return (4, 4) | |
""" | |
return torch.tensor( | |
[[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], | |
dtype=dtype, | |
device=device, | |
) | |
def look_at(origin, target, world_up=np.array([0, 1, 0], dtype=np.float32)): | |
""" | |
Get 4x4 camera to world space matrix, for camera looking at target | |
""" | |
back = origin - target | |
back /= np.linalg.norm(back) | |
right = np.cross(world_up, back) | |
right /= np.linalg.norm(right) | |
up = np.cross(back, right) | |
cam_to_world = np.empty((4, 4), dtype=np.float32) | |
cam_to_world[:3, 0] = right | |
cam_to_world[:3, 1] = up | |
cam_to_world[:3, 2] = back | |
cam_to_world[:3, 3] = origin | |
cam_to_world[3, :] = [0, 0, 0, 1] | |
return cam_to_world | |
def get_cuda(gpu_id): | |
""" | |
Get a torch.device for GPU gpu_id. If GPU not available, | |
returns CPU device. | |
""" | |
return ( | |
torch.device("cuda:%d" % gpu_id) | |
if torch.cuda.is_available() | |
else torch.device("cpu") | |
) | |
def masked_sample(masks, num_pix, prop_inside, thresh=0.5): | |
""" | |
:return (num_pix, 3) | |
""" | |
num_inside = int(num_pix * prop_inside + 0.5) | |
num_outside = num_pix - num_inside | |
inside = (masks >= thresh).nonzero(as_tuple=False) | |
outside = (masks < thresh).nonzero(as_tuple=False) | |
pix_inside = inside[torch.randint(0, inside.shape[0], (num_inside,))] | |
pix_outside = outside[torch.randint(0, outside.shape[0], (num_outside,))] | |
pix = torch.cat((pix_inside, pix_outside)) | |
return pix | |
def bbox_sample(bboxes, num_pix): | |
""" | |
:return (num_pix, 3) | |
""" | |
image_ids = torch.randint(0, bboxes.shape[0], (num_pix,)) | |
pix_bboxes = bboxes[image_ids] | |
x = ( | |
torch.rand(num_pix) * (pix_bboxes[:, 2] + 1 - pix_bboxes[:, 0]) | |
+ pix_bboxes[:, 0] | |
).long() | |
y = ( | |
torch.rand(num_pix) * (pix_bboxes[:, 3] + 1 - pix_bboxes[:, 1]) | |
+ pix_bboxes[:, 1] | |
).long() | |
pix = torch.stack((image_ids, y, x), dim=-1) | |
return pix | |
def gen_rays(poses, width, height, z_near, z_far, focal=None, c=None, norm_dir=True, xy_offset=None): | |
""" | |
Generate camera rays | |
:return (B, H, W, 8) | |
""" | |
num_images = poses.shape[0] | |
device = poses.device | |
cam_unproj_map, xy = unproj_map(width, height, focal, c=c, device=device, norm_dir=norm_dir, xy_offset=xy_offset) | |
cam_unproj_map = cam_unproj_map.expand(num_images, -1, -1, -1) | |
xy = xy.expand(num_images, -1, -1, -1) | |
cam_centers = poses[:, None, None, :3, 3].expand(-1, height, width, -1) | |
cam_raydir = torch.matmul( | |
poses[:, None, None, :3, :3], cam_unproj_map.unsqueeze(-1) | |
)[:, :, :, :, 0] | |
cam_nears = ( | |
torch.tensor(z_near, device=device) | |
.view(1, 1, 1, 1) | |
.expand(num_images, height, width, -1) | |
) | |
cam_fars = ( | |
torch.tensor(z_far, device=device) | |
.view(1, 1, 1, 1) | |
.expand(num_images, height, width, -1) | |
) | |
rays = torch.cat( | |
(cam_centers, cam_raydir, cam_nears, cam_fars), dim=-1 | |
) | |
return rays, xy | |
def trans_t(t): | |
return torch.tensor( | |
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1],], dtype=torch.float32, | |
) | |
def rot_phi(phi): | |
return torch.tensor( | |
[ | |
[1, 0, 0, 0], | |
[0, np.cos(phi), -np.sin(phi), 0], | |
[0, np.sin(phi), np.cos(phi), 0], | |
[0, 0, 0, 1], | |
], | |
dtype=torch.float32, | |
) | |
def rot_theta(th): | |
return torch.tensor( | |
[ | |
[np.cos(th), 0, -np.sin(th), 0], | |
[0, 1, 0, 0], | |
[np.sin(th), 0, np.cos(th), 0], | |
[0, 0, 0, 1], | |
], | |
dtype=torch.float32, | |
) | |
def pose_spherical(theta, phi, radius): | |
""" | |
Spherical rendering poses, from NeRF | |
""" | |
c2w = trans_t(radius) | |
c2w = rot_phi(phi / 180.0 * np.pi) @ c2w | |
c2w = rot_theta(theta / 180.0 * np.pi) @ c2w | |
c2w = ( | |
torch.tensor( | |
[[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], | |
dtype=torch.float32, | |
) | |
) | |
return c2w | |
def count_parameters(model): | |
return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
def get_norm_layer(norm_type="instance", group_norm_groups=32): | |
"""Return a normalization layer | |
Parameters: | |
norm_type (str) -- the name of the normalization layer: batch | instance | none | |
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). | |
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. | |
""" | |
if norm_type == "batch": | |
norm_layer = functools.partial( | |
nn.BatchNorm2d, affine=True, track_running_stats=True | |
) | |
elif norm_type == "instance": | |
norm_layer = functools.partial( | |
nn.InstanceNorm2d, affine=False, track_running_stats=False | |
) | |
elif norm_type == "group": | |
norm_layer = functools.partial(nn.GroupNorm, group_norm_groups) | |
elif norm_type == "none": | |
norm_layer = None | |
else: | |
raise NotImplementedError("normalization layer [%s] is not found" % norm_type) | |
return norm_layer | |
def make_conv_2d( | |
dim_in, | |
dim_out, | |
padding_type="reflect", | |
norm_layer=None, | |
activation=None, | |
kernel_size=3, | |
use_bias=False, | |
stride=1, | |
no_pad=False, | |
zero_init=False, | |
): | |
conv_block = [] | |
amt = kernel_size // 2 | |
if stride > 1 and not no_pad: | |
raise NotImplementedError( | |
"Padding with stride > 1 not supported, use same_pad_conv2d" | |
) | |
if amt > 0 and not no_pad: | |
if padding_type == "reflect": | |
conv_block += [nn.ReflectionPad2d(amt)] | |
elif padding_type == "replicate": | |
conv_block += [nn.ReplicationPad2d(amt)] | |
elif padding_type == "zero": | |
conv_block += [nn.ZeroPad2d(amt)] | |
else: | |
raise NotImplementedError("padding [%s] is not implemented" % padding_type) | |
conv_block.append( | |
nn.Conv2d( | |
dim_in, dim_out, kernel_size=kernel_size, bias=use_bias, stride=stride | |
) | |
) | |
if zero_init: | |
nn.init.zeros_(conv_block[-1].weight) | |
# else: | |
# nn.init.kaiming_normal_(conv_block[-1].weight) | |
if norm_layer is not None: | |
conv_block.append(norm_layer(dim_out)) | |
if activation is not None: | |
conv_block.append(activation) | |
return nn.Sequential(*conv_block) | |
def calc_same_pad_conv2d(t_shape, kernel_size=3, stride=1): | |
in_height, in_width = t_shape[-2:] | |
out_height = math.ceil(in_height / stride) | |
out_width = math.ceil(in_width / stride) | |
pad_along_height = max((out_height - 1) * stride + kernel_size - in_height, 0) | |
pad_along_width = max((out_width - 1) * stride + kernel_size - in_width, 0) | |
pad_top = pad_along_height // 2 | |
pad_bottom = pad_along_height - pad_top | |
pad_left = pad_along_width // 2 | |
pad_right = pad_along_width - pad_left | |
return pad_left, pad_right, pad_top, pad_bottom | |
def same_pad_conv2d(t, padding_type="reflect", kernel_size=3, stride=1, layer=None): | |
""" | |
Perform SAME padding on tensor, given kernel size/stride of conv operator | |
assumes kernel/stride are equal in all dimensions. | |
Use before conv called. | |
Dilation not supported. | |
:param t image tensor input (B, C, H, W) | |
:param padding_type padding type constant | reflect | replicate | circular | |
constant is 0-pad. | |
:param kernel_size kernel size of conv | |
:param stride stride of conv | |
:param layer optionally, pass conv layer to automatically get kernel_size and stride | |
(overrides these) | |
""" | |
if layer is not None: | |
if isinstance(layer, nn.Sequential): | |
layer = next(layer.children()) | |
kernel_size = layer.kernel_size[0] | |
stride = layer.stride[0] | |
return F.pad( | |
t, calc_same_pad_conv2d(t.shape, kernel_size, stride), mode=padding_type | |
) | |
def same_unpad_deconv2d(t, kernel_size=3, stride=1, layer=None): | |
""" | |
Perform SAME unpad on tensor, given kernel/stride of deconv operator. | |
Use after deconv called. | |
Dilation not supported. | |
""" | |
if layer is not None: | |
if isinstance(layer, nn.Sequential): | |
layer = next(layer.children()) | |
kernel_size = layer.kernel_size[0] | |
stride = layer.stride[0] | |
h_scaled = (t.shape[-2] - 1) * stride | |
w_scaled = (t.shape[-1] - 1) * stride | |
pad_left, pad_right, pad_top, pad_bottom = calc_same_pad_conv2d( | |
(h_scaled, w_scaled), kernel_size, stride | |
) | |
if pad_right == 0: | |
pad_right = -10000 | |
if pad_bottom == 0: | |
pad_bottom = -10000 | |
return t[..., pad_top:-pad_bottom, pad_left:-pad_right] | |
def combine_interleaved(t, inner_dims=(1,), agg_type="average"): | |
if len(inner_dims) == 1 and inner_dims[0] == 1: | |
return t | |
t = t.reshape(-1, *inner_dims, *t.shape[1:]) | |
if agg_type == "average": | |
t = torch.mean(t, dim=1) | |
elif agg_type == "max": | |
t = torch.max(t, dim=1)[0] | |
else: | |
raise NotImplementedError("Unsupported combine type " + agg_type) | |
return t | |
def psnr(pred, target): | |
""" | |
Compute PSNR of two tensors in decibels. | |
pred/target should be of same size or broadcastable | |
""" | |
mse = ((pred - target) ** 2).mean() | |
psnr = -10 * math.log10(mse) | |
return psnr | |
def quat_to_rot(q): | |
""" | |
Quaternion to rotation matrix | |
""" | |
batch_size, _ = q.shape | |
q = F.normalize(q, dim=1) | |
R = torch.ones((batch_size, 3, 3), device=q.device) | |
qr = q[:, 0] | |
qi = q[:, 1] | |
qj = q[:, 2] | |
qk = q[:, 3] | |
R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2) | |
R[:, 0, 1] = 2 * (qj * qi - qk * qr) | |
R[:, 0, 2] = 2 * (qi * qk + qr * qj) | |
R[:, 1, 0] = 2 * (qj * qi + qk * qr) | |
R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2) | |
R[:, 1, 2] = 2 * (qj * qk - qi * qr) | |
R[:, 2, 0] = 2 * (qk * qi - qj * qr) | |
R[:, 2, 1] = 2 * (qj * qk + qi * qr) | |
R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2) | |
return R | |
def rot_to_quat(R): | |
""" | |
Rotation matrix to quaternion | |
""" | |
batch_size, _, _ = R.shape | |
q = torch.ones((batch_size, 4), device=R.device) | |
R00 = R[:, 0, 0] | |
R01 = R[:, 0, 1] | |
R02 = R[:, 0, 2] | |
R10 = R[:, 1, 0] | |
R11 = R[:, 1, 1] | |
R12 = R[:, 1, 2] | |
R20 = R[:, 2, 0] | |
R21 = R[:, 2, 1] | |
R22 = R[:, 2, 2] | |
q[:, 0] = torch.sqrt(1.0 + R00 + R11 + R22) / 2 | |
q[:, 1] = (R21 - R12) / (4 * q[:, 0]) | |
q[:, 2] = (R02 - R20) / (4 * q[:, 0]) | |
q[:, 3] = (R10 - R01) / (4 * q[:, 0]) | |
return q | |
def get_module(net): | |
""" | |
Shorthand for either net.module (if net is instance of DataParallel) or net | |
""" | |
if isinstance(net, torch.nn.DataParallel): | |
return net.module | |
else: | |
return net | |
# Nan safe entropy | |
def normalized_entropy(p, dim=-1, eps=2 ** (-8)): | |
H_max = math.log2(p.shape[dim]) | |
# x log2 (x) -> 0 . Therefore, we can set log2 (x) to 0 if x is small enough. | |
# This should ensure numerical stability. | |
p_too_small = p < eps | |
p = p.clone() | |
p[p_too_small] = 1 | |
plp = torch.log2(p) * p | |
plp[p_too_small] = 0 | |
# This is the formula for the normalised entropy | |
entropy = -plp.sum(dim) / H_max | |
return entropy | |
def kl_div(p, q, dim=-1, eps=2 ** (-8)): | |
p_too_small = p < eps | |
q_too_small = q < eps | |
too_small = p_too_small | q_too_small | |
p = p.clone() | |
q = q.clone() | |
p[too_small] = 0 | |
q[too_small] = 0 | |
p = p / p.sum(dim, keepdims=True).detach() | |
q = q / q.sum(dim, keepdims=True).detach() | |
p[too_small] = 1 | |
q[too_small] = 1 | |
els = p * (p.log() - q.log()) | |
els[too_small] = 0 | |
kl_div = els.sum(dim) | |
return kl_div | |