alexnasa's picture
Upload 243 files
2568013 verified
import numpy as np
import torch
from einops import repeat
from jaxtyping import Float
from scipy.spatial.transform import Rotation as R
from torch import Tensor
def generate_spin(
num_frames: int,
device: torch.device,
elevation: float,
radius: float,
) -> Float[Tensor, "frame 4 4"]:
# Translate back along the camera's look vector.
tf_translation = torch.eye(4, dtype=torch.float32, device=device)
tf_translation[:2] *= -1
tf_translation[2, 3] = -radius
# Generate the transformation for the azimuth.
phi = 2 * np.pi * (np.arange(num_frames) / num_frames)
rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1)
azimuth = R.from_rotvec(rotation_vectors).as_matrix()
azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device)
tf_azimuth = torch.eye(4, dtype=torch.float32, device=device)
tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone()
tf_azimuth[:, :3, :3] = azimuth
# Generate the transformation for the elevation.
deg_elevation = np.deg2rad(elevation)
elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32))
elevation = torch.tensor(elevation.as_matrix())
tf_elevation = torch.eye(4, dtype=torch.float32, device=device)
tf_elevation[:3, :3] = elevation
return tf_azimuth @ tf_elevation @ tf_translation