hyzhou404's picture
private scenes
7f3c2df
import torch
from scene.gaussian_model import GaussianModel
from scene.ground_model import GroundModel
from gsplat.rendering import rasterization
import roma
from scene.cameras import Camera
from torch import Tensor
def euler2matrix(yaw):
return torch.tensor([
[torch.cos(-yaw), 0, torch.sin(-yaw)],
[0, 1, 0],
[-torch.sin(-yaw), 0, torch.cos(-yaw)]
]).cuda()
def cat_bgfg(bg, fg, only_xyz=False):
if only_xyz:
if bg.ground_model is None:
bg_feats = [bg.get_xyz]
else:
bg_feats = [bg.get_full_xyz]
else:
if bg.ground_model is None:
bg_feats = [bg.get_xyz, bg.get_opacity, bg.get_scaling, bg.get_rotation, bg.get_features, bg.get_3D_features]
else:
bg_feats = [bg.get_full_xyz, bg.get_full_opacity, bg.get_full_scaling, bg.get_full_rotation, bg.get_full_features, bg.get_full_3D_features]
if len(fg) == 0:
return bg_feats
output = []
for fg_feat, bg_feat in zip(fg, bg_feats):
if fg_feat is None:
output.append(bg_feat)
else:
if bg_feat.shape[1] != fg_feat.shape[1]:
fg_feat = fg_feat[:, :bg_feat.shape[1], :]
output.append(torch.cat((bg_feat, fg_feat), dim=0))
return output
def concatenate_all(all_fg):
output = []
for feat in list(zip(*all_fg)):
output.append(torch.cat(feat, dim=0))
return output
def proj_uv(xyz, cam):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
intr = torch.as_tensor(cam.K[:3, :3]).float().to(device) # (3, 3)
w2c = torch.linalg.inv(cam.c2w)[:3, :] # (3, 4)
c_xyz = (w2c[:3, :3] @ xyz.T).T + w2c[:3, 3]
i_xyz = (intr @ c_xyz.mT).mT # (N, 3)
uv = i_xyz[:, [1,0]] / i_xyz[:, -1:].clip(1e-3) # (N, 2)
return uv
def unicycle_b2w(timestamp, model):
pred = model(timestamp)
if pred is None:
return None
pred_a, pred_b, pred_v, pitchroll, pred_yaw, pred_h = pred
rt = torch.eye(4).float().cuda()
rt[:3,:3] = roma.euler_to_rotmat('xzy', [-pitchroll[0]+torch.pi/2, -pitchroll[1]+torch.pi/2, -pred_yaw+torch.pi/2])
rt[1, 3], rt[0, 3], rt[2, 3] = pred_h, pred_a, pred_b
return rt
def render(viewpoint:Camera, prev_viewpoint:Camera, pc:GaussianModel, dynamic_gaussians:dict,
unicycles:dict, bg_color:Tensor, render_optical=False, planning=[]):
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""
timestamp = viewpoint.timestamp
all_fg = [None, None, None, None, None, None]
prev_all_fg = [None]
if unicycles is None or len(unicycles) == 0:
track_dict = viewpoint.dynamics
if prev_viewpoint is not None:
prev_track_dict = prev_viewpoint.dynamics
else:
track_dict, prev_track_dict = {}, {}
for track_id, B2W in viewpoint.dynamics.items():
if track_id in unicycles:
B2W = unicycle_b2w(timestamp, unicycles[track_id]['model'])
track_dict[track_id] = B2W
if prev_viewpoint is not None:
prev_B2W = unicycle_b2w(prev_viewpoint.timestamp, unicycles[track_id]['model'])
prev_track_dict[track_id] = prev_B2W
if len(planning) > 0:
for plan_id, B2W in planning[0].items():
track_dict[plan_id] = B2W
if prev_viewpoint is not None:
for plan_id, B2W in planning[1].items():
prev_track_dict[plan_id] = B2W
all_fg, prev_all_fg = [], []
for track_id, B2W in track_dict.items():
w_dxyz = (B2W[:3, :3] @ dynamic_gaussians[track_id].get_xyz.T).T + B2W[:3, 3]
drot = roma.quat_wxyz_to_xyzw(dynamic_gaussians[track_id].get_rotation)
drot = roma.unitquat_to_rotmat(drot)
w_drot = roma.quat_xyzw_to_wxyz(roma.rotmat_to_unitquat(B2W[:3, :3] @ drot))
fg = [w_dxyz,
dynamic_gaussians[track_id].get_opacity,
dynamic_gaussians[track_id].get_scaling,
w_drot,
# dynamic_gaussians[track_id].get_rotation,
dynamic_gaussians[track_id].get_features,
dynamic_gaussians[track_id].get_3D_features]
all_fg.append(fg)
if render_optical and prev_viewpoint is not None:
if track_id in prev_track_dict:
prev_B2W = prev_track_dict[track_id]
prev_w_dxyz = torch.mm(prev_B2W[:3, :3], dynamic_gaussians[track_id].get_xyz.T).T + prev_B2W[:3, 3]
prev_all_fg.append([prev_w_dxyz])
else:
prev_all_fg.append([w_dxyz])
all_fg = concatenate_all(all_fg)
xyz, opacities, scales, rotations, shs, feats3D = cat_bgfg(pc, all_fg)
if render_optical and prev_viewpoint is not None:
prev_all_fg = concatenate_all(prev_all_fg)
prev_xyz = cat_bgfg(pc, prev_all_fg, only_xyz=True)[0]
uv = proj_uv(xyz, viewpoint)
prev_uv = proj_uv(prev_xyz, prev_viewpoint)
delta_uv = prev_uv - uv
delta_uv = torch.cat([delta_uv, torch.ones_like(delta_uv[:, :1], device=delta_uv.device)], dim=-1)
else:
delta_uv = torch.zeros_like(xyz)
if pc.affine:
cam_xyz, cam_dir = viewpoint.c2w[:3, 3].cuda(), viewpoint.c2w[:3, 2].cuda()
o_enc = pc.pos_enc(cam_xyz[None, :] / 60)
d_enc = pc.dir_enc(cam_dir[None, :])
appearance = pc.appearance_model(torch.cat([o_enc, d_enc], dim=1)) * 1e-1
affine_weight, affine_bias = appearance[:, :9].view(3, 3), appearance[:, -3:]
affine_weight = affine_weight + torch.eye(3, device=appearance.device)
if render_optical:
render_mode = 'RGB+ED+S+F'
else:
render_mode = 'RGB+ED+S'
renders, render_alphas, info = rasterization(
means=xyz,
quats=rotations,
scales=scales,
opacities=opacities[:, 0],
colors=shs,
viewmats=torch.linalg.inv(viewpoint.c2w)[None, ...], # [C, 4, 4]
Ks=viewpoint.K[None, :3, :3], # [C, 3, 3]
width=viewpoint.width,
height=viewpoint.height,
smts=feats3D[None, ...],
flows= delta_uv[None, ...],
render_mode=render_mode,
sh_degree=pc.active_sh_degree,
near_plane=0.01,
far_plane=500,
packed=False,
backgrounds=bg_color[None, :],
)
renders = renders[0]
rendered_image = renders[..., :3].permute(2,0,1)
depth = renders[..., 3][None, ...]
smt = renders[..., 4:(4+feats3D.shape[-1])].permute(2,0,1)
if pc.affine:
colors = rendered_image.view(3, -1).permute(1, 0) # (H*W, 3)
refined_image = (colors @ affine_weight + affine_bias).clip(0, 1).permute(1, 0).view(*rendered_image.shape)
else:
refined_image = rendered_image
return {"render": refined_image,
"feats": smt,
"depth": depth,
"opticalflow": renders[..., -2:].permute(2,0,1) if render_optical else None,
"alphas": render_alphas,
"viewspace_points": info["means2d"],
"info": info,
}
def render_ground(viewpoint:Camera, pc:GroundModel, bg_color:Tensor):
xyz, opacities, scales = pc.get_xyz, pc.get_opacity, pc.get_scaling
rotations, shs, feats3D = pc.get_rotation, pc.get_features, pc.get_3D_features
K = viewpoint.K[None, :3, :3]
renders, render_alphas, info = rasterization(
means=xyz,
quats=rotations,
scales=scales,
opacities=opacities[:, 0],
colors=shs,
viewmats=torch.linalg.inv(viewpoint.c2w)[None, ...], # [C, 4, 4]
Ks=K, # [C, 3, 3]
width=viewpoint.width,
height=viewpoint.height,
smts=feats3D[None, ...],
render_mode='RGB+ED+S',
sh_degree=pc.active_sh_degree,
near_plane=0.01,
far_plane=500,
packed=False,
backgrounds=bg_color[None, :],
)
renders = renders[0]
rendered_image = renders[..., :3].permute(2,0,1)
depth = renders[..., 3][None, ...]
smt = renders[..., 4:(4+feats3D.shape[-1])].permute(2,0,1)
return {"render": rendered_image,
"feats": smt,
"depth": depth,
"opticalflow": None,
"alphas": render_alphas,
"viewspace_points": info["means2d"],
"info": info,
}