from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from croco.models.blocks import Mlp from dust3r.heads.postprocess import postprocess_pose inf = float("inf") class PoseDecoder(nn.Module): def __init__( self, hidden_size=768, mlp_ratio=4, pose_encoding_type="absT_quaR", ): super().__init__() self.pose_encoding_type = pose_encoding_type if self.pose_encoding_type == "absT_quaR": self.target_dim = 7 self.mlp = Mlp( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), out_features=self.target_dim, drop=0, ) def forward( self, pose_feat, ): """ pose_feat: BxC preliminary_cameras: cameras in opencv coordinate. """ pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR return pred_cameras class PoseEncoder(nn.Module): def __init__( self, hidden_size=768, mlp_ratio=4, pose_mode=("exp", -inf, inf), pose_encoding_type="absT_quaR", ): super().__init__() self.pose_encoding_type = pose_encoding_type self.pose_mode = pose_mode if self.pose_encoding_type == "absT_quaR": self.target_dim = 7 self.embed_pose = PoseEmbedding( target_dim=self.target_dim, out_dim=hidden_size, n_harmonic_functions=10, append_input=True, ) self.pose_encoder = Mlp( in_features=self.embed_pose.out_dim, hidden_features=int(hidden_size * mlp_ratio), out_features=hidden_size, drop=0, ) def forward(self, camera): pose_enc = camera_to_pose_encoding( camera, pose_encoding_type=self.pose_encoding_type, ).to(camera.dtype) pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True) pose_feat = self.embed_pose(pose_enc) pose_feat = self.pose_encoder(pose_feat) return pose_feat class HarmonicEmbedding(torch.nn.Module): def __init__( self, n_harmonic_functions: int = 6, omega_0: float = 1.0, logspace: bool = True, append_input: bool = True, ) -> None: """ The harmonic embedding layer supports the classical Nerf positional encoding described in `NeRF `_ and the integrated position encoding in `MIP-NeRF `_. During the inference you can provide the extra argument `diag_cov`. If `diag_cov is None`, it converts rays parametrized with a `ray_bundle` to 3D points by extending each ray according to the corresponding length. Then it converts each feature (i.e. vector along the last dimension) in `x` into a series of harmonic features `embedding`, where for each i in range(dim) the following are present in embedding[...]:: [ sin(f_1*x[..., i]), sin(f_2*x[..., i]), ... sin(f_N * x[..., i]), cos(f_1*x[..., i]), cos(f_2*x[..., i]), ... cos(f_N * x[..., i]), x[..., i], # only present if append_input is True. ] where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar denoting the i-th frequency of the harmonic embedding. If `diag_cov is not None`, it approximates conical frustums following a ray bundle as gaussians, defined by x, the means of the gaussians and diag_cov, the diagonal covariances. Then it converts each gaussian into a series of harmonic features `embedding`, where for each i in range(dim) the following are present in embedding[...]:: [ sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), ... sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, ... cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), x[..., i], # only present if append_input is True. ] where N equals `n_harmonic_functions-1`, and f_i is a scalar denoting the i-th frequency of the harmonic embedding. If `logspace==True`, the frequencies `[f_1, ..., f_N]` are powers of 2: `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` If `logspace==False`, frequencies are linearly spaced between `1.0` and `2**(n_harmonic_functions-1)`: `f_1, ..., f_N = torch.linspace( 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions )` Note that `x` is also premultiplied by the base frequency `omega_0` before evaluating the harmonic functions. Args: n_harmonic_functions: int, number of harmonic features omega_0: float, base frequency logspace: bool, Whether to space the frequencies in logspace or linear space append_input: bool, whether to concat the original input to the harmonic embedding. If true the output is of the form (embed.sin(), embed.cos(), x) """ super().__init__() if logspace: frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32) else: frequencies = torch.linspace( 1.0, 2.0 ** (n_harmonic_functions - 1), n_harmonic_functions, dtype=torch.float32, ) self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) self.register_buffer( "_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False, ) self.append_input = append_input def forward( self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: """ Args: x: tensor of shape [..., dim] diag_cov: An optional tensor of shape `(..., dim)` representing the diagonal covariance matrices of our Gaussians, joined with x as means of the Gaussians. Returns: embedding: a harmonic embedding of `x` of shape [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] """ embed = x[..., None] * self._frequencies embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] embed = embed.sin() if diag_cov is not None: x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) exp_var = torch.exp(-0.5 * x_var) embed = embed * exp_var[..., None, :, :] embed = embed.reshape(*x.shape[:-1], -1) if self.append_input: return torch.cat([embed, x], dim=-1) return embed @staticmethod def get_output_dim_static( input_dims: int, n_harmonic_functions: int, append_input: bool ) -> int: """ Utility to help predict the shape of the output of `forward`. Args: input_dims: length of the last dimension of the input tensor n_harmonic_functions: number of embedding frequencies append_input: whether or not to concat the original input to the harmonic embedding Returns: int: the length of the last dimension of the output tensor """ return input_dims * (2 * n_harmonic_functions + int(append_input)) def get_output_dim(self, input_dims: int = 3) -> int: """ Same as above. The default for input_dims is 3 for 3D applications which use harmonic embedding for positional encoding, so the input might be xyz. """ return self.get_output_dim_static( input_dims, len(self._frequencies), self.append_input ) class PoseEmbedding(nn.Module): def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True): super().__init__() self._emb_pose = HarmonicEmbedding( n_harmonic_functions=n_harmonic_functions, append_input=append_input ) self.out_dim = self._emb_pose.get_output_dim(target_dim) def forward(self, pose_encoding): e_pose_encoding = self._emb_pose(pose_encoding) return e_pose_encoding def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: """ Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0. """ ret = torch.zeros_like(x) positive_mask = x > 0 ret[positive_mask] = torch.sqrt(x[positive_mask]) return ret def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: """ Convert rotations given as rotation matrices to quaternions. Args: matrix: Rotation matrices as tensor of shape (..., 3, 3). Returns: quaternions with real part first, as tensor of shape (..., 4). """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") batch_dim = matrix.shape[:-2] m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( matrix.reshape(batch_dim + (9,)), dim=-1 ) q_abs = _sqrt_positive_part( torch.stack( [ 1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22, ], dim=-1, ) ) quat_by_rijk = torch.stack( [ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), ], dim=-2, ) flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) out = quat_candidates[ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : ].reshape(batch_dim + (4,)) return standardize_quaternion(out) def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: """ Convert a unit quaternion to a standard form: one in which the real part is non negative. Args: quaternions: Quaternions with real part first, as tensor of shape (..., 4). Returns: Standardized quaternions as tensor of shape (..., 4). """ quaternions = F.normalize(quaternions, p=2, dim=-1) return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) def camera_to_pose_encoding( camera, pose_encoding_type="absT_quaR", ): """ Inverse to pose_encoding_to_camera camera: opencv, cam2world """ if pose_encoding_type == "absT_quaR": quaternion_R = matrix_to_quaternion(camera[:, :3, :3]) pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1) else: raise ValueError(f"Unknown pose encoding {pose_encoding_type}") return pose_encoding def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: """ Convert rotations given as quaternions to rotation matrices. Args: quaternions: quaternions with real part first, as tensor of shape (..., 4). Returns: Rotation matrices as tensor of shape (..., 3, 3). """ r, i, j, k = torch.unbind(quaternions, -1) two_s = 2.0 / (quaternions * quaternions).sum(-1) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return o.reshape(quaternions.shape[:-1] + (3, 3)) def pose_encoding_to_camera( pose_encoding, pose_encoding_type="absT_quaR", ): """ Args: pose_encoding: A tensor of shape `BxC`, containing a batch of `B` `C`-dimensional pose encodings. pose_encoding_type: The type of pose encoding, """ if pose_encoding_type == "absT_quaR": abs_T = pose_encoding[:, :3] quaternion_R = pose_encoding[:, 3:7] R = quaternion_to_matrix(quaternion_R) else: raise ValueError(f"Unknown pose encoding {pose_encoding_type}") c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device) c2w_mats = c2w_mats[None].repeat(len(R), 1, 1) c2w_mats[:, :3, :3] = R c2w_mats[:, :3, 3] = abs_T return c2w_mats def quaternion_conjugate(q): """Compute the conjugate of quaternion q (w, x, y, z).""" q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) return q_conj def quaternion_multiply(q1, q2): """Multiply two quaternions q1 and q2.""" w1, x1, y1, z1 = q1.unbind(dim=-1) w2, x2, y2, z2 = q2.unbind(dim=-1) w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 return torch.stack((w, x, y, z), dim=-1) def rotate_vector(q, v): """Rotate vector v by quaternion q.""" q_vec = q[..., 1:] q_w = q[..., :1] t = 2.0 * torch.cross(q_vec, v, dim=-1) v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1) return v_rot def relative_pose_absT_quatR(t1, q1, t2, q2): """Compute the relative translation and quaternion between two poses.""" q1_inv = quaternion_conjugate(q1) q_rel = quaternion_multiply(q1_inv, q2) delta_t = t2 - t1 t_rel = rotate_vector(q1_inv, delta_t) return t_rel, q_rel