SceneDINO / app.py
jev-aleks's picture
fixes for ZeroGPU
7abe500
raw
history blame
8.22 kB
import tempfile
import os
import sys
import yaml
import numpy as np
import open3d as o3d
import gradio as gr
import spaces
import torch
sys.path.append("./sscbench")
from sscbench.gen_voxelgrid_npy import save_as_voxel_ply, classes_to_colors
from download_checkpoint_hf import download_scenedino_checkpoint
from demo_utils.utils import (load_modules,
load_sample_from_path,
load_sample_from_dataset,
get_fov_mask,
inference_3d,
inference_rendered_2d)
os.environ['NUMBA_THREADING_LAYER'] = 'workqueue'
# Load checkpoints from Hugging Face
download_scenedino_checkpoint("ssc-kitti-360-dino")
download_scenedino_checkpoint("ssc-kitti-360-dinov2")
def convert_voxels(arr, map_dict):
f = np.vectorize(map_dict.__getitem__)
return f(arr)
with open("sscbench/label_maps.yaml", "r") as f:
label_maps = yaml.safe_load(f)
@spaces.GPU
def load_scenedino():
# Load model, ray sampler, datasets
ckpt_path = "out/scenedino-pretrained/seg-best-dino/"
ckpt_name = "checkpoint.pt"
net_v1, renderer_v1, ray_sampler_v1 = load_modules(ckpt_path, ckpt_name)
renderer_v1.eval()
ckpt_path = "out/scenedino-pretrained/seg-best-dinov2/"
ckpt_name = "checkpoint.pt"
net_v2, renderer_v2, ray_sampler_v2 = load_modules(ckpt_path, ckpt_name)
renderer_v2.eval()
@spaces.GPU
def demo_run(image: str,
backbone: str,
mode: str,
sigma_threshold: float,
resolution: float,
x_range: int,
y_range: int,
z_range: int):
if backbone == "DINO (ViT-B)":
net, renderer, ray_sampler = net_v1, renderer_v1, ray_sampler_v1
elif backbone == "DINOv2 (ViT-B)":
net, renderer, ray_sampler = net_v2, renderer_v2, ray_sampler_v2
prediction_mode = "stego_kmeans"
if mode == "Feature PCA 1-3":
segmentation = False
rgb_from_pca_dim = 0
elif mode == "Feature PCA 4-6":
segmentation = False
rgb_from_pca_dim = 3
elif mode == "Feature PCA 7-9":
segmentation = False
rgb_from_pca_dim = 6
elif mode == "SSC (unsup.)":
segmentation = True
elif mode == "SSC (linear)":
segmentation = True
prediction_mode = "direct_linear"
# Necessary when reading from examples? cast from str
sigma_threshold, resolution = float(sigma_threshold), float(resolution)
x_range, y_range, z_range = int(x_range), int(y_range), int(z_range)
# Too many voxels
max_voxel_count = 5000000
voxel_count = (x_range//resolution + 1) * (y_range//resolution + 1) * (z_range//resolution + 1)
if voxel_count > max_voxel_count:
raise gr.Error(f"Too many voxels ({int(voxel_count) / 1_000_000:.1f}M > {max_voxel_count / 1_000_000:.1f}M).\n" +
"Reduce voxel resolution or range.", duration=5)
with torch.no_grad():
images, poses, projs = load_sample_from_path(image, intrinsic=None)
net.encode(images, projs, poses, ids_encoder=[0])
net.set_scale(0)
# 2D Features output
dino_full_2d, depth_2d, seg_2d = inference_rendered_2d(net, poses, projs, ray_sampler, renderer, prediction_mode)
net.encoder.fit_visualization(dino_full_2d.flatten(0, -2))
if segmentation:
output_2d = convert_voxels(seg_2d.detach().cpu(), label_maps["cityscapes_to_label"])
output_2d = classes_to_colors[output_2d].cpu().detach().numpy()
else:
output_2d = net.encoder.transform_visualization(dino_full_2d, from_dim=rgb_from_pca_dim)
output_2d -= output_2d.min()
output_2d /= output_2d.max()
output_2d = output_2d.cpu().detach().numpy()
# Chunking
max_chunk_size = 100000
z_layers_per_chunk = max_chunk_size // ((x_range//resolution + 1) * (y_range//resolution + 1))
# 3D Features output
x_range = (-x_range/2, x_range)
y_range = (-y_range/2, y_range)
z_range = (0, z_range)
is_occupied, output_3d, fov_mask = [], [], []
current_z = 0
while current_z <= z_range[1]:
z_range_chunk = (current_z, min(current_z + z_layers_per_chunk*resolution, z_range[1]))
current_z += (z_layers_per_chunk+1) * resolution
xyz_chunk, dino_full_3d_chunk, sigma_3d_chunk, seg_3d_chunk = inference_3d(net, x_range, y_range, z_range_chunk, resolution, prediction_mode)
fov_mask_chunk = get_fov_mask(projs[0, 0], xyz_chunk)
is_occupied_chunk = sigma_3d_chunk > sigma_threshold
if segmentation:
output_3d_chunk = seg_3d_chunk
else:
output_3d_chunk = net.encoder.transform_visualization(dino_full_3d_chunk, from_dim=rgb_from_pca_dim)
output_3d_chunk -= output_3d_chunk.min()
output_3d_chunk /= output_3d_chunk.max()
output_3d_chunk = torch.clamp(output_3d_chunk*1.2 - 0.1, 0.0, 1.0)
output_3d_chunk = (255*output_3d_chunk).int()
fov_mask_chunk = fov_mask_chunk.reshape(is_occupied_chunk.shape)
is_occupied.append(is_occupied_chunk)
output_3d.append(output_3d_chunk)
fov_mask.append(fov_mask_chunk)
is_occupied = torch.cat(is_occupied, dim=2)
output_3d = torch.cat(output_3d, dim=2)
fov_mask = torch.cat(fov_mask, dim=2)
temp_dir = tempfile.gettempdir()
ply_path = os.path.join(temp_dir, "output.ply")
if segmentation:
# mapped to "unlabeled"
is_occupied[output_3d == 10] = 0
is_occupied[output_3d == 12] = 0
save_as_voxel_ply(ply_path,
is_occupied.detach().cpu(),
voxel_size=resolution,
size=is_occupied.size(),
classes=torch.Tensor(
convert_voxels(
output_3d.detach().cpu(),
label_maps["cityscapes_to_label"])),
fov_mask=fov_mask)
else:
save_as_voxel_ply(ply_path,
is_occupied.detach().cpu(),
voxel_size=resolution,
size=is_occupied.size(),
colors=output_3d.detach().cpu(),
fov_mask=fov_mask)
mesh = o3d.io.read_triangle_mesh(ply_path)
glb_path = os.path.join(temp_dir, "output.glb")
o3d.io.write_triangle_mesh(glb_path, mesh, write_ascii=True)
del dino_full_2d, depth_2d, seg_2d
del dino_full_3d_chunk, sigma_3d_chunk, seg_3d_chunk, is_occupied_chunk
del is_occupied, output_3d, fov_mask
torch.cuda.empty_cache()
return output_2d, glb_path
net_v1, renderer_v1, ray_sampler_v1 = None, None, None
net_v2, renderer_v2, ray_sampler_v2 = None, None, None
load_scenedino()
demo = gr.Interface(
demo_run,
inputs=[
gr.Image(label="Input image", type="filepath"),
gr.Radio(label="Backbone", choices=["DINO (ViT-B)", "DINOv2 (ViT-B)"]),
gr.Radio(label="Mode", choices=["Feature PCA 1-3", "Feature PCA 4-6", "Feature PCA 7-9", "SSC (unsup.)", "SSC (linear)"]),
gr.Slider(label="Density threshold", minimum=0, maximum=1, step=0.05, value=0.2),
gr.Slider(label="Resolution [m]", minimum=0.05, maximum=0.5, step=0.1, value=0.2),
gr.Slider(label="X Range [m]", minimum=1, maximum=50, step=1, value=10),
gr.Slider(label="Y Range [m]", minimum=1, maximum=50, step=1, value=10),
gr.Slider(label="Z Range [m]", minimum=1, maximum=100, step=1, value=20),
],
outputs=[
gr.Image(label="Rendered 2D Visualization"),
gr.Model3D(label="Voxel Surface 3D Visualization",
zoom_speed=0.5, pan_speed=0.5,
clear_color=[0.0, 0.0, 0.0, 0.0],
camera_position=[-90, 80, None],
display_mode="solid"),
],
title="SceneDINO Demo",
examples="demo_utils/examples",
)
demo.launch()