Spaces:
Running
Running
| # | |
| # Copyright (C) 2023, Inria | |
| # GRAPHDECO research group, https://team.inria.fr/graphdeco | |
| # All rights reserved. | |
| # | |
| # This software is free for non-commercial, research and evaluation use | |
| # under the terms of the LICENSE.md file. | |
| # | |
| # For inquiries contact [email protected] | |
| # | |
| import torch | |
| import math | |
| import numpy as np | |
| from typing import NamedTuple | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| class BasicPointCloud(NamedTuple): | |
| points : np.array | |
| colors : np.array | |
| normals : np.array | |
| features: np.array | |
| def geom_transform_points(points, transf_matrix): | |
| P, _ = points.shape | |
| ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) | |
| points_hom = torch.cat([points, ones], dim=1) | |
| points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) | |
| denom = points_out[..., 3:] + 0.0000001 | |
| return (points_out[..., :3] / denom).squeeze(dim=0) | |
| def getWorld2View(R, t): | |
| Rt = np.zeros((4, 4)) | |
| Rt[:3, :3] = R.transpose() | |
| Rt[:3, 3] = t | |
| Rt[3, 3] = 1.0 | |
| return np.float32(Rt) | |
| def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): | |
| Rt = np.zeros((4, 4)) | |
| Rt[:3, :3] = R.transpose() | |
| Rt[:3, 3] = t | |
| Rt[3, 3] = 1.0 | |
| C2W = np.linalg.inv(Rt) | |
| cam_center = C2W[:3, 3] | |
| cam_center = (cam_center + translate) * scale | |
| C2W[:3, 3] = cam_center | |
| Rt = np.linalg.inv(C2W) | |
| return np.float32(Rt) | |
| def getWorld2View2_torch(R, t, translate=torch.tensor([0.0, 0.0, 0.0]), scale=1.0): | |
| translate = torch.tensor(translate, dtype=torch.float32) | |
| # Initialize the transformation matrix | |
| Rt = torch.zeros((4, 4), dtype=torch.float32) | |
| Rt[:3, :3] = R.t() # Transpose of R | |
| Rt[:3, 3] = t | |
| Rt[3, 3] = 1.0 | |
| # Compute the inverse to get the camera-to-world transformation | |
| C2W = torch.linalg.inv(Rt) | |
| cam_center = C2W[:3, 3] | |
| cam_center = (cam_center + translate) * scale | |
| C2W[:3, 3] = cam_center | |
| # Invert again to get the world-to-view transformation | |
| Rt = torch.linalg.inv(C2W) | |
| return Rt | |
| def getProjectionMatrix(znear, zfar, fovX, fovY): | |
| tanHalfFovY = math.tan((fovY / 2)) | |
| tanHalfFovX = math.tan((fovX / 2)) | |
| top = tanHalfFovY * znear | |
| bottom = -top | |
| right = tanHalfFovX * znear | |
| left = -right | |
| P = torch.zeros(4, 4) | |
| z_sign = 1.0 | |
| P[0, 0] = 2.0 * znear / (right - left) | |
| P[1, 1] = 2.0 * znear / (top - bottom) | |
| P[0, 2] = (right + left) / (right - left) | |
| P[1, 2] = (top + bottom) / (top - bottom) | |
| P[3, 2] = z_sign | |
| P[2, 2] = z_sign * zfar / (zfar - znear) | |
| P[2, 3] = -(zfar * znear) / (zfar - znear) | |
| return P | |
| def fov2focal(fov, pixels): | |
| return pixels / (2 * math.tan(fov / 2)) | |
| def focal2fov(focal, pixels): | |
| return 2*math.atan(pixels/(2*focal)) | |
| def resize_render(view, size=None): | |
| image_size = size if size is not None else max(view.image_width, view.image_height) | |
| view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device) | |
| focal_length_x = fov2focal(view.FoVx, view.image_width) | |
| focal_length_y = fov2focal(view.FoVy, view.image_height) | |
| view.image_width = image_size | |
| view.image_height = image_size | |
| view.FoVx = focal2fov(focal_length_x, image_size) | |
| view.FoVy = focal2fov(focal_length_y, image_size) | |
| return view | |
| def make_video_divisble( | |
| video: torch.Tensor | np.ndarray, block_size=16 | |
| ) -> torch.Tensor | np.ndarray: | |
| H, W = video.shape[1:3] | |
| H_new = H - H % block_size | |
| W_new = W - W % block_size | |
| return video[:, :H_new, :W_new] | |
| def depth_to_points( | |
| depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True | |
| ) -> Tensor: | |
| """Convert depth maps to 3D points | |
| Args: | |
| depths: Depth maps [..., H, W, 1] | |
| camtoworlds: Camera-to-world transformation matrices [..., 4, 4] | |
| Ks: Camera intrinsics [..., 3, 3] | |
| z_depth: Whether the depth is in z-depth (True) or ray depth (False) | |
| Returns: | |
| points: 3D points in the world coordinate system [..., H, W, 3] | |
| """ | |
| assert depths.shape[-1] == 1, f"Invalid depth shape: {depths.shape}" | |
| assert camtoworlds.shape[-2:] == ( | |
| 4, | |
| 4, | |
| ), f"Invalid viewmats shape: {camtoworlds.shape}" | |
| assert Ks.shape[-2:] == (3, 3), f"Invalid Ks shape: {Ks.shape}" | |
| assert ( | |
| depths.shape[:-3] == camtoworlds.shape[:-2] == Ks.shape[:-2] | |
| ), f"Shape mismatch! depths: {depths.shape}, viewmats: {camtoworlds.shape}, Ks: {Ks.shape}" | |
| device = depths.device | |
| height, width = depths.shape[-3:-1] | |
| x, y = torch.meshgrid( | |
| torch.arange(width, device=device), | |
| torch.arange(height, device=device), | |
| indexing="xy", | |
| ) # [H, W] | |
| fx = Ks[..., 0, 0] # [...] | |
| fy = Ks[..., 1, 1] # [...] | |
| cx = Ks[..., 0, 2] # [...] | |
| cy = Ks[..., 1, 2] # [...] | |
| # camera directions in camera coordinates | |
| camera_dirs = F.pad( | |
| torch.stack( | |
| [ | |
| (x - cx[..., None, None] + 0.5) / fx[..., None, None], | |
| (y - cy[..., None, None] + 0.5) / fy[..., None, None], | |
| ], | |
| dim=-1, | |
| ), | |
| (0, 1), | |
| value=1.0, | |
| ) # [..., H, W, 3] | |
| # ray directions in world coordinates | |
| directions = torch.einsum( | |
| "...ij,...hwj->...hwi", camtoworlds[..., :3, :3], camera_dirs | |
| ) # [..., H, W, 3] | |
| origins = camtoworlds[..., :3, -1] # [..., 3] | |
| if not z_depth: | |
| directions = F.normalize(directions, dim=-1) | |
| points = origins[..., None, None, :] + depths * directions | |
| return points | |
| def depth_to_normal( | |
| depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True | |
| ) -> Tensor: | |
| """Convert depth maps to surface normals | |
| Args: | |
| depths: Depth maps [..., H, W, 1] | |
| camtoworlds: Camera-to-world transformation matrices [..., 4, 4] | |
| Ks: Camera intrinsics [..., 3, 3] | |
| z_depth: Whether the depth is in z-depth (True) or ray depth (False) | |
| Returns: | |
| normals: Surface normals in the world coordinate system [..., H, W, 3] | |
| """ | |
| points = depth_to_points(depths, camtoworlds, Ks, z_depth=z_depth) # [..., H, W, 3] | |
| dx = torch.cat( | |
| [points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :]], dim=-3 | |
| ) # [..., H-2, W-2, 3] | |
| dy = torch.cat( | |
| [points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :]], dim=-2 | |
| ) # [..., H-2, W-2, 3] | |
| normals = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) # [..., H-2, W-2, 3] | |
| normals = F.pad(normals, (0, 0, 1, 1, 1, 1), value=0.0) # [..., H, W, 3] | |
| return normals |