Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,028 Bytes
9e15541 f61bfbf 9e15541 f61bfbf 9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms.functional import center_crop, resize
import os
from hydra import compose, initialize
from scenedino.models.bts import BTSNet
from scenedino.common.ray_sampler import ImageRaySampler
from scenedino.models import make_model
from scenedino.renderer.nerf import NeRFRenderer
from scenedino.training.trainer import BTSWrapper
from scenedino.datasets import make_datasets
from scenedino.common.array_operations import map_fn, unsqueezer
device = "cuda"
def load_modules(
ckpt_path: str,
ckpt_name: str
) -> tuple[BTSNet, NeRFRenderer, ImageRaySampler, Dataset]:
"""
Loads relevant modules with a SceneDINO checkpoint (*.pt) and corresponding config (training_config.yaml) file.
Args:
ckpt_path (str): Relative path to the directory containing the files.
ckpt_name (str): File name of the checkpoint.
Returns:
net (BTSNet): The SceneDINO network.
renderer (NeRFRenderer): Volume rendering module.
ray_sampler (ImageRaySampler): Camera ray sampler for whole images.
test_dataset (Dataset): Test set of the dataset trained on.
"""
with initialize(version_base=None, config_path="../" + ckpt_path, job_name="demo_script"):
config = compose(config_name="training_config")
net = make_model(config["model"], config.get("downstream", None))
renderer = NeRFRenderer.from_conf(config["renderer"])
renderer.hard_alpha_cap = False
renderer = renderer.bind_parallel(net, gpus=None).eval()
height, width = config["dataset"]["image_size"]
ray_sampler = ImageRaySampler(z_near=3, z_far=80, width=width, height=height)
model = BTSWrapper(renderer, ray_sampler, config["model"])
cp = torch.load(ckpt_path + ckpt_name)
# cp = cp["model"] # Some older checkpoints have this structure
model.load_state_dict(cp, strict=False)
model = model.to(device)
# test_dataset = make_datasets(config["dataset"])[1]
return net, renderer, ray_sampler # , test_dataset
def load_sample_from_path(
path: str,
intrinsic: Tensor | None
) -> tuple[Tensor, Tensor, Tensor]:
"""
Loads a test image from a provided path.
Args:
path (str): Image path.
Returns:
images (Tensor): RGB image normalized to [-1, 1].
poses (Tensor): Camera pose (unit matrix).
projs (Tensor): Camera matrix (unit matrix).
"""
images = read_image(path)
if not (images.size(1) == 192 and images.size(2) == 640):
scale = max(192 / images.size(1), 640 / images.size(2))
new_h, new_w = int(images.size(1) * scale), int(images.size(2) * scale)
images_resized = resize(images, [new_h, new_w])
images = center_crop(images_resized, (192, 640))
print("WARNING: Custom image does not have correct dimensions! Taking center crop.")
if images.dtype == torch.uint8:
images = 2 * (images / 255) - 1
elif images.dtype == torch.uint16:
images = 2 * (images / (2**16 - 1)) - 1
if images.size(0) == 4:
images = images[:3]
images = images.unsqueeze(0).unsqueeze(1)
poses = torch.eye(4).unsqueeze(0).unsqueeze(1)
if intrinsic:
projs = intrinsic.unsqueeze(0).unsqueeze(1)
else:
projs = torch.Tensor([
[0.7849, 0.0000, -0.0312],
[0.0000, 2.9391, 0.2701],
[0.0000, 0.0000, 1.0000]]).unsqueeze(0).unsqueeze(1)
print("WARNING: Custom image has no provided intrinsics! Using KITTI-360 values.")
return images.to(device), poses.to(device), projs.to(device)
def load_sample_from_dataset(
idx: int,
dataset: Dataset
) -> tuple[Tensor, Tensor, Tensor]:
"""
Loads a data point from the provided dataset. In this demo, we just load the front view.
Args:
idx (int): Index in the dataset.
dataset (Dataset): The dataset.
Returns:
images (Tensor): RGB image normalized to [-1, 1].
poses (Tensor): Camera pose (since just front view, unit matrix).
projs (Tensor): Camera matrix.
"""
data = dataset[idx]
data_batch = map_fn(map_fn(data, torch.tensor), unsqueezer)
images = torch.stack(data_batch["imgs"], dim=1)
poses = torch.stack(data_batch["poses"], dim=1)
projs = torch.stack(data_batch["projs"], dim=1)
poses = torch.inverse(poses[:, :1, :, :]) @ poses
# Just front view
images = images[:, :1]
poses = poses[:, :1]
projs = projs[:, :1]
return images.to(device), poses.to(device), projs.to(device)
def inference_3d(
net: BTSNet,
x_range: tuple[float, float],
y_range: tuple[float, float],
z_range: tuple[float, float],
resolution: float,
prediction_mode: str = "stego_kmeans"
) -> tuple[Tensor, Tensor, Tensor]:
"""
Inference in a uniform 3D grid. All units are provided in meters.
Args:
net (BTSNet): The SceneDINO network.
x_range (tuple[float, float]): Range along the X dimension.
y_range (tuple[float, float]): Range along the Y dimension.
z_range (tuple[float, float]): Range along the Z dimension, the viewing direction.
resolution (float): Resolution of the grid.
Returns:
dino_full (Tensor): SceneDINO features [n_X, n_Y, n_Z, 768].
sigma (Tensor): Volumentric density [n_X, n_Y, n_Z].
seg (Tensor): Predicted semantic classes [n_X, n_Y, n_Z].
"""
n_pts_x = int((x_range[1] - x_range[0]) / resolution) + 1
n_pts_y = int((y_range[1] - y_range[0]) / resolution) + 1
n_pts_z = int((z_range[1] - z_range[0]) / resolution) + 1
x = torch.linspace(x_range[0], x_range[1], n_pts_x)
y = torch.linspace(y_range[0], y_range[1], n_pts_y)
z = torch.linspace(z_range[0], z_range[1], n_pts_z)
grid_x, grid_y, grid_z = torch.meshgrid(x, y, z, indexing='ij')
xyz = torch.stack((grid_x, grid_y, grid_z), dim=-1).reshape(-1, 3).unsqueeze(0).to(device)
dino_full, invalid, sigma, seg = net(xyz, predict_segmentation=True, prediction_mode=prediction_mode)
dino_full = dino_full.reshape(n_pts_x, n_pts_y, n_pts_z, -1)
sigma = sigma.reshape(n_pts_x, n_pts_y, n_pts_z)
if seg is not None:
seg = seg.reshape(n_pts_x, n_pts_y, n_pts_z, -1).argmax(-1)
return xyz, dino_full, sigma, seg
def get_fov_mask(proj_matrix, xyz):
proj_xyz = xyz @ proj_matrix.T
proj_xyz = proj_xyz / proj_xyz[..., 2:3]
fov_mask = (proj_xyz[..., 0] > -0.99) & (proj_xyz[..., 0] < 0.99) & (proj_xyz[..., 1] > -0.99) & (proj_xyz[..., 1] < 0.99)
return fov_mask
def inference_rendered_2d(
net: BTSNet,
poses: Tensor,
projs: Tensor,
ray_sampler: ImageRaySampler,
renderer: NeRFRenderer,
prediction_mode: str = "stego_kmeans"
) -> tuple[Tensor, Tensor, Tensor]:
"""
Inference in 3D, rendered back into a 2D image, based on a provided camera pose and matrix.
Args:
net (BTSNet): The SceneDINO network.
poses (Tensor): Camera pose.
projs (Tensor): Camera matrix.
ray_sampler (ImageRaySampler): Camera ray sampler for whole images.
renderer (NeRFRenderer): Volume rendering module.
Returns:
dino_full (Tensor): SceneDINO features [H, W, 768].
depth (Tensor): Ray termination depth [H, W].
seg (Tensor): Predicted semantic classes [H, W].
"""
all_rays, _ = ray_sampler.sample(None, poses[:, :], projs[:, :])
render_dict = renderer(all_rays, want_weights=True, want_alphas=True)
render_dict = ray_sampler.reconstruct(render_dict)
depth = render_dict["coarse"]["depth"].squeeze()
dino_distilled = render_dict["coarse"]["dino_features"].squeeze()
dino_full = net.encoder.expand_dim(dino_distilled)
if net.downstream_head is not None:
seg = net.downstream_head(dino_full, mode=prediction_mode)
else:
seg = None
return dino_full, depth, seg
|