Spaces:
Paused
Paused
import torch | |
from scene.gaussian_model import GaussianModel | |
from scene.ground_model import GroundModel | |
from gsplat.rendering import rasterization | |
import roma | |
from scene.cameras import Camera | |
from torch import Tensor | |
def euler2matrix(yaw): | |
return torch.tensor([ | |
[torch.cos(-yaw), 0, torch.sin(-yaw)], | |
[0, 1, 0], | |
[-torch.sin(-yaw), 0, torch.cos(-yaw)] | |
]).cuda() | |
def cat_bgfg(bg, fg, only_xyz=False): | |
if only_xyz: | |
if bg.ground_model is None: | |
bg_feats = [bg.get_xyz] | |
else: | |
bg_feats = [bg.get_full_xyz] | |
else: | |
if bg.ground_model is None: | |
bg_feats = [bg.get_xyz, bg.get_opacity, bg.get_scaling, bg.get_rotation, bg.get_features, bg.get_3D_features] | |
else: | |
bg_feats = [bg.get_full_xyz, bg.get_full_opacity, bg.get_full_scaling, bg.get_full_rotation, bg.get_full_features, bg.get_full_3D_features] | |
if len(fg) == 0: | |
return bg_feats | |
output = [] | |
for fg_feat, bg_feat in zip(fg, bg_feats): | |
if fg_feat is None: | |
output.append(bg_feat) | |
else: | |
if bg_feat.shape[1] != fg_feat.shape[1]: | |
fg_feat = fg_feat[:, :bg_feat.shape[1], :] | |
output.append(torch.cat((bg_feat, fg_feat), dim=0)) | |
return output | |
def concatenate_all(all_fg): | |
output = [] | |
for feat in list(zip(*all_fg)): | |
output.append(torch.cat(feat, dim=0)) | |
return output | |
def proj_uv(xyz, cam): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
intr = torch.as_tensor(cam.K[:3, :3]).float().to(device) # (3, 3) | |
w2c = torch.linalg.inv(cam.c2w)[:3, :] # (3, 4) | |
c_xyz = (w2c[:3, :3] @ xyz.T).T + w2c[:3, 3] | |
i_xyz = (intr @ c_xyz.mT).mT # (N, 3) | |
uv = i_xyz[:, [1,0]] / i_xyz[:, -1:].clip(1e-3) # (N, 2) | |
return uv | |
def unicycle_b2w(timestamp, model): | |
pred = model(timestamp) | |
if pred is None: | |
return None | |
pred_a, pred_b, pred_v, pitchroll, pred_yaw, pred_h = pred | |
rt = torch.eye(4).float().cuda() | |
rt[:3,:3] = roma.euler_to_rotmat('xzy', [-pitchroll[0]+torch.pi/2, -pitchroll[1]+torch.pi/2, -pred_yaw+torch.pi/2]) | |
rt[1, 3], rt[0, 3], rt[2, 3] = pred_h, pred_a, pred_b | |
return rt | |
def render(viewpoint:Camera, prev_viewpoint:Camera, pc:GaussianModel, dynamic_gaussians:dict, | |
unicycles:dict, bg_color:Tensor, render_optical=False, planning=[]): | |
""" | |
Render the scene. | |
Background tensor (bg_color) must be on GPU! | |
""" | |
timestamp = viewpoint.timestamp | |
all_fg = [None, None, None, None, None, None] | |
prev_all_fg = [None] | |
if unicycles is None or len(unicycles) == 0: | |
track_dict = viewpoint.dynamics | |
if prev_viewpoint is not None: | |
prev_track_dict = prev_viewpoint.dynamics | |
else: | |
track_dict, prev_track_dict = {}, {} | |
for track_id, B2W in viewpoint.dynamics.items(): | |
if track_id in unicycles: | |
B2W = unicycle_b2w(timestamp, unicycles[track_id]['model']) | |
track_dict[track_id] = B2W | |
if prev_viewpoint is not None: | |
prev_B2W = unicycle_b2w(prev_viewpoint.timestamp, unicycles[track_id]['model']) | |
prev_track_dict[track_id] = prev_B2W | |
if len(planning) > 0: | |
for plan_id, B2W in planning[0].items(): | |
track_dict[plan_id] = B2W | |
if prev_viewpoint is not None: | |
for plan_id, B2W in planning[1].items(): | |
prev_track_dict[plan_id] = B2W | |
all_fg, prev_all_fg = [], [] | |
for track_id, B2W in track_dict.items(): | |
w_dxyz = (B2W[:3, :3] @ dynamic_gaussians[track_id].get_xyz.T).T + B2W[:3, 3] | |
drot = roma.quat_wxyz_to_xyzw(dynamic_gaussians[track_id].get_rotation) | |
drot = roma.unitquat_to_rotmat(drot) | |
w_drot = roma.quat_xyzw_to_wxyz(roma.rotmat_to_unitquat(B2W[:3, :3] @ drot)) | |
fg = [w_dxyz, | |
dynamic_gaussians[track_id].get_opacity, | |
dynamic_gaussians[track_id].get_scaling, | |
w_drot, | |
# dynamic_gaussians[track_id].get_rotation, | |
dynamic_gaussians[track_id].get_features, | |
dynamic_gaussians[track_id].get_3D_features] | |
all_fg.append(fg) | |
if render_optical and prev_viewpoint is not None: | |
if track_id in prev_track_dict: | |
prev_B2W = prev_track_dict[track_id] | |
prev_w_dxyz = torch.mm(prev_B2W[:3, :3], dynamic_gaussians[track_id].get_xyz.T).T + prev_B2W[:3, 3] | |
prev_all_fg.append([prev_w_dxyz]) | |
else: | |
prev_all_fg.append([w_dxyz]) | |
all_fg = concatenate_all(all_fg) | |
xyz, opacities, scales, rotations, shs, feats3D = cat_bgfg(pc, all_fg) | |
if render_optical and prev_viewpoint is not None: | |
prev_all_fg = concatenate_all(prev_all_fg) | |
prev_xyz = cat_bgfg(pc, prev_all_fg, only_xyz=True)[0] | |
uv = proj_uv(xyz, viewpoint) | |
prev_uv = proj_uv(prev_xyz, prev_viewpoint) | |
delta_uv = prev_uv - uv | |
delta_uv = torch.cat([delta_uv, torch.ones_like(delta_uv[:, :1], device=delta_uv.device)], dim=-1) | |
else: | |
delta_uv = torch.zeros_like(xyz) | |
if pc.affine: | |
cam_xyz, cam_dir = viewpoint.c2w[:3, 3].cuda(), viewpoint.c2w[:3, 2].cuda() | |
o_enc = pc.pos_enc(cam_xyz[None, :] / 60) | |
d_enc = pc.dir_enc(cam_dir[None, :]) | |
appearance = pc.appearance_model(torch.cat([o_enc, d_enc], dim=1)) * 1e-1 | |
affine_weight, affine_bias = appearance[:, :9].view(3, 3), appearance[:, -3:] | |
affine_weight = affine_weight + torch.eye(3, device=appearance.device) | |
if render_optical: | |
render_mode = 'RGB+ED+S+F' | |
else: | |
render_mode = 'RGB+ED+S' | |
renders, render_alphas, info = rasterization( | |
means=xyz, | |
quats=rotations, | |
scales=scales, | |
opacities=opacities[:, 0], | |
colors=shs, | |
viewmats=torch.linalg.inv(viewpoint.c2w)[None, ...], # [C, 4, 4] | |
Ks=viewpoint.K[None, :3, :3], # [C, 3, 3] | |
width=viewpoint.width, | |
height=viewpoint.height, | |
smts=feats3D[None, ...], | |
flows= delta_uv[None, ...], | |
render_mode=render_mode, | |
sh_degree=pc.active_sh_degree, | |
near_plane=0.01, | |
far_plane=500, | |
packed=False, | |
backgrounds=bg_color[None, :], | |
) | |
renders = renders[0] | |
rendered_image = renders[..., :3].permute(2,0,1) | |
depth = renders[..., 3][None, ...] | |
smt = renders[..., 4:(4+feats3D.shape[-1])].permute(2,0,1) | |
if pc.affine: | |
colors = rendered_image.view(3, -1).permute(1, 0) # (H*W, 3) | |
refined_image = (colors @ affine_weight + affine_bias).clip(0, 1).permute(1, 0).view(*rendered_image.shape) | |
else: | |
refined_image = rendered_image | |
return {"render": refined_image, | |
"feats": smt, | |
"depth": depth, | |
"opticalflow": renders[..., -2:].permute(2,0,1) if render_optical else None, | |
"alphas": render_alphas, | |
"viewspace_points": info["means2d"], | |
"info": info, | |
} | |
def render_ground(viewpoint:Camera, pc:GroundModel, bg_color:Tensor): | |
xyz, opacities, scales = pc.get_xyz, pc.get_opacity, pc.get_scaling | |
rotations, shs, feats3D = pc.get_rotation, pc.get_features, pc.get_3D_features | |
K = viewpoint.K[None, :3, :3] | |
renders, render_alphas, info = rasterization( | |
means=xyz, | |
quats=rotations, | |
scales=scales, | |
opacities=opacities[:, 0], | |
colors=shs, | |
viewmats=torch.linalg.inv(viewpoint.c2w)[None, ...], # [C, 4, 4] | |
Ks=K, # [C, 3, 3] | |
width=viewpoint.width, | |
height=viewpoint.height, | |
smts=feats3D[None, ...], | |
render_mode='RGB+ED+S', | |
sh_degree=pc.active_sh_degree, | |
near_plane=0.01, | |
far_plane=500, | |
packed=False, | |
backgrounds=bg_color[None, :], | |
) | |
renders = renders[0] | |
rendered_image = renders[..., :3].permute(2,0,1) | |
depth = renders[..., 3][None, ...] | |
smt = renders[..., 4:(4+feats3D.shape[-1])].permute(2,0,1) | |
return {"render": rendered_image, | |
"feats": smt, | |
"depth": depth, | |
"opticalflow": None, | |
"alphas": render_alphas, | |
"viewspace_points": info["means2d"], | |
"info": info, | |
} |