|
import viser |
|
from pathlib import Path |
|
from typing import Literal |
|
from typing import Tuple, Callable |
|
from nerfview import Viewer, RenderTabState |
|
|
|
|
|
class GsplatRenderTabState(RenderTabState): |
|
|
|
total_gs_count: int = 0 |
|
rendered_gs_count: int = 0 |
|
|
|
|
|
max_sh_degree: int = 5 |
|
near_plane: float = 1e-2 |
|
far_plane: float = 1e2 |
|
radius_clip: float = 0.0 |
|
eps2d: float = 0.3 |
|
backgrounds: Tuple[float, float, float] = (0.0, 0.0, 0.0) |
|
render_mode: Literal[ |
|
"rgb", "depth(accumulated)", "depth(expected)", "alpha" |
|
] = "rgb" |
|
normalize_nearfar: bool = False |
|
inverse: bool = True |
|
colormap: Literal[ |
|
"turbo", "viridis", "magma", "inferno", "cividis", "gray" |
|
] = "turbo" |
|
rasterize_mode: Literal["classic", "antialiased"] = "classic" |
|
camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" |
|
|
|
|
|
class GsplatViewer(Viewer): |
|
""" |
|
Viewer for gsplat. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
server: viser.ViserServer, |
|
render_fn: Callable, |
|
output_dir: Path, |
|
mode: Literal["rendering", "training"] = "rendering", |
|
): |
|
super().__init__(server, render_fn, output_dir, mode) |
|
server.gui.set_panel_label("gsplat viewer") |
|
|
|
def _init_rendering_tab(self): |
|
self.render_tab_state = GsplatRenderTabState() |
|
self._rendering_tab_handles = {} |
|
self._rendering_folder = self.server.gui.add_folder("Rendering") |
|
|
|
def _populate_rendering_tab(self): |
|
server = self.server |
|
with self._rendering_folder: |
|
with server.gui.add_folder("Gsplat"): |
|
total_gs_count_number = server.gui.add_number( |
|
"Total", |
|
initial_value=self.render_tab_state.total_gs_count, |
|
disabled=True, |
|
hint="Total number of splats in the scene.", |
|
) |
|
rendered_gs_count_number = server.gui.add_number( |
|
"Rendered", |
|
initial_value=self.render_tab_state.rendered_gs_count, |
|
disabled=True, |
|
hint="Number of splats rendered.", |
|
) |
|
|
|
max_sh_degree_number = server.gui.add_number( |
|
"Max SH", |
|
initial_value=self.render_tab_state.max_sh_degree, |
|
min=0, |
|
max=5, |
|
step=1, |
|
hint="Maximum SH degree used", |
|
) |
|
|
|
@max_sh_degree_number.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.max_sh_degree = int( |
|
max_sh_degree_number.value |
|
) |
|
self.rerender(_) |
|
|
|
near_far_plane_vec2 = server.gui.add_vector2( |
|
"Near/Far", |
|
initial_value=( |
|
self.render_tab_state.near_plane, |
|
self.render_tab_state.far_plane, |
|
), |
|
min=(1e-3, 1e1), |
|
max=(1e1, 1e2), |
|
step=1e-3, |
|
hint="Near and far plane for rendering.", |
|
) |
|
|
|
@near_far_plane_vec2.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.near_plane = near_far_plane_vec2.value[0] |
|
self.render_tab_state.far_plane = near_far_plane_vec2.value[1] |
|
self.rerender(_) |
|
|
|
radius_clip_slider = server.gui.add_number( |
|
"Radius Clip", |
|
initial_value=self.render_tab_state.radius_clip, |
|
min=0.0, |
|
max=100.0, |
|
step=1.0, |
|
hint="2D radius clip for rendering.", |
|
) |
|
|
|
@radius_clip_slider.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.radius_clip = radius_clip_slider.value |
|
self.rerender(_) |
|
|
|
eps2d_slider = server.gui.add_number( |
|
"2D Epsilon", |
|
initial_value=self.render_tab_state.eps2d, |
|
min=0.0, |
|
max=1.0, |
|
step=0.01, |
|
hint="Epsilon added to the egienvalues of projected 2D covariance matrices.", |
|
) |
|
|
|
@eps2d_slider.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.eps2d = eps2d_slider.value |
|
self.rerender(_) |
|
|
|
backgrounds_slider = server.gui.add_rgb( |
|
"Background", |
|
initial_value=self.render_tab_state.backgrounds, |
|
hint="Background color for rendering.", |
|
) |
|
|
|
@backgrounds_slider.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.backgrounds = backgrounds_slider.value |
|
self.rerender(_) |
|
|
|
render_mode_dropdown = server.gui.add_dropdown( |
|
"Render Mode", |
|
("rgb", "depth(accumulated)", "depth(expected)", "alpha"), |
|
initial_value=self.render_tab_state.render_mode, |
|
hint="Render mode to use.", |
|
) |
|
|
|
@render_mode_dropdown.on_update |
|
def _(_) -> None: |
|
if "depth" in render_mode_dropdown.value: |
|
normalize_nearfar_checkbox.disabled = False |
|
else: |
|
normalize_nearfar_checkbox.disabled = True |
|
if render_mode_dropdown.value == "rgb": |
|
inverse_checkbox.disabled = True |
|
else: |
|
inverse_checkbox.disabled = False |
|
self.render_tab_state.render_mode = render_mode_dropdown.value |
|
self.rerender(_) |
|
|
|
normalize_nearfar_checkbox = server.gui.add_checkbox( |
|
"Normalize Near/Far", |
|
initial_value=self.render_tab_state.normalize_nearfar, |
|
disabled=True, |
|
hint="Normalize depth with near/far plane.", |
|
) |
|
|
|
@normalize_nearfar_checkbox.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.normalize_nearfar = ( |
|
normalize_nearfar_checkbox.value |
|
) |
|
self.rerender(_) |
|
|
|
inverse_checkbox = server.gui.add_checkbox( |
|
"Inverse", |
|
initial_value=self.render_tab_state.inverse, |
|
disabled=True, |
|
hint="Inverse the depth.", |
|
) |
|
|
|
@inverse_checkbox.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.inverse = inverse_checkbox.value |
|
self.rerender(_) |
|
|
|
colormap_dropdown = server.gui.add_dropdown( |
|
"Colormap", |
|
("turbo", "viridis", "magma", "inferno", "cividis", "gray"), |
|
initial_value=self.render_tab_state.colormap, |
|
hint="Colormap used for rendering depth/alpha.", |
|
) |
|
|
|
@colormap_dropdown.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.colormap = colormap_dropdown.value |
|
self.rerender(_) |
|
|
|
rasterize_mode_dropdown = server.gui.add_dropdown( |
|
"Anti-Aliasing", |
|
("classic", "antialiased"), |
|
initial_value=self.render_tab_state.rasterize_mode, |
|
hint="Whether to use classic or antialiased rasterization.", |
|
) |
|
|
|
@rasterize_mode_dropdown.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.rasterize_mode = rasterize_mode_dropdown.value |
|
self.rerender(_) |
|
|
|
camera_model_dropdown = server.gui.add_dropdown( |
|
"Camera", |
|
("pinhole", "ortho", "fisheye"), |
|
initial_value=self.render_tab_state.camera_model, |
|
hint="Camera model used for rendering.", |
|
) |
|
|
|
@camera_model_dropdown.on_update |
|
def _(_) -> None: |
|
self.render_tab_state.camera_model = camera_model_dropdown.value |
|
self.rerender(_) |
|
|
|
self._rendering_tab_handles.update( |
|
{ |
|
"total_gs_count_number": total_gs_count_number, |
|
"rendered_gs_count_number": rendered_gs_count_number, |
|
"near_far_plane_vec2": near_far_plane_vec2, |
|
"radius_clip_slider": radius_clip_slider, |
|
"eps2d_slider": eps2d_slider, |
|
"backgrounds_slider": backgrounds_slider, |
|
"render_mode_dropdown": render_mode_dropdown, |
|
"normalize_nearfar_checkbox": normalize_nearfar_checkbox, |
|
"inverse_checkbox": inverse_checkbox, |
|
"colormap_dropdown": colormap_dropdown, |
|
"rasterize_mode_dropdown": rasterize_mode_dropdown, |
|
"camera_model_dropdown": camera_model_dropdown, |
|
} |
|
) |
|
super()._populate_rendering_tab() |
|
|
|
def _after_render(self): |
|
|
|
self._rendering_tab_handles[ |
|
"total_gs_count_number" |
|
].value = self.render_tab_state.total_gs_count |
|
self._rendering_tab_handles[ |
|
"rendered_gs_count_number" |
|
].value = self.render_tab_state.rendered_gs_count |
|
|