File size: 1,328 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
from einops import rearrange
from jaxtyping import Float
from torch import Tensor
# https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
def quaternion_to_matrix(
quaternions: Float[Tensor, "*batch 4"],
eps: float = 1e-8,
) -> Float[Tensor, "*batch 3 3"]:
# Order changed to match scipy format!
i, j, k, r = torch.unbind(quaternions, dim=-1)
two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return rearrange(o, "... (i j) -> ... i j", i=3, j=3)
def build_covariance(
scale: Float[Tensor, "*#batch 3"],
rotation_xyzw: Float[Tensor, "*#batch 4"],
) -> Float[Tensor, "*batch 3 3"]:
scale = scale.diag_embed()
rotation = quaternion_to_matrix(rotation_xyzw)
return (
rotation
@ scale
@ rearrange(scale, "... i j -> ... j i")
@ rearrange(rotation, "... i j -> ... j i")
)
|