|
import torch |
|
from einops import rearrange |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
|
|
|
|
def quaternion_to_matrix( |
|
quaternions: Float[Tensor, "*batch 4"], |
|
eps: float = 1e-8, |
|
) -> Float[Tensor, "*batch 3 3"]: |
|
|
|
i, j, k, r = torch.unbind(quaternions, dim=-1) |
|
two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) |
|
|
|
o = torch.stack( |
|
( |
|
1 - two_s * (j * j + k * k), |
|
two_s * (i * j - k * r), |
|
two_s * (i * k + j * r), |
|
two_s * (i * j + k * r), |
|
1 - two_s * (i * i + k * k), |
|
two_s * (j * k - i * r), |
|
two_s * (i * k - j * r), |
|
two_s * (j * k + i * r), |
|
1 - two_s * (i * i + j * j), |
|
), |
|
-1, |
|
) |
|
return rearrange(o, "... (i j) -> ... i j", i=3, j=3) |
|
|
|
|
|
def build_covariance( |
|
scale: Float[Tensor, "*#batch 3"], |
|
rotation_xyzw: Float[Tensor, "*#batch 4"], |
|
) -> Float[Tensor, "*batch 3 3"]: |
|
scale = scale.diag_embed() |
|
rotation = quaternion_to_matrix(rotation_xyzw) |
|
return ( |
|
rotation |
|
@ scale |
|
@ rearrange(scale, "... i j -> ... j i") |
|
@ rearrange(rotation, "... i j -> ... j i") |
|
) |
|
|