File size: 7,054 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import copy, deepcopy

from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from src.model.encoder.vggt.utils.rotation import mat_to_quat
from src.utils.point import get_normal_map

def extri_intri_to_pose_encoding(
    extrinsics,
    intrinsics,
    image_size_hw=None,  # e.g., (256, 512)
    pose_encoding_type="absT_quaR_FoV",
):
    """Convert camera extrinsics and intrinsics to a compact pose encoding.

    This function transforms camera parameters into a unified pose encoding format,
    which can be used for various downstream tasks like pose prediction or representation.

    Args:
        extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
            where B is batch size and S is sequence length.
            In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
            The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
        intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
            Defined in pixels, with format:
            [[fx, 0, cx],
             [0, fy, cy],
             [0,  0,  1]]
            where fx, fy are focal lengths and (cx, cy) is the principal point
        image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
            Required for computing field of view values. For example: (256, 512).
        pose_encoding_type (str): Type of pose encoding to use. Currently only
            supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).

    Returns:
        torch.Tensor: Encoded camera pose parameters with shape BxSx9.
            For "absT_quaR_FoV" type, the 9 dimensions are:
            - [:3] = absolute translation vector T (3D)
            - [3:7] = rotation as quaternion quat (4D)
            - [7:] = field of view (2D)
    """

    # extrinsics: BxSx3x4
    # intrinsics: BxSx3x3

    if pose_encoding_type == "absT_quaR_FoV":
        R = extrinsics[:, :, :3, :3]  # BxSx3x3
        T = extrinsics[:, :, :3, 3]  # BxSx3
        
        quat = mat_to_quat(R)
        # Note the order of h and w here
        # H, W = image_size_hw
        # fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
        # fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
        fov_h = 2 * torch.atan(0.5 / intrinsics[..., 1, 1])
        fov_w = 2 * torch.atan(0.5 / intrinsics[..., 0, 0])
        pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
    else:
        raise NotImplementedError

    return pose_encoding
    
def huber_loss(x, y, delta=1.0):
    """Calculate element-wise Huber loss between x and y"""
    diff = x - y
    abs_diff = diff.abs()
    flag = (abs_diff <= delta).to(diff.dtype)
    return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)

class DistillLoss(nn.Module):
    def __init__(self, delta=1.0, gamma=0.6, weight_pose=1.0, weight_depth=1.0, weight_normal=1.0):
        super().__init__()
        self.delta = delta
        self.gamma = gamma
        self.weight_pose = weight_pose
        self.weight_depth = weight_depth
        self.weight_normal = weight_normal

    def camera_loss_single(self, cur_pred_pose_enc, gt_pose_encoding, loss_type="l1"):
        if loss_type == "l1":
            loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).abs()
            loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).abs()
            loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).abs()
        elif loss_type == "l2":
            loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).norm(dim=-1, keepdim=True)
            loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).norm(dim=-1)
            loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).norm(dim=-1)
        elif loss_type == "huber":
            loss_T = huber_loss(cur_pred_pose_enc[..., :3], gt_pose_encoding[..., :3])
            loss_R = huber_loss(cur_pred_pose_enc[..., 3:7], gt_pose_encoding[..., 3:7])
            loss_fl = huber_loss(cur_pred_pose_enc[..., 7:], gt_pose_encoding[..., 7:])
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")

        loss_T = torch.nan_to_num(loss_T, nan=0.0, posinf=0.0, neginf=0.0)
        loss_R = torch.nan_to_num(loss_R, nan=0.0, posinf=0.0, neginf=0.0)
        loss_fl = torch.nan_to_num(loss_fl, nan=0.0, posinf=0.0, neginf=0.0)

        loss_T = torch.clamp(loss_T, min=-100, max=100)
        loss_R = torch.clamp(loss_R, min=-100, max=100)
        loss_fl = torch.clamp(loss_fl, min=-100, max=100)

        loss_T = loss_T.mean()
        loss_R = loss_R.mean()
        loss_fl = loss_fl.mean()
        
        return loss_T, loss_R, loss_fl

    def forward(self, distill_infos, pred_pose_enc_list, prediction, batch):
        loss_pose = 0.0

        if pred_pose_enc_list is not None:
            num_predictions = len(pred_pose_enc_list)
            pesudo_gt_pose_enc = distill_infos['pred_pose_enc_list']
            for i in range(num_predictions):
                i_weight = self.gamma ** (num_predictions - i - 1)
                cur_pred_pose_enc = pred_pose_enc_list[i]
                cur_pesudo_gt_pose_enc = pesudo_gt_pose_enc[i]
                loss_pose += i_weight * huber_loss(cur_pred_pose_enc, cur_pesudo_gt_pose_enc).mean()
            loss_pose = loss_pose / num_predictions
            loss_pose = torch.nan_to_num(loss_pose, nan=0.0, posinf=0.0, neginf=0.0)
        
        pred_depth = prediction.depth.flatten(0, 1)
        pesudo_gt_depth = distill_infos['depth_map'].flatten(0, 1).squeeze(-1)
        conf_mask = distill_infos['conf_mask'].flatten(0, 1)

        if batch['context']['valid_mask'].sum() > 0:
            conf_mask = batch['context']['valid_mask'].flatten(0, 1)

        loss_depth = F.mse_loss(pred_depth[conf_mask], pesudo_gt_depth[conf_mask], reduction='none').mean()

        render_normal = get_normal_map(pred_depth, batch["context"]["intrinsics"].flatten(0, 1))
        pred_normal = get_normal_map(pesudo_gt_depth, batch["context"]["intrinsics"].flatten(0, 1))
        
        alpha1_loss = (1 - (render_normal[conf_mask] * pred_normal[conf_mask]).sum(-1)).mean()
        alpha2_loss = F.l1_loss(render_normal[conf_mask], pred_normal[conf_mask], reduction='mean')
        loss_normal = (alpha1_loss + alpha2_loss) / 2
        
        loss_distill = loss_pose * self.weight_pose + loss_depth * self.weight_depth + loss_normal * self.weight_normal
        loss_distill = torch.nan_to_num(loss_distill, nan=0.0, posinf=0.0, neginf=0.0)
        
        loss_dict = {
            "loss_distill": loss_distill,
            "loss_pose": loss_pose * self.weight_pose,
            "loss_depth": loss_depth * self.weight_depth,
            "loss_normal": loss_normal * self.weight_normal
        }

        return loss_dict