File size: 7,679 Bytes
2568013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import torch.nn as nn
from copy import copy, deepcopy

from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri
from src.model.encoder.vggt.utils.rotation import mat_to_quat

def extri_intri_to_pose_encoding(
    extrinsics,
    intrinsics,
    image_size_hw=None,  # e.g., (256, 512)
    pose_encoding_type="absT_quaR_FoV",
):
    """Convert camera extrinsics and intrinsics to a compact pose encoding.

    This function transforms camera parameters into a unified pose encoding format,
    which can be used for various downstream tasks like pose prediction or representation.

    Args:
        extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
            where B is batch size and S is sequence length.
            In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
            The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
        intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
            Defined in pixels, with format:
            [[fx, 0, cx],
             [0, fy, cy],
             [0,  0,  1]]
            where fx, fy are focal lengths and (cx, cy) is the principal point
        image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
            Required for computing field of view values. For example: (256, 512).
        pose_encoding_type (str): Type of pose encoding to use. Currently only
            supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).

    Returns:
        torch.Tensor: Encoded camera pose parameters with shape BxSx9.
            For "absT_quaR_FoV" type, the 9 dimensions are:
            - [:3] = absolute translation vector T (3D)
            - [3:7] = rotation as quaternion quat (4D)
            - [7:] = field of view (2D)
    """

    # extrinsics: BxSx3x4
    # intrinsics: BxSx3x3

    if pose_encoding_type == "absT_quaR_FoV":
        R = extrinsics[:, :, :3, :3]  # BxSx3x3
        T = extrinsics[:, :, :3, 3]  # BxSx3
        
        quat = mat_to_quat(R)
        # Note the order of h and w here
        # H, W = image_size_hw
        # fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
        # fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
        fov_h = 2 * torch.atan(0.5 / intrinsics[..., 1, 1])
        fov_w = 2 * torch.atan(0.5 / intrinsics[..., 0, 0])
        pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
    else:
        raise NotImplementedError

    return pose_encoding

def huber_loss(x, y, delta=1.0):
    """Calculate element-wise Huber loss between x and y"""
    diff = x - y
    abs_diff = diff.abs()
    flag = (abs_diff <= delta).to(diff.dtype)
    return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)

class HuberLoss(nn.Module):
    def __init__(self, alpha=1.0, delta=1.0, gamma=0.6, weight_T=1.0, weight_R=1.0, weight_fl=0.5):
        super().__init__()
        self.alpha = alpha
        self.delta = delta
        self.gamma = gamma
        self.weight_T = weight_T
        self.weight_R = weight_R
        self.weight_fl = weight_fl

    def camera_loss_single(self, cur_pred_pose_enc, gt_pose_encoding, loss_type="l1"):
        if loss_type == "l1":
            loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).abs()
            loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).abs()
            loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).abs()
        elif loss_type == "l2":
            loss_T = (cur_pred_pose_enc[..., :3] - gt_pose_encoding[..., :3]).norm(dim=-1, keepdim=True)
            loss_R = (cur_pred_pose_enc[..., 3:7] - gt_pose_encoding[..., 3:7]).norm(dim=-1)
            loss_fl = (cur_pred_pose_enc[..., 7:] - gt_pose_encoding[..., 7:]).norm(dim=-1)
        elif loss_type == "huber":
            loss_T = huber_loss(cur_pred_pose_enc[..., :3], gt_pose_encoding[..., :3])
            loss_R = huber_loss(cur_pred_pose_enc[..., 3:7], gt_pose_encoding[..., 3:7])
            loss_fl = huber_loss(cur_pred_pose_enc[..., 7:], gt_pose_encoding[..., 7:])
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")

        loss_T = torch.nan_to_num(loss_T, nan=0.0, posinf=0.0, neginf=0.0)
        loss_R = torch.nan_to_num(loss_R, nan=0.0, posinf=0.0, neginf=0.0)
        loss_fl = torch.nan_to_num(loss_fl, nan=0.0, posinf=0.0, neginf=0.0)

        loss_T = torch.clamp(loss_T, min=-100, max=100)
        loss_R = torch.clamp(loss_R, min=-100, max=100)
        loss_fl = torch.clamp(loss_fl, min=-100, max=100)

        loss_T = loss_T.mean()
        loss_R = loss_R.mean()
        loss_fl = loss_fl.mean()
        
        return loss_T, loss_R, loss_fl

    def forward(self, pred_pose_enc_list, batch):
        context_extrinsics = batch["context"]["extrinsics"]
        context_intrinsics = batch["context"]["intrinsics"]
        image_size_hw = batch["context"]["image"].shape[-2:]
        
        # transform extrinsics and intrinsics to pose_enc
        GT_pose_enc = extri_intri_to_pose_encoding(context_extrinsics, context_intrinsics, image_size_hw)
        num_predictions = len(pred_pose_enc_list)
        loss_T = loss_R = loss_fl = 0
        
        for i in range(num_predictions):
            i_weight = self.gamma ** (num_predictions - i - 1)

            cur_pred_pose_enc = pred_pose_enc_list[i]

            loss_T_i, loss_R_i, loss_fl_i = self.camera_loss_single(cur_pred_pose_enc.clone(), GT_pose_enc.clone(), loss_type="huber")
            loss_T += i_weight * loss_T_i
            loss_R += i_weight * loss_R_i
            loss_fl += i_weight * loss_fl_i

        loss_T = loss_T / num_predictions
        loss_R = loss_R / num_predictions
        loss_fl = loss_fl / num_predictions
        loss_camera = loss_T * self.weight_T + loss_R * self.weight_R + loss_fl * self.weight_fl

        loss_dict = {
            "loss_camera": loss_camera,
            "loss_T": loss_T,
            "loss_R": loss_R,
            "loss_fl": loss_fl
        }

        # with torch.no_grad():
        #     # compute auc
        #     last_pred_pose_enc = pred_pose_enc_list[-1]
            
        #     last_pred_extrinsic, _ = pose_encoding_to_extri_intri(last_pred_pose_enc.detach(), image_size_hw, pose_encoding_type='absT_quaR_FoV', build_intrinsics=False)

        #     rel_rangle_deg, rel_tangle_deg = camera_to_rel_deg(last_pred_extrinsic.float(), context_extrinsics.float(), context_extrinsics.device)

        #     if rel_rangle_deg.numel() == 0 and rel_tangle_deg.numel() == 0:
        #         rel_rangle_deg = torch.FloatTensor([0]).to(context_extrinsics.device).to(context_extrinsics.dtype)
        #         rel_tangle_deg = torch.FloatTensor([0]).to(context_extrinsics.device).to(context_extrinsics.dtype)

        #     thresholds = [5, 15]
        #     for threshold in thresholds:
        #         loss_dict[f"Rac_{threshold}"] = (rel_rangle_deg < threshold).float().mean()
        #         loss_dict[f"Tac_{threshold}"] = (rel_tangle_deg < threshold).float().mean()

        #     _, normalized_histogram = calculate_auc(
        #         rel_rangle_deg, rel_tangle_deg, max_threshold=30, return_list=True
        #     )

        #     auc_thresholds = [30, 10, 5, 3]
        #     for auc_threshold in auc_thresholds:
        #         cur_auc = torch.cumsum(
        #             normalized_histogram[:auc_threshold], dim=0
        #         ).mean()
        #         loss_dict[f"Auc_{auc_threshold}"] = cur_auc

        return loss_dict