Spaces:
Runtime error
Runtime error
File size: 4,261 Bytes
e8bdafd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import torch
from einops import rearrange, repeat
from packaging import version as pver
from torch import Tensor
@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=False)
def ray_condition(K, c2w, H, W, device, flip_flag=None):
# c2w: B, V, 4, 4
# K: B, V, 3, 3
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
B, V = K.shape[:2]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
if n_flip > 0:
j_flip, i_flip = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
)
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
i[:, flip_flag, ...] = i_flip
j[:, flip_flag, ...] = j_flip
fx = K[..., 0, 0].unsqueeze(-1)
fy = K[..., 1, 1].unsqueeze(-1)
cx = K[..., 0, 2].unsqueeze(-1)
cy = K[..., 1, 2].unsqueeze(-1)
zs = torch.ones_like(i) # [B, V, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d, dim=-1) # B, V, HW, 3
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
plucker = rearrange(plucker, "b f h w c -> b c f h w") # [B, 6, F, H, W]
return plucker
@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=False)
def get_relative_pose(RT_4x4: Tensor, cond_frame_index: Tensor, mode='left'):
'''
:param
RT: (B,F,4,4)
cond_frame_index: (B,)
:return:
relative_RT_4x4: (B,F,4,4)
'''
b, t, _, _ = RT_4x4.shape # b,t,4,4
first_frame_RT = RT_4x4[torch.arange(b, device=RT_4x4.device), cond_frame_index, ...].unsqueeze(1) # (B, 1, 4, 4)
if mode == 'left':
relative_RT_4x4 = first_frame_RT.inverse() @ RT_4x4
elif mode == 'right':
relative_RT_4x4 = RT_4x4 @ first_frame_RT.inverse()
return relative_RT_4x4
@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=False)
def get_camera_condition(H, W, camera_intrinsics, camera_extrinsics, mode, cond_frame_index=0, align_factor=1.0):
'''
:param camera_intrinsics: (B, F, 3, 3)
:param camera_extrinsics: (B, F, 4, 4)
:param cond_frame_index: (B,)
:param trace_scale_factor: (B,)
:return: plucker_embedding: (B, 6, F, H, W)
'''
B, F = camera_extrinsics.shape[:2]
camera_intrinsics_3x3 = camera_intrinsics.float() # B, F, 3, 3
if mode == "c2w":
c2w_RT_4x4 = camera_extrinsics.float() # B, F, 4, 4
elif mode =="w2c":
c2w_RT_4x4 = camera_extrinsics.float().inverse() # B, F, 4, 4
else:
raise ValueError(f"Unknown mode {mode}")
B, F, device = c2w_RT_4x4.shape[0], c2w_RT_4x4.shape[1], c2w_RT_4x4.device
relative_c2w_RT_4x4 = get_relative_pose(c2w_RT_4x4, cond_frame_index, mode='left') # B,F,4,4
relative_c2w_RT_4x4[:, :, :3, 3] = relative_c2w_RT_4x4[:, :, :3, 3] * align_factor
plucker_embedding = ray_condition(camera_intrinsics_3x3, relative_c2w_RT_4x4, H, W, device, flip_flag=None) # B, 6, F, H, W
return plucker_embedding, relative_c2w_RT_4x4 # B 6 F H W
|