|
from typing import * |
|
import math |
|
from collections import namedtuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.types |
|
import utils3d |
|
|
|
|
|
def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min: |
|
"Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`." |
|
shape = src.shape[:dim] + (size,) + src.shape[dim + 1:] |
|
minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False) |
|
minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index)) |
|
indices = torch.full(shape, -1, dtype=torch.long, device=src.device) |
|
indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim] |
|
return torch.return_types.min((minimum, indices)) |
|
|
|
|
|
def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs): |
|
batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0] |
|
n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0) |
|
splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args) |
|
splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()} |
|
results = [] |
|
for i in range(n_chunks): |
|
chunk_args = tuple(arg[i] for arg in splited_args) |
|
chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()} |
|
results.append(fn(*chunk_args, **chunk_kwargs)) |
|
|
|
if isinstance(results[0], tuple): |
|
return tuple(torch.cat(r, dim=0) for r in zip(*results)) |
|
else: |
|
return torch.cat(results, dim=0) |
|
|
|
|
|
def _pad_inf(x_: torch.Tensor): |
|
return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1) |
|
|
|
|
|
def _pad_cumsum(cumsum: torch.Tensor): |
|
return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1) |
|
|
|
|
|
def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float): |
|
return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1) |
|
|
|
|
|
def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: |
|
""" |
|
If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`. |
|
|
|
w_i must be >= 0. |
|
|
|
### Parameters: |
|
- `x`: tensor of shape (..., n) |
|
- `y`: tensor of shape (..., n) |
|
- `w`: tensor of shape (..., n) |
|
- `trunc`: optional, float or tensor of shape (..., n) or None |
|
|
|
### Returns: |
|
- `a`: tensor of shape (...), differentiable |
|
- `loss`: tensor of shape (...), value of loss function at `a`, detached |
|
- `index`: tensor of shape (...), where a = y[idx] / x[idx] |
|
""" |
|
if trunc is None: |
|
x, y, w = torch.broadcast_tensors(x, y, w) |
|
sign = torch.sign(x) |
|
x, y = x * sign, y * sign |
|
y_div_x = y / x.clamp_min(eps) |
|
y_div_x, argsort = y_div_x.sort(dim=-1) |
|
|
|
wx = torch.gather(x * w, dim=-1, index=argsort) |
|
derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True) |
|
search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1) |
|
|
|
a = y_div_x.gather(dim=-1, index=search).squeeze(-1) |
|
index = argsort.gather(dim=-1, index=search).squeeze(-1) |
|
loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1) |
|
|
|
else: |
|
|
|
x, y, w = torch.broadcast_tensors(x, y, w) |
|
batch_shape = x.shape[:-1] |
|
batch_size = math.prod(batch_shape) |
|
x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1]) |
|
|
|
sign = torch.sign(x) |
|
x, y = x * sign, y * sign |
|
wx, wy = w * x, w * y |
|
xyw = torch.stack([x, y, w], dim=-1) |
|
|
|
y_div_x = A = y / x.clamp_min(eps) |
|
B = (wy - trunc) / wx.clamp_min(eps) |
|
C = (wy + trunc) / wx.clamp_min(eps) |
|
with torch.no_grad(): |
|
|
|
A, A_argsort = A.sort(dim=-1) |
|
Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1) |
|
A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) |
|
|
|
B, B_argsort = B.sort(dim=-1) |
|
Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1) |
|
B, Q_B = _pad_inf(B), _pad_cumsum(Q_B) |
|
|
|
C, C_argsort = C.sort(dim=-1) |
|
Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1) |
|
C, Q_C = _pad_inf(C), _pad_cumsum(Q_C) |
|
|
|
|
|
j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1) |
|
j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1) |
|
j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1) |
|
left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) |
|
j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1) |
|
j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1) |
|
j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1) |
|
right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) |
|
|
|
|
|
is_extrema = (left_derivative < 0) & (right_derivative >= 0) |
|
is_extrema[..., 0] |= ~is_extrema.any(dim=-1) |
|
where_extrema_batch, where_extrema_index = torch.where(is_extrema) |
|
|
|
|
|
extrema_a = y_div_x[where_extrema_batch, where_extrema_index] |
|
MAX_ELEMENTS = 4096 ** 2 |
|
SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1] |
|
extrema_value = torch.cat([ |
|
_compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc) |
|
for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE)) |
|
]) |
|
|
|
|
|
minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) |
|
index = where_extrema_index[indices] |
|
|
|
a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps) |
|
a = a.reshape(batch_shape) |
|
loss = minima.reshape(batch_shape) |
|
index = index.reshape(batch_shape) |
|
|
|
return a, loss, index |
|
|
|
|
|
def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): |
|
""" |
|
Align `depth_src` to `depth_tgt` with given constant weights. |
|
|
|
### Parameters: |
|
- `depth_src: torch.Tensor` of shape (..., N) |
|
- `depth_tgt: torch.Tensor` of shape (..., N) |
|
|
|
""" |
|
scale, _, _ = align(depth_src, depth_tgt, weight, trunc) |
|
|
|
return scale |
|
|
|
|
|
def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): |
|
""" |
|
Align `depth_src` to `depth_tgt` with given constant weights. |
|
|
|
### Parameters: |
|
- `depth_src: torch.Tensor` of shape (..., N) |
|
- `depth_tgt: torch.Tensor` of shape (..., N) |
|
- `weight: torch.Tensor` of shape (..., N) |
|
- `trunc: float` or tensor of shape (..., N) or None |
|
|
|
### Returns: |
|
- `scale: torch.Tensor` of shape (...). |
|
- `shift: torch.Tensor` of shape (...). |
|
""" |
|
dtype, device = depth_src.dtype, depth_src.device |
|
|
|
|
|
batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1] |
|
batch_size = math.prod(batch_shape) |
|
depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n) |
|
|
|
|
|
|
|
|
|
anchors_where_batch, anchors_where_n = torch.where(weight > 0) |
|
|
|
|
|
with torch.no_grad(): |
|
depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] |
|
depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] |
|
|
|
depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] |
|
depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] |
|
weight_anchored = weight[anchors_where_batch, :] |
|
|
|
scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) |
|
|
|
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) |
|
|
|
|
|
index_1 = anchors_where_n[index_anchor] |
|
index_2 = index[index_anchor] |
|
|
|
tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1) |
|
tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1) |
|
|
|
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7) |
|
shift = tgt_1 - scale * src_1 |
|
|
|
scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape) |
|
|
|
return scale, shift |
|
|
|
def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12): |
|
""" |
|
Align `depth_src` to `depth_tgt` with given constant weights using IRLS. |
|
""" |
|
dtype, device = depth_src.dtype, depth_src.device |
|
|
|
w = weight |
|
x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1) |
|
y = depth_tgt |
|
|
|
for i in range(max_iter): |
|
beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1) |
|
w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps) |
|
|
|
return beta[..., 0], beta[..., 1] |
|
|
|
|
|
def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): |
|
""" |
|
### Parameters: |
|
- `points_src: torch.Tensor` of shape (..., N, 3) |
|
- `points_tgt: torch.Tensor` of shape (..., N, 3) |
|
- `weight: torch.Tensor` of shape (..., N) |
|
|
|
### Returns: |
|
- `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it. |
|
- `b: torch.Tensor` of shape (...) |
|
""" |
|
dtype, device = points_src.dtype, points_src.device |
|
|
|
scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc) |
|
|
|
return scale |
|
|
|
|
|
def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): |
|
""" |
|
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. |
|
It is similar to `align_affine` but scale and shift are applied to different dimensions. |
|
|
|
### Parameters: |
|
- `points_src: torch.Tensor` of shape (..., N, 3) |
|
- `points_tgt: torch.Tensor` of shape (..., N, 3) |
|
- `weights: torch.Tensor` of shape (..., N) |
|
|
|
### Returns: |
|
- `scale: torch.Tensor` of shape (...). |
|
- `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros. |
|
""" |
|
dtype, device = points_src.dtype, points_src.device |
|
|
|
|
|
batch_shape, n = points_src.shape[:-2], points_src.shape[-2] |
|
batch_size = math.prod(batch_shape) |
|
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) |
|
|
|
|
|
anchor_where_batch, anchor_where_n = torch.where(weight > 0) |
|
with torch.no_grad(): |
|
zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype) |
|
points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) |
|
points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) |
|
|
|
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] |
|
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] |
|
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) |
|
|
|
|
|
MAX_ELEMENTS = 2 ** 20 |
|
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) |
|
|
|
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) |
|
|
|
|
|
index_2 = index[index_anchor] |
|
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 |
|
|
|
zeros = torch.zeros((batch_size, n), device=device, dtype=dtype) |
|
points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1) |
|
tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) |
|
tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) |
|
|
|
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) |
|
shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) |
|
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) |
|
|
|
return scale, shift |
|
|
|
|
|
def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): |
|
""" |
|
Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. |
|
It is similar to `align_affine` but scale and shift are applied to different dimensions. |
|
|
|
### Parameters: |
|
- `points_src: torch.Tensor` of shape (..., N, 3) |
|
- `points_tgt: torch.Tensor` of shape (..., N, 3) |
|
- `weights: torch.Tensor` of shape (..., N) |
|
|
|
### Returns: |
|
- `scale: torch.Tensor` of shape (...). |
|
- `shift: torch.Tensor` of shape (..., 3) |
|
""" |
|
dtype, device = points_src.dtype, points_src.device |
|
|
|
|
|
batch_shape, n = points_src.shape[:-2], points_src.shape[-2] |
|
batch_size = math.prod(batch_shape) |
|
points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) |
|
|
|
|
|
anchor_where_batch, anchor_where_n = torch.where(weight > 0) |
|
|
|
with torch.no_grad(): |
|
points_src_anchor = points_src[anchor_where_batch, anchor_where_n] |
|
points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] |
|
|
|
points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] |
|
points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] |
|
weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) |
|
|
|
|
|
MAX_ELEMENTS = 2 ** 20 |
|
scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) |
|
|
|
|
|
loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) |
|
|
|
index_2 = index[index_anchor] |
|
index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 |
|
|
|
src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) |
|
src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) |
|
|
|
scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) |
|
shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) |
|
|
|
scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) |
|
|
|
return scale, shift |
|
|
|
|
|
def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): |
|
""" |
|
Align `points_src` to `points_tgt` with respect to a Z-axis shift. |
|
|
|
### Parameters: |
|
- `points_src: torch.Tensor` of shape (..., N, 3) |
|
- `points_tgt: torch.Tensor` of shape (..., N, 3) |
|
- `weights: torch.Tensor` of shape (..., N) |
|
|
|
### Returns: |
|
- `scale: torch.Tensor` of shape (...). |
|
- `shift: torch.Tensor` of shape (..., 3) |
|
""" |
|
dtype, device = points_src.dtype, points_src.device |
|
|
|
shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc) |
|
shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1) |
|
|
|
return shift |
|
|
|
|
|
def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): |
|
""" |
|
Align `points_src` to `points_tgt` with respect to a Z-axis shift. |
|
|
|
### Parameters: |
|
- `points_src: torch.Tensor` of shape (..., N, 3) |
|
- `points_tgt: torch.Tensor` of shape (..., N, 3) |
|
- `weights: torch.Tensor` of shape (..., N) |
|
|
|
### Returns: |
|
- `scale: torch.Tensor` of shape (...). |
|
- `shift: torch.Tensor` of shape (..., 3) |
|
""" |
|
dtype, device = points_src.dtype, points_src.device |
|
|
|
shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc) |
|
|
|
return shift |
|
|
|
|
|
def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares. |
|
|
|
### Parameters: |
|
- `x: torch.Tensor` of shape (..., N) |
|
- `y: torch.Tensor` of shape (..., N) |
|
- `w: torch.Tensor` of shape (..., N) |
|
|
|
### Returns: |
|
- `a: torch.Tensor` of shape (...,) |
|
- `b: torch.Tensor` of shape (...,) |
|
""" |
|
w_sqrt = torch.ones_like(x) if w is None else w.sqrt() |
|
A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1) |
|
B = (w_sqrt * y)[..., None] |
|
a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1) |
|
return a, b |