File size: 4,261 Bytes
e8bdafd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from einops import rearrange, repeat
from packaging import version as pver
from torch import Tensor


@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=False)
def ray_condition(K, c2w, H, W, device, flip_flag=None):
    # c2w: B, V, 4, 4
    # K: B, V, 3, 3

    def custom_meshgrid(*args):
        # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
        if pver.parse(torch.__version__) < pver.parse('1.10'):
            return torch.meshgrid(*args)
        else:
            return torch.meshgrid(*args, indexing='ij')

    B, V = K.shape[:2]

    j, i = custom_meshgrid(
        torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
        torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
    )
    i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5  # [B, V, HxW]
    j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5  # [B, V, HxW]

    n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
    if n_flip > 0:
        j_flip, i_flip = custom_meshgrid(
            torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
            torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
        )
        i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
        i[:, flip_flag, ...] = i_flip
        j[:, flip_flag, ...] = j_flip

    fx = K[..., 0, 0].unsqueeze(-1)
    fy = K[..., 1, 1].unsqueeze(-1)
    cx = K[..., 0, 2].unsqueeze(-1)
    cy = K[..., 1, 2].unsqueeze(-1)

    zs = torch.ones_like(i)  # [B, V, HxW]
    xs = (i - cx) / fx * zs
    ys = (j - cy) / fy * zs
    zs = zs.expand_as(ys)

    directions = torch.stack((xs, ys, zs), dim=-1)  # B, V, HW, 3
    directions = directions / directions.norm(dim=-1, keepdim=True)  # B, V, HW, 3

    rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)  # B, V, HW, 3
    rays_o = c2w[..., :3, 3]  # B, V, 3
    rays_o = rays_o[:, :, None].expand_as(rays_d)  # B, V, HW, 3
    # c2w @ dirctions
    rays_dxo = torch.cross(rays_o, rays_d, dim=-1)  # B, V, HW, 3
    plucker = torch.cat([rays_dxo, rays_d], dim=-1)
    plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)  # B, V, H, W, 6
    # plucker = plucker.permute(0, 1, 4, 2, 3)
    plucker = rearrange(plucker, "b f h w c -> b c f h w")  # [B, 6, F, H, W]
    return plucker


@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=False)
def get_relative_pose(RT_4x4: Tensor, cond_frame_index: Tensor, mode='left'):
    '''
    :param
        RT: (B,F,4,4)
        cond_frame_index: (B,)
    :return:
        relative_RT_4x4: (B,F,4,4)
    '''
    b, t, _, _ = RT_4x4.shape  # b,t,4,4
    first_frame_RT = RT_4x4[torch.arange(b, device=RT_4x4.device), cond_frame_index, ...].unsqueeze(1)  # (B, 1, 4, 4)

    if mode == 'left':
        relative_RT_4x4 = first_frame_RT.inverse() @ RT_4x4
    elif mode == 'right':
        relative_RT_4x4 = RT_4x4 @ first_frame_RT.inverse()

    return relative_RT_4x4


@torch.no_grad()
@torch.autocast(device_type="cuda", enabled=False)
def get_camera_condition(H, W, camera_intrinsics, camera_extrinsics, mode, cond_frame_index=0, align_factor=1.0):
    '''
    :param camera_intrinsics: (B, F, 3, 3)
    :param camera_extrinsics: (B, F, 4, 4)
    :param cond_frame_index:  (B,)
    :param trace_scale_factor: (B,)
    :return: plucker_embedding: (B, 6, F, H, W)
    '''
    B, F = camera_extrinsics.shape[:2]
    camera_intrinsics_3x3 = camera_intrinsics.float()  # B, F, 3, 3
    if mode == "c2w":
        c2w_RT_4x4 = camera_extrinsics.float()  # B, F, 4, 4
    elif mode =="w2c":
        c2w_RT_4x4 = camera_extrinsics.float().inverse()  # B, F, 4, 4
    else:
        raise ValueError(f"Unknown mode {mode}")
    B, F, device = c2w_RT_4x4.shape[0], c2w_RT_4x4.shape[1], c2w_RT_4x4.device

    relative_c2w_RT_4x4 = get_relative_pose(c2w_RT_4x4, cond_frame_index, mode='left')  # B,F,4,4
    relative_c2w_RT_4x4[:, :, :3, 3] = relative_c2w_RT_4x4[:, :, :3, 3] * align_factor

    plucker_embedding = ray_condition(camera_intrinsics_3x3, relative_c2w_RT_4x4, H, W, device, flip_flag=None)  # B, 6, F, H, W

    return plucker_embedding, relative_c2w_RT_4x4 # B 6 F H W