liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
raw
history blame
2.39 kB
import torch
from gsplat import rasterization
from dust3r.utils.geometry import inv, geotrf
def render(
intrinsics: torch.Tensor,
pts3d: torch.Tensor,
rgbs: torch.Tensor | None = None,
scale: float = 0.002,
opacity: float = 0.95,
):
device = pts3d.device
batch_size = len(intrinsics)
img_size = pts3d.shape[1:3]
pts3d = pts3d.reshape(batch_size, -1, 3)
num_pts = pts3d.shape[1]
quats = torch.randn((num_pts, 4), device=device)
quats = quats / quats.norm(dim=-1, keepdim=True)
scales = scale * torch.ones((num_pts, 3), device=device)
opacities = opacity * torch.ones((num_pts), device=device)
if rgbs is not None:
assert rgbs.shape[1] == 3
rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2)
else:
rgbs = torch.ones_like(pts3d[:, :, :3])
rendered_rgbs = []
rendered_depths = []
accs = []
for i in range(batch_size):
rgbd, acc, _ = rasterization(
pts3d[i],
quats,
scales,
opacities,
rgbs[i],
torch.eye(4, device=device)[None],
intrinsics[[i]],
width=img_size[1],
height=img_size[0],
packed=False,
render_mode="RGB+D",
)
rendered_depths.append(rgbd[..., 3])
rendered_depths = torch.cat(rendered_depths, dim=0)
return rendered_rgbs, rendered_depths, accs
def get_render_results(gts, preds, self_view=False):
device = preds[0]["pts3d_in_self_view"].device
with torch.no_grad():
depths = []
gt_depths = []
for i, (gt, pred) in enumerate(zip(gts, preds)):
if self_view:
camera = inv(gt["camera_pose"]).to(device)
intrinsics = gt["camera_intrinsics"].to(device)
pred = pred["pts3d_in_self_view"]
else:
camera = inv(gts[0]["camera_pose"]).to(device)
intrinsics = gts[0]["camera_intrinsics"].to(device)
pred = pred["pts3d_in_other_view"]
gt_img = gt["img"].to(device)
gt_pts3d = gt["pts3d"].to(device)
_, depth, _ = render(intrinsics, pred, gt_img)
_, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img)
depths.append(depth)
gt_depths.append(gt_depth)
return depths, gt_depths