File size: 4,448 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# utilitary functions for DUSt3R
# --------------------------------------------------------
import torch
def fill_default_args(kwargs, func):
import inspect # a bit hacky but it works reliably
signature = inspect.signature(func)
for k, v in signature.parameters.items():
if v.default is inspect.Parameter.empty:
continue
kwargs.setdefault(k, v.default)
return kwargs
def freeze_all_params(modules):
for module in modules:
try:
for n, param in module.named_parameters():
param.requires_grad = False
except AttributeError:
# module is directly a parameter
module.requires_grad = False
def is_symmetrized(gt1, gt2):
x = gt1['instance']
y = gt2['instance']
if len(x) == len(y) and len(x) == 1:
return False # special case of batchsize 1
ok = True
for i in range(0, len(x), 2):
ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
return ok
def flip(tensor):
""" flip so that tensor[0::2] <=> tensor[1::2] """
return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
def interleave(tensor1, tensor2):
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
return res1, res2
def _interleave_imgs(img1, img2):
res = {}
for key, value1 in img1.items():
value2 = img2[key]
if isinstance(value1, torch.Tensor):
value = torch.stack((value1, value2), dim=1).flatten(0, 1)
else:
value = [x for pair in zip(value1, value2) for x in pair]
res[key] = value
return res
def make_batch_symmetric(view1, view2):
view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
return view1, view2
def transpose_to_landscape(head, activate=True):
""" Predict in the correct aspect-ratio,
then transpose the result in landscape
and stack everything back together.
"""
def wrapper_no(decout, true_shape, ray_embedding=None):
B = len(true_shape)
assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
H, W = true_shape[0].cpu().tolist()
res = head(decout, (H, W), ray_embedding=ray_embedding)
return res
def wrapper_yes(decout, true_shape, ray_embedding=None):
B = len(true_shape)
# by definition, the batch is in landscape mode so W >= H
H, W = int(true_shape.min()), int(true_shape.max())
height, width = true_shape.T
is_landscape = (width >= height)
is_portrait = ~is_landscape
# true_shape = true_shape.cpu()
if is_landscape.all():
return head(decout, (H, W), ray_embedding=ray_embedding)
if is_portrait.all():
return transposed(head(decout, (W, H), ray_embedding=ray_embedding))
# batch is a mix of both portraint & landscape
def selout(ar): return [d[ar] for d in decout]
l_result = head(selout(is_landscape), (H, W), ray_embedding=ray_embedding)
p_result = transposed(head(selout(is_portrait), (W, H), ray_embedding=ray_embedding))
# allocate full result
result = {}
for k in l_result | p_result:
x = l_result[k].new(B, *l_result[k].shape[1:])
x[is_landscape] = l_result[k]
x[is_portrait] = p_result[k]
result[k] = x
return result
return wrapper_yes if activate else wrapper_no
def transposed(dic):
return {k: v.swapaxes(1, 2) for k, v in dic.items()}
def invalid_to_nans(arr, valid_mask, ndim=999):
if valid_mask is not None:
arr = arr.clone()
arr[~valid_mask] = float('nan')
if arr.ndim > ndim:
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
return arr
def invalid_to_zeros(arr, valid_mask, ndim=999):
if valid_mask is not None:
arr = arr.clone()
arr[~valid_mask] = 0
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
else:
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
if arr.ndim > ndim:
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
return arr, nnz
|