from typing import Callable import lpips import torch import torch.nn as nn import torch.nn.functional as F import ignite.distributed as idst from scenedino.common.geometry import distance_to_z import scenedino.common.metrics as metrics def create_depth_eval( model: nn.Module, scaling_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, ): def _compute_depth_metrics( data, # TODO: maybe integrate model # model: nn.Module, ): return metrics.compute_depth_metrics( data["depths"][0], data["coarse"][0]["depth"][:, :1], scaling_function ) return _compute_depth_metrics def create_nvs_eval(model: nn.Module): lpips_fn = lpips.LPIPS().to(idst.device()) def _compute_nvs_metrics( data, # model: nn.Module, ): return metrics.compute_nvs_metrics(data, lpips_fn) return _compute_nvs_metrics def create_dino_eval(model: nn.Module): def _compute_dino_metrics( data, ): return metrics.compute_dino_metrics(data) return _compute_dino_metrics def create_seg_eval(model: nn.Module, n_classes: int, gt_classes: int): def _compute_seg_metrics( data, ): return metrics.compute_seg_metrics(data, n_classes, gt_classes) # Why is this necessary? return _compute_seg_metrics def create_stego_eval(model: nn.Module): def _compute_stego_metrics( data, ): return metrics.compute_stego_metrics(data) # Why is this necessary? return _compute_stego_metrics # code for saving voxel grid # def pack(uncompressed): # """convert a boolean array into a bitwise array.""" # uncompressed_r = uncompressed.reshape(-1, 8) # compressed = uncompressed_r.dot( # 1 << np.arange(uncompressed_r.shape[-1] - 1, -1, -1) # ) # return compressed # if self.save_bin_path: # # base_file = "/storage/user/hank/methods_test/semantic-kitti-api/bts_test/sequences/00/voxels" # outside_frustum = ( # ( # (cam_pts[:, 0] < -1.0) # | (cam_pts[:, 0] > 1.0) # | (cam_pts[:, 1] < -1.0) # | (cam_pts[:, 0] > 1.0) # ) # .reshape(q_pts_shape) # .permute(1, 2, 0) # .detach() # .cpu() # .numpy() # ) # is_occupied_numpy = ( # is_occupied_pred.reshape(q_pts_shape) # .permute(1, 2, 0) # .detach() # .cpu() # .numpy() # .astype(np.float32) # ) # is_occupied_numpy[outside_frustum] = 0.0 # ## carving out the invisible regions out of view-frustum # # for i_ in range( # # (is_occupied_numpy.shape[0]) // 2 # # ): ## left | right half of the space # # for j_ in range(i_ + 1): # # is_occupied_numpy[i_, j_] = 0 # pack(np.flip(is_occupied_numpy, (0, 1, 2)).reshape(-1)).astype( # np.uint8 # ).tofile( # # f"{base_file}/{self.counter:0>6}.bin" # f"{self.save_bin_path}/{self.counter:0>6}.bin" # ) # # for idx_i, image in enumerate(images[0]): # # torchvision.utils.save_image( # # image, f"{self.save_bin_path}/{self.counter:0>6}_{idx_i}.png" # # ) def project_into_cam(pts, proj, pose): pts = torch.cat((pts, torch.ones_like(pts[:, :1])), dim=-1) cam_pts = (proj @ (torch.inverse(pose).squeeze()[:3, :] @ pts.T)).T cam_pts[:, :2] /= cam_pts[:, 2:3] dist = cam_pts[:, 2] return cam_pts, dist def create_occ_eval( model: nn.Module, occ_threshold: float, query_batch_size: int, ): # TODO: deal with other models such as IBRnet def _compute_occ_metrics( data, ): projs = torch.stack(data["projs"], dim=1) images = torch.stack(data["imgs"], dim=1) _, _, _, h, w = images.shape poses = torch.stack(data["poses"], dim=1) device = poses.device # TODO: get occ points and occupation from dataset occ_pts = data["occ_pts"].permute(0, 2, 1, 3).contiguous() occ_pts = occ_pts.to(device).view(-1, 3) pred_depth = distance_to_z(data["coarse"]["depth"], projs[:1, :1]) # is visible? Check whether point is closer than the computed pseudo depth cam_pts, dists = project_into_cam(occ_pts, projs[0, 0], poses[0, 0]) pred_dist = F.grid_sample( pred_depth.view(1, 1, h, w), cam_pts[:, :2].view(1, 1, -1, 2), mode="nearest", padding_mode="border", align_corners=True, ).view(-1) is_visible_pred = dists <= pred_dist depth_plus4meters = False if depth_plus4meters: mask = (dists >= pred_dist) & (dists < pred_dist + 4) densities = torch.zeros_like(occ_pts[..., 0]) densities[mask] = 1.0 is_occupied_pred = densities > occ_threshold else: # Query the density of the query points from the density field densities = [] for i_from in range(0, len(occ_pts), query_batch_size): i_to = min(i_from + query_batch_size, len(occ_pts)) q_pts_ = occ_pts[i_from:i_to] _, _, densities_, _ = model( q_pts_.unsqueeze(0), only_density=True ) ## ! occupancy estimation densities.append(densities_.squeeze(0)) densities = torch.cat(densities, dim=0).squeeze() is_occupied_pred = densities > occ_threshold is_occupied = data["is_occupied"] is_visible = data["is_visible"] return metrics.compute_occ_metrics(is_occupied_pred, is_occupied, is_visible) return _compute_occ_metrics def make_eval_fn( model: nn.Module, conf, ): eval_type = conf["type"] eval_fn = globals().get(f"create_{eval_type}_eval", None) if eval_fn: if conf.get("args", None): return eval_fn(model, **conf["args"]) else: return eval_fn(model) else: return None