Spaces:
Running
on
Zero
Running
on
Zero
from math import sin, cos | |
import torch | |
from torch.cuda.amp import autocast | |
def transform_pts(pts: torch.Tensor, rel_pose: torch.Tensor) -> torch.Tensor: | |
"""Transform points by relative pose | |
Args: | |
pts (torch.Tensor): B, n_pts, 3 | |
rel_pose (torch.Tensor): B, 4, 4 | |
Returns: | |
torch.Tensor: B, n_pts, 3 | |
""" | |
pts = torch.cat((pts, torch.ones_like(pts[..., :1])), dim=-1) | |
return (pts @ rel_pose.transpose(-1, -2))[..., :3] | |
# TODO: unify | |
def distance_to_z(depths: torch.Tensor, projs: torch.Tensor): | |
n, nv, h, w = depths.shape | |
device = depths.device | |
inv_K = torch.inverse(projs) | |
grid_x = ( | |
torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1).expand(-1, -1, h, -1) | |
) | |
grid_y = ( | |
torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1).expand(-1, -1, -1, w) | |
) | |
img_points = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=2).expand( | |
n, nv, -1, -1, -1 | |
) | |
cam_points = (inv_K @ img_points.view(n, nv, 3, -1)).view(n, nv, 3, h, w) | |
factors = cam_points[:, :, 2, :, :] / torch.norm(cam_points, dim=2) | |
return depths * factors | |
def z_to_distance(z: torch.Tensor, projs: torch.Tensor): | |
n, nv, h, w = z.shape | |
device = z.device | |
inv_K = torch.inverse(projs) | |
grid_x = ( | |
torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1).expand(-1, -1, h, -1) | |
) | |
grid_y = ( | |
torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1).expand(-1, -1, -1, w) | |
) | |
img_points = torch.stack((grid_x, grid_y, torch.ones_like(grid_x)), dim=2).expand( | |
n, nv, -1, -1, -1 | |
) | |
cam_points = (inv_K @ img_points.view(n, nv, 3, -1)).view(n, nv, 3, h, w) | |
factors = cam_points[:, :, 2, :, :] / torch.norm(cam_points, dim=2) | |
return z / factors | |
def azimuth_elevation_to_rotation(azimuth: float, elevation: float) -> torch.Tensor: | |
rot_z = torch.tensor( | |
[ | |
[cos(azimuth), -sin(azimuth), 0.0], | |
[sin(azimuth), cos(azimuth), 0.0], | |
[0.0, 0.0, 1.0], | |
] | |
) | |
rot_x = torch.tensor( | |
[ | |
[1.0, 0.0, 0.0], | |
[0.0, cos(azimuth), -sin(azimuth)], | |
[0.0, sin(azimuth), cos(azimuth)], | |
] | |
) | |
return rot_x @ rot_z | |
def estimate_frustum_overlap(proj_source: torch.Tensor, pose_source: torch.Tensor, proj_target: torch.Tensor, pose_target: torch.Tensor, dist_lim=50): | |
device = proj_source.device | |
dtype = proj_source.dtype | |
# Check which camera has higher z value in target coordinate system | |
with autocast(enabled=False): | |
src2tgt = torch.inverse(pose_target) @ pose_source | |
for i in range(len(src2tgt)): | |
if src2tgt[i, 2, 3] < 0: | |
print("SWAP", i) | |
proj_ = proj_target[i].clone() | |
pose_ = pose_target[i].clone() | |
proj_target[i] = proj_source[i] | |
pose_target[i] = pose_source[i] | |
proj_source[i] = proj_ | |
pose_source[i] = pose_ | |
points = torch.tensor([[ | |
[-1, 1, 1, 1], | |
[1, 1, 1, 1], | |
[1, -1, 1, 1], | |
[-1, -1, 1, 1], | |
]], device=device, dtype=dtype) | |
with autocast(enabled=False): | |
K_src_inv = torch.inverse(proj_source) | |
K_tgt_inv = torch.inverse(proj_target) | |
_ = K_src_inv.new_zeros(K_src_inv.shape[0], 4, 4) | |
_[:, 3, 3] = 1 | |
_[:, :3, :3] = K_src_inv | |
K_src_inv = _ | |
_ = K_tgt_inv.new_zeros(K_tgt_inv.shape[0], 4, 4) | |
_[:, 3, 3] = 1 | |
_[:, :3, :3] = K_tgt_inv | |
K_tgt_inv = _ | |
points_src = K_src_inv @ points.permute(0, 2, 1) | |
points_tgt = K_tgt_inv @ points.permute(0, 2, 1) | |
normals_tgt = torch.cross(points_tgt[..., :3, :], torch.roll(points_tgt[..., :3, :], shifts=-1, dims=-2), dim=-2) | |
normals_tgt = normals_tgt / torch.norm(normals_tgt, dim=-2, keepdim=True) | |
with autocast(enabled=False): | |
src2tgt = torch.inverse(pose_target) @ pose_source | |
base = src2tgt[:, :3, 3, None] | |
points_src_tgt = src2tgt @ points_src | |
dirs = points_src_tgt[..., :3, :] - base | |
# dirs = dirs / torch.norm(dirs, dim=-2) #dirs should have z length 1 | |
dists = - (base[..., None] * normals_tgt[..., None, :]).sum(dim=-3) / (dirs[..., None] * normals_tgt[..., None, :]).sum(dim=-3).clamp_min(1e-4) | |
# print(dists) | |
# Ignore all non-positive | |
mask = (dists <= 0) | (dists > dist_lim) | |
dists[mask] = dist_lim | |
# print(dists) | |
dists = torch.min(dists, dim=-1)[0] | |
mean_dist = dists.mean(dim=-1) | |
# print(mean_dist, (torch.max(points_src[..., 0], dim=-1)[0] - torch.min(points_src[..., 0], dim=-1)[0]), (torch.max(points_src[..., 1], dim=-1)[0] - torch.min(points_src[..., 1], dim=-1)[0])) | |
volume_estimate = \ | |
1/3 * \ | |
(torch.max(points_src[..., 0], dim=-1)[0] - torch.min(points_src[..., 0], dim=-1)[0]) * mean_dist * \ | |
(torch.max(points_src[..., 1], dim=-1)[0] - torch.min(points_src[..., 1], dim=-1)[0]) * mean_dist * \ | |
mean_dist | |
return volume_estimate | |
def estimate_frustum_overlap_2(proj_source: torch.Tensor, pose_source: torch.Tensor, proj_target: torch.Tensor, pose_target: torch.Tensor, z_range=(3, 40), res=(8, 8, 16)): | |
device = proj_source.device | |
dtype = proj_source.dtype | |
with autocast(enabled=False): | |
K_src_inv = torch.inverse(proj_source) | |
n = proj_source.shape[0] | |
w, h, d = res | |
pixel_width = 2 / w | |
pixel_height = 2 / h | |
x = torch.linspace(-1 + .5 * pixel_width, 1 - .5 * pixel_width, w, dtype=dtype, device=device).view(1, 1, 1, w).expand(n, d, h, w) | |
y = torch.linspace(-1 + .5 * pixel_height, 1 - .5 * pixel_height, h, dtype=dtype, device=device).view(1, 1, h, 1).expand(n, d, h, w) | |
z = torch.ones_like(x) | |
xyz = torch.stack((x, y, z), dim=-1) | |
xyz = K_src_inv @ xyz.reshape(n, -1, 3).permute(0, 2, 1) | |
xyz = xyz.reshape(n, 3, d, h, w) | |
# xyz = xyz * (1 / torch.linspace(1 / z_range[0], 1 / z_range[1], d, dtype=dtype, device=device).view(1, 1, d, 1, 1).expand(n, 1, d, h, w)) | |
xyz = xyz * torch.linspace(z_range[0], z_range[1], d, dtype=dtype, device=device).view(1, 1, d, 1, 1).expand(n, 1, d, h, w) | |
xyz = torch.cat((xyz, torch.ones_like(xyz[:, :1])), dim=1) | |
xyz = xyz.reshape(n, 4, -1) | |
with autocast(enabled=False): | |
src2tgt = torch.inverse(pose_target) @ pose_source | |
xyz = src2tgt @ xyz | |
# print(xyz) | |
xyz = proj_target @ xyz[:, :3, :] | |
xyz[:, :2] = xyz[:, :2] / xyz[:, 2:3, :] | |
# print(xyz) | |
valid = (xyz[:, 0].abs() < 1) & (xyz[:, 1].abs() < 1) & (xyz[:, 2].abs() > z_range[0])# & (xyz[:, 2].abs() < z_range[1]) | |
# print(valid) | |
volume_estimate = valid.to(dtype).mean(-1) | |
return volume_estimate | |
def compute_occlusions(flow0, flow1): | |
n, _, h, w = flow0.shape | |
device = flow0.device | |
x = torch.linspace(-1, 1, w, device=device).view(1, 1, w).expand(1, h, w) | |
y = torch.linspace(-1, 1, h, device=device).view(1, h, 1).expand(1, h, w) | |
xy = torch.cat((x, y), dim=0).view(1, 2, h, w).expand(n, 2, h, w) | |
flow0_r = torch.cat((flow0[:, 0:1, :, :] * 2 / w , flow0[:, 1:2, :, :] * 2 / h), dim=1) | |
flow1_r = torch.cat((flow1[:, 0:1, :, :] * 2 / w , flow1[:, 1:2, :, :] * 2 / h), dim=1) | |
xy_0 = xy + flow0_r | |
xy_1 = xy + flow1_r | |
xy_0 = xy_0.view(n, 2, -1) | |
xy_1 = xy_1.view(n, 2, -1) | |
ns = torch.arange(n, device=device, dtype=xy_0.dtype) | |
nxy_0 = torch.cat((ns.view(n, 1, 1).expand(-1, 1, xy_0.shape[-1]), xy_0), dim=1) | |
nxy_1 = torch.cat((ns.view(n, 1, 1).expand(-1, 1, xy_1.shape[-1]), xy_1), dim=1) | |
mask0 = torch.zeros_like(flow0[:, :1, :, :]) | |
mask0[nxy_1[:, 0, :].long(), 0, ((nxy_1[:, 2, :] * .5 + .5) * h).round().long().clamp(0, h-1), ((nxy_1[:, 1, :] * .5 + .5) * w).round().long().clamp(0, w-1)] = 1 | |
mask1 = torch.zeros_like(flow1[:, :1, :, :]) | |
mask1[nxy_0[:, 0, :].long(), 0, ((nxy_0[:, 2, :] * .5 + .5) * h).round().long().clamp(0, h-1), ((nxy_0[:, 1, :] * .5 + .5) * w).round().long().clamp(0, w-1)] = 1 | |
return mask0, mask1 |