liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
raw
history blame
14.6 kB
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from croco.models.blocks import Mlp
from dust3r.heads.postprocess import postprocess_pose
inf = float("inf")
class PoseDecoder(nn.Module):
def __init__(
self,
hidden_size=768,
mlp_ratio=4,
pose_encoding_type="absT_quaR",
):
super().__init__()
self.pose_encoding_type = pose_encoding_type
if self.pose_encoding_type == "absT_quaR":
self.target_dim = 7
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
out_features=self.target_dim,
drop=0,
)
def forward(
self,
pose_feat,
):
"""
pose_feat: BxC
preliminary_cameras: cameras in opencv coordinate.
"""
pred_cameras = self.mlp(pose_feat) # Bx7, 3 for absT, 4 for quaR
return pred_cameras
class PoseEncoder(nn.Module):
def __init__(
self,
hidden_size=768,
mlp_ratio=4,
pose_mode=("exp", -inf, inf),
pose_encoding_type="absT_quaR",
):
super().__init__()
self.pose_encoding_type = pose_encoding_type
self.pose_mode = pose_mode
if self.pose_encoding_type == "absT_quaR":
self.target_dim = 7
self.embed_pose = PoseEmbedding(
target_dim=self.target_dim,
out_dim=hidden_size,
n_harmonic_functions=10,
append_input=True,
)
self.pose_encoder = Mlp(
in_features=self.embed_pose.out_dim,
hidden_features=int(hidden_size * mlp_ratio),
out_features=hidden_size,
drop=0,
)
def forward(self, camera):
pose_enc = camera_to_pose_encoding(
camera,
pose_encoding_type=self.pose_encoding_type,
).to(camera.dtype)
pose_enc = postprocess_pose(pose_enc, self.pose_mode, inverse=True)
pose_feat = self.embed_pose(pose_enc)
pose_feat = self.pose_encoder(pose_feat)
return pose_feat
class HarmonicEmbedding(torch.nn.Module):
def __init__(
self,
n_harmonic_functions: int = 6,
omega_0: float = 1.0,
logspace: bool = True,
append_input: bool = True,
) -> None:
"""
The harmonic embedding layer supports the classical
Nerf positional encoding described in
`NeRF <https://arxiv.org/abs/2003.08934>`_
and the integrated position encoding in
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
During the inference you can provide the extra argument `diag_cov`.
If `diag_cov is None`, it converts
rays parametrized with a `ray_bundle` to 3D points by
extending each ray according to the corresponding length.
Then it converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i], # only present if append_input is True.
]
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `diag_cov is not None`, it approximates
conical frustums following a ray bundle as gaussians,
defined by x, the means of the gaussians and diag_cov,
the diagonal covariances.
Then it converts each gaussian
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]::
[
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
...
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
...
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
x[..., i], # only present if append_input is True.
]
where N equals `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
If `logspace==False`, frequencies are linearly spaced between
`1.0` and `2**(n_harmonic_functions-1)`:
`f_1, ..., f_N = torch.linspace(
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
)`
Note that `x` is also premultiplied by the base frequency `omega_0`
before evaluating the harmonic functions.
Args:
n_harmonic_functions: int, number of harmonic
features
omega_0: float, base frequency
logspace: bool, Whether to space the frequencies in
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (embed.sin(), embed.cos(), x)
"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(n_harmonic_functions, dtype=torch.float32)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (n_harmonic_functions - 1),
n_harmonic_functions,
dtype=torch.float32,
)
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.register_buffer(
"_zero_half_pi",
torch.tensor([0.0, 0.5 * torch.pi]),
persistent=False,
)
self.append_input = append_input
def forward(
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
diag_cov: An optional tensor of shape `(..., dim)`
representing the diagonal covariance matrices of our Gaussians, joined with x
as means of the Gaussians.
Returns:
embedding: a harmonic embedding of `x` of shape
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
"""
embed = x[..., None] * self._frequencies
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
embed = embed.sin()
if diag_cov is not None:
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
exp_var = torch.exp(-0.5 * x_var)
embed = embed * exp_var[..., None, :, :]
embed = embed.reshape(*x.shape[:-1], -1)
if self.append_input:
return torch.cat([embed, x], dim=-1)
return embed
@staticmethod
def get_output_dim_static(
input_dims: int, n_harmonic_functions: int, append_input: bool
) -> int:
"""
Utility to help predict the shape of the output of `forward`.
Args:
input_dims: length of the last dimension of the input tensor
n_harmonic_functions: number of embedding frequencies
append_input: whether or not to concat the original
input to the harmonic embedding
Returns:
int: the length of the last dimension of the output tensor
"""
return input_dims * (2 * n_harmonic_functions + int(append_input))
def get_output_dim(self, input_dims: int = 3) -> int:
"""
Same as above. The default for input_dims is 3 for 3D applications
which use harmonic embedding for positional encoding,
so the input might be xyz.
"""
return self.get_output_dim_static(
input_dims, len(self._frequencies), self.append_input
)
class PoseEmbedding(nn.Module):
def __init__(self, target_dim, out_dim, n_harmonic_functions=10, append_input=True):
super().__init__()
self._emb_pose = HarmonicEmbedding(
n_harmonic_functions=n_harmonic_functions, append_input=append_input
)
self.out_dim = self._emb_pose.get_output_dim(target_dim)
def forward(self, pose_encoding):
e_pose_encoding = self._emb_pose(pose_encoding)
return e_pose_encoding
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
return standardize_quaternion(out)
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
quaternions = F.normalize(quaternions, p=2, dim=-1)
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def camera_to_pose_encoding(
camera,
pose_encoding_type="absT_quaR",
):
"""
Inverse to pose_encoding_to_camera
camera: opencv, cam2world
"""
if pose_encoding_type == "absT_quaR":
quaternion_R = matrix_to_quaternion(camera[:, :3, :3])
pose_encoding = torch.cat([camera[:, :3, 3], quaternion_R], dim=-1)
else:
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
return pose_encoding
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def pose_encoding_to_camera(
pose_encoding,
pose_encoding_type="absT_quaR",
):
"""
Args:
pose_encoding: A tensor of shape `BxC`, containing a batch of
`B` `C`-dimensional pose encodings.
pose_encoding_type: The type of pose encoding,
"""
if pose_encoding_type == "absT_quaR":
abs_T = pose_encoding[:, :3]
quaternion_R = pose_encoding[:, 3:7]
R = quaternion_to_matrix(quaternion_R)
else:
raise ValueError(f"Unknown pose encoding {pose_encoding_type}")
c2w_mats = torch.eye(4, 4).to(R.dtype).to(R.device)
c2w_mats = c2w_mats[None].repeat(len(R), 1, 1)
c2w_mats[:, :3, :3] = R
c2w_mats[:, :3, 3] = abs_T
return c2w_mats
def quaternion_conjugate(q):
"""Compute the conjugate of quaternion q (w, x, y, z)."""
q_conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)
return q_conj
def quaternion_multiply(q1, q2):
"""Multiply two quaternions q1 and q2."""
w1, x1, y1, z1 = q1.unbind(dim=-1)
w2, x2, y2, z2 = q2.unbind(dim=-1)
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
return torch.stack((w, x, y, z), dim=-1)
def rotate_vector(q, v):
"""Rotate vector v by quaternion q."""
q_vec = q[..., 1:]
q_w = q[..., :1]
t = 2.0 * torch.cross(q_vec, v, dim=-1)
v_rot = v + q_w * t + torch.cross(q_vec, t, dim=-1)
return v_rot
def relative_pose_absT_quatR(t1, q1, t2, q2):
"""Compute the relative translation and quaternion between two poses."""
q1_inv = quaternion_conjugate(q1)
q_rel = quaternion_multiply(q1_inv, q2)
delta_t = t2 - t1
t_rel = rotate_vector(q1_inv, delta_t)
return t_rel, q_rel