|
import numpy as np |
|
import torch |
|
from einops import repeat |
|
from jaxtyping import Float |
|
from scipy.spatial.transform import Rotation as R |
|
from torch import Tensor |
|
|
|
|
|
def generate_spin( |
|
num_frames: int, |
|
device: torch.device, |
|
elevation: float, |
|
radius: float, |
|
) -> Float[Tensor, "frame 4 4"]: |
|
|
|
tf_translation = torch.eye(4, dtype=torch.float32, device=device) |
|
tf_translation[:2] *= -1 |
|
tf_translation[2, 3] = -radius |
|
|
|
|
|
phi = 2 * np.pi * (np.arange(num_frames) / num_frames) |
|
rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) |
|
|
|
azimuth = R.from_rotvec(rotation_vectors).as_matrix() |
|
azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) |
|
tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) |
|
tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() |
|
tf_azimuth[:, :3, :3] = azimuth |
|
|
|
|
|
deg_elevation = np.deg2rad(elevation) |
|
elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) |
|
elevation = torch.tensor(elevation.as_matrix()) |
|
tf_elevation = torch.eye(4, dtype=torch.float32, device=device) |
|
tf_elevation[:3, :3] = elevation |
|
|
|
return tf_azimuth @ tf_elevation @ tf_translation |
|
|