Spaces:
Running
Running
| import numpy as np | |
| import roma | |
| import torch | |
| import torch.nn.functional as F | |
| def rt_to_mat4( | |
| R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| R (torch.Tensor): (..., 3, 3). | |
| t (torch.Tensor): (..., 3). | |
| s (torch.Tensor): (...,). | |
| Returns: | |
| torch.Tensor: (..., 4, 4) | |
| """ | |
| mat34 = torch.cat([R, t[..., None]], dim=-1) | |
| if s is None: | |
| bottom = ( | |
| mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]]) | |
| .reshape((1,) * (mat34.dim() - 2) + (1, 4)) | |
| .expand(mat34.shape[:-2] + (1, 4)) | |
| ) | |
| else: | |
| bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0) | |
| mat4 = torch.cat([mat34, bottom], dim=-2) | |
| return mat4 | |
| def get_avg_w2c(w2cs: torch.Tensor): | |
| c2ws = torch.linalg.inv(w2cs) | |
| # 1. Compute the center | |
| center = c2ws[:, :3, -1].mean(0) | |
| # 2. Compute the z axis | |
| z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1) | |
| # 3. Compute axis y' (no need to normalize as it's not the final output) | |
| y_ = c2ws[:, :3, 1].mean(0) # (3) | |
| # 4. Compute the x axis | |
| x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3) | |
| # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) | |
| y = torch.cross(z, x, dim=-1) # (3) | |
| avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center) | |
| avg_w2c = torch.linalg.inv(avg_c2w) | |
| return avg_w2c | |
| # def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: | |
| # """Calculate the intersection point of multiple camera rays as the lookat point. | |
| # Use the center of camera positions as a reference point for the lookat, | |
| # then move forward along the average view direction by a certain distance. | |
| # """ | |
| # # Calculate the center of camera positions | |
| # center = origins.mean(dim=0) | |
| # # Calculate average view direction | |
| # mean_dir = F.normalize(viewdirs.mean(dim=0), dim=-1) | |
| # # Calculate average distance to the center point | |
| # avg_dist = torch.norm(origins - center, dim=-1).mean() | |
| # # Move forward along the average view direction | |
| # lookat = center + mean_dir * avg_dist | |
| # return lookat | |
| def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor: | |
| """Triangulate a set of rays to find a single lookat point. | |
| Args: | |
| origins (torch.Tensor): A (N, 3) array of ray origins. | |
| viewdirs (torch.Tensor): A (N, 3) array of ray view directions. | |
| Returns: | |
| torch.Tensor: A (3,) lookat point. | |
| """ | |
| viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1) | |
| eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None] | |
| # Calculate projection matrix I - rr^T | |
| I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :]) | |
| # Compute sum of projections | |
| sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3) | |
| # Solve for the intersection point using least squares | |
| lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] | |
| # Check NaNs. | |
| assert not torch.any(torch.isnan(lookat)) | |
| return lookat | |
| def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor): | |
| """ | |
| Args: | |
| positions: (N, 3) tensor of camera positions | |
| lookat: (3,) tensor of lookat point | |
| up: (3,) tensor of up vector | |
| Returns: | |
| w2cs: (N, 3, 3) tensor of world to camera rotation matrices | |
| """ | |
| forward_vectors = F.normalize(lookat - positions, dim=-1) | |
| right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1) | |
| down_vectors = F.normalize( | |
| torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1 | |
| ) | |
| Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1) | |
| w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions)) | |
| return w2cs | |
| def get_arc_w2cs( | |
| ref_w2c: torch.Tensor, | |
| lookat: torch.Tensor, | |
| up: torch.Tensor, | |
| num_frames: int, | |
| degree: float, | |
| **_, | |
| ) -> torch.Tensor: | |
| ref_position = torch.linalg.inv(ref_w2c)[:3, 3] | |
| thetas = ( | |
| torch.sin( | |
| torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[ | |
| :-1 | |
| ] | |
| ) | |
| * (degree / 2.0) | |
| / 180.0 | |
| * torch.pi | |
| ) | |
| positions = torch.einsum( | |
| "nij,j->ni", | |
| roma.rotvec_to_rotmat(thetas[:, None] * up[None]), | |
| ref_position - lookat, | |
| ) | |
| return get_lookat_w2cs(positions, lookat, up) | |
| def get_lemniscate_w2cs( | |
| ref_w2c: torch.Tensor, | |
| lookat: torch.Tensor, | |
| up: torch.Tensor, | |
| num_frames: int, | |
| degree: float, | |
| **_, | |
| ) -> torch.Tensor: | |
| ref_c2w = torch.linalg.inv(ref_w2c) | |
| a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi) | |
| # Lemniscate curve in camera space. Starting at the origin. | |
| thetas = ( | |
| torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1] | |
| + torch.pi / 2 | |
| ) | |
| positions = torch.stack( | |
| [ | |
| a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2), | |
| a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2), | |
| torch.zeros(num_frames, device=ref_w2c.device), | |
| ], | |
| dim=-1, | |
| ) | |
| # Transform to world space. | |
| positions = torch.einsum( | |
| "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) | |
| ) | |
| return get_lookat_w2cs(positions, lookat, up) | |
| def get_spiral_w2cs( | |
| ref_w2c: torch.Tensor, | |
| lookat: torch.Tensor, | |
| up: torch.Tensor, | |
| num_frames: int, | |
| rads: float | torch.Tensor, | |
| zrate: float, | |
| rots: int, | |
| **_, | |
| ) -> torch.Tensor: | |
| ref_c2w = torch.linalg.inv(ref_w2c) | |
| thetas = torch.linspace( | |
| 0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device | |
| )[:-1] | |
| # Spiral curve in camera space. Starting at the origin. | |
| if isinstance(rads, torch.Tensor): | |
| rads = rads.reshape(-1, 3).to(ref_w2c.device) | |
| positions = ( | |
| torch.stack( | |
| [ | |
| torch.cos(thetas), | |
| -torch.sin(thetas), | |
| -torch.sin(thetas * zrate), | |
| ], | |
| dim=-1, | |
| ) | |
| * rads | |
| ) | |
| # Transform to world space. | |
| positions = torch.einsum( | |
| "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0) | |
| ) | |
| return get_lookat_w2cs(positions, lookat, up) | |
| def get_wander_w2cs(ref_w2c, focal_length, num_frames, max_disp, **_): | |
| device = ref_w2c.device | |
| c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy()) | |
| max_disp = max_disp | |
| max_trans = max_disp / focal_length | |
| output_poses = [] | |
| for i in range(num_frames): | |
| x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames)) | |
| y_trans = 0.0 | |
| z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0 | |
| i_pose = np.concatenate( | |
| [ | |
| np.concatenate( | |
| [ | |
| np.eye(3), | |
| np.array([x_trans, y_trans, z_trans])[:, np.newaxis], | |
| ], | |
| axis=1, | |
| ), | |
| np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :], | |
| ], | |
| axis=0, | |
| ) | |
| i_pose = np.linalg.inv(i_pose) | |
| ref_pose = np.concatenate( | |
| [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0 | |
| ) | |
| render_pose = np.dot(ref_pose, i_pose) | |
| output_poses.append(render_pose) | |
| output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device) | |
| w2cs = torch.linalg.inv(output_poses) | |
| return w2cs |