alexnasa's picture
Upload 243 files
2568013 verified
import torch
from einops import rearrange
from jaxtyping import Float
from torch import Tensor
# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
def quaternion_to_matrix(
quaternions: Float[Tensor, "*batch 4"],
eps: float = 1e-8,
) -> Float[Tensor, "*batch 3 3"]:
# Order changed to match scipy format!
i, j, k, r = torch.unbind(quaternions, dim=-1)
two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return rearrange(o, "... (i j) -> ... i j", i=3, j=3)
def build_covariance(
scale: Float[Tensor, "*#batch 3"],
rotation_xyzw: Float[Tensor, "*#batch 4"],
) -> Float[Tensor, "*batch 3 3"]:
scale = scale.diag_embed()
rotation = quaternion_to_matrix(rotation_xyzw)
return (
rotation
@ scale
@ rearrange(scale, "... i j -> ... j i")
@ rearrange(rotation, "... i j -> ... j i")
)