import torch from einops import rearrange from jaxtyping import Float from torch import Tensor # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py def quaternion_to_matrix( quaternions: Float[Tensor, "*batch 4"], eps: float = 1e-8, ) -> Float[Tensor, "*batch 3 3"]: # Order changed to match scipy format! i, j, k, r = torch.unbind(quaternions, dim=-1) two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) o = torch.stack( ( 1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r), two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r), two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j), ), -1, ) return rearrange(o, "... (i j) -> ... i j", i=3, j=3) def build_covariance( scale: Float[Tensor, "*#batch 3"], rotation_xyzw: Float[Tensor, "*#batch 4"], ) -> Float[Tensor, "*batch 3 3"]: scale = scale.diag_embed() rotation = quaternion_to_matrix(rotation_xyzw) return ( rotation @ scale @ rearrange(scale, "... i j -> ... j i") @ rearrange(rotation, "... i j -> ... j i") )