Spaces:
Running
on
Zero
Running
on
Zero
from typing import Callable | |
import lpips | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import ignite.distributed as idst | |
from scenedino.common.geometry import distance_to_z | |
import scenedino.common.metrics as metrics | |
def create_depth_eval( | |
model: nn.Module, | |
scaling_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | |
| None = None, | |
): | |
def _compute_depth_metrics( | |
data, | |
# TODO: maybe integrate model | |
# model: nn.Module, | |
): | |
return metrics.compute_depth_metrics( | |
data["depths"][0], data["coarse"][0]["depth"][:, :1], scaling_function | |
) | |
return _compute_depth_metrics | |
def create_nvs_eval(model: nn.Module): | |
lpips_fn = lpips.LPIPS().to(idst.device()) | |
def _compute_nvs_metrics( | |
data, | |
# model: nn.Module, | |
): | |
return metrics.compute_nvs_metrics(data, lpips_fn) | |
return _compute_nvs_metrics | |
def create_dino_eval(model: nn.Module): | |
def _compute_dino_metrics( | |
data, | |
): | |
return metrics.compute_dino_metrics(data) | |
return _compute_dino_metrics | |
def create_seg_eval(model: nn.Module, n_classes: int, gt_classes: int): | |
def _compute_seg_metrics( | |
data, | |
): | |
return metrics.compute_seg_metrics(data, n_classes, gt_classes) # Why is this necessary? | |
return _compute_seg_metrics | |
def create_stego_eval(model: nn.Module): | |
def _compute_stego_metrics( | |
data, | |
): | |
return metrics.compute_stego_metrics(data) # Why is this necessary? | |
return _compute_stego_metrics | |
# code for saving voxel grid | |
# def pack(uncompressed): | |
# """convert a boolean array into a bitwise array.""" | |
# uncompressed_r = uncompressed.reshape(-1, 8) | |
# compressed = uncompressed_r.dot( | |
# 1 << np.arange(uncompressed_r.shape[-1] - 1, -1, -1) | |
# ) | |
# return compressed | |
# if self.save_bin_path: | |
# # base_file = "/storage/user/hank/methods_test/semantic-kitti-api/bts_test/sequences/00/voxels" | |
# outside_frustum = ( | |
# ( | |
# (cam_pts[:, 0] < -1.0) | |
# | (cam_pts[:, 0] > 1.0) | |
# | (cam_pts[:, 1] < -1.0) | |
# | (cam_pts[:, 0] > 1.0) | |
# ) | |
# .reshape(q_pts_shape) | |
# .permute(1, 2, 0) | |
# .detach() | |
# .cpu() | |
# .numpy() | |
# ) | |
# is_occupied_numpy = ( | |
# is_occupied_pred.reshape(q_pts_shape) | |
# .permute(1, 2, 0) | |
# .detach() | |
# .cpu() | |
# .numpy() | |
# .astype(np.float32) | |
# ) | |
# is_occupied_numpy[outside_frustum] = 0.0 | |
# ## carving out the invisible regions out of view-frustum | |
# # for i_ in range( | |
# # (is_occupied_numpy.shape[0]) // 2 | |
# # ): ## left | right half of the space | |
# # for j_ in range(i_ + 1): | |
# # is_occupied_numpy[i_, j_] = 0 | |
# pack(np.flip(is_occupied_numpy, (0, 1, 2)).reshape(-1)).astype( | |
# np.uint8 | |
# ).tofile( | |
# # f"{base_file}/{self.counter:0>6}.bin" | |
# f"{self.save_bin_path}/{self.counter:0>6}.bin" | |
# ) | |
# # for idx_i, image in enumerate(images[0]): | |
# # torchvision.utils.save_image( | |
# # image, f"{self.save_bin_path}/{self.counter:0>6}_{idx_i}.png" | |
# # ) | |
def project_into_cam(pts, proj, pose): | |
pts = torch.cat((pts, torch.ones_like(pts[:, :1])), dim=-1) | |
cam_pts = (proj @ (torch.inverse(pose).squeeze()[:3, :] @ pts.T)).T | |
cam_pts[:, :2] /= cam_pts[:, 2:3] | |
dist = cam_pts[:, 2] | |
return cam_pts, dist | |
def create_occ_eval( | |
model: nn.Module, | |
occ_threshold: float, | |
query_batch_size: int, | |
): | |
# TODO: deal with other models such as IBRnet | |
def _compute_occ_metrics( | |
data, | |
): | |
projs = torch.stack(data["projs"], dim=1) | |
images = torch.stack(data["imgs"], dim=1) | |
_, _, _, h, w = images.shape | |
poses = torch.stack(data["poses"], dim=1) | |
device = poses.device | |
# TODO: get occ points and occupation from dataset | |
occ_pts = data["occ_pts"].permute(0, 2, 1, 3).contiguous() | |
occ_pts = occ_pts.to(device).view(-1, 3) | |
pred_depth = distance_to_z(data["coarse"]["depth"], projs[:1, :1]) | |
# is visible? Check whether point is closer than the computed pseudo depth | |
cam_pts, dists = project_into_cam(occ_pts, projs[0, 0], poses[0, 0]) | |
pred_dist = F.grid_sample( | |
pred_depth.view(1, 1, h, w), | |
cam_pts[:, :2].view(1, 1, -1, 2), | |
mode="nearest", | |
padding_mode="border", | |
align_corners=True, | |
).view(-1) | |
is_visible_pred = dists <= pred_dist | |
depth_plus4meters = False | |
if depth_plus4meters: | |
mask = (dists >= pred_dist) & (dists < pred_dist + 4) | |
densities = torch.zeros_like(occ_pts[..., 0]) | |
densities[mask] = 1.0 | |
is_occupied_pred = densities > occ_threshold | |
else: | |
# Query the density of the query points from the density field | |
densities = [] | |
for i_from in range(0, len(occ_pts), query_batch_size): | |
i_to = min(i_from + query_batch_size, len(occ_pts)) | |
q_pts_ = occ_pts[i_from:i_to] | |
_, _, densities_, _ = model( | |
q_pts_.unsqueeze(0), only_density=True | |
) ## ! occupancy estimation | |
densities.append(densities_.squeeze(0)) | |
densities = torch.cat(densities, dim=0).squeeze() | |
is_occupied_pred = densities > occ_threshold | |
is_occupied = data["is_occupied"] | |
is_visible = data["is_visible"] | |
return metrics.compute_occ_metrics(is_occupied_pred, is_occupied, is_visible) | |
return _compute_occ_metrics | |
def make_eval_fn( | |
model: nn.Module, | |
conf, | |
): | |
eval_type = conf["type"] | |
eval_fn = globals().get(f"create_{eval_type}_eval", None) | |
if eval_fn: | |
if conf.get("args", None): | |
return eval_fn(model, **conf["args"]) | |
else: | |
return eval_fn(model) | |
else: | |
return None | |