|
import torch |
|
from einops import rearrange |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
|
|
|
|
@torch.no_grad() |
|
def generate_wobble_transformation( |
|
radius: Float[Tensor, "*#batch"], |
|
t: Float[Tensor, " time_step"], |
|
num_rotations: int = 1, |
|
scale_radius_with_t: bool = True, |
|
) -> Float[Tensor, "*batch time_step 4 4"]: |
|
|
|
tf = torch.eye(4, dtype=torch.float32, device=t.device) |
|
tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() |
|
radius = radius[..., None] |
|
if scale_radius_with_t: |
|
radius = radius * t |
|
tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius |
|
tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius |
|
return tf |
|
|
|
|
|
@torch.no_grad() |
|
def generate_wobble( |
|
extrinsics: Float[Tensor, "*#batch 4 4"], |
|
radius: Float[Tensor, "*#batch"], |
|
t: Float[Tensor, " time_step"], |
|
) -> Float[Tensor, "*batch time_step 4 4"]: |
|
tf = generate_wobble_transformation(radius, t) |
|
return rearrange(extrinsics, "... i j -> ... () i j") @ tf |
|
|