Spaces:
Running
on
Zero
Running
on
Zero
from math import isqrt | |
import torch | |
import torch.nn.functional as F | |
from torch.cuda.amp import autocast | |
from omegaconf import ListConfig | |
from scenedino.common import util | |
from scenedino.common.cameras.pinhole import outside_frustum, project_to_image, pts_into_camera | |
class RaySampler: | |
def __init__(self, z_near: float, z_far: float) -> None: | |
self.z_near = z_near | |
self.z_far = z_far | |
def sample(self, images, poses, projs): | |
raise NotImplementedError | |
def reconstruct(self, render_dict): | |
raise NotImplementedError | |
class RandomRaySampler(RaySampler): | |
def __init__( | |
self, z_near: float, z_far: float, ray_batch_size: int, channels: int = 3 | |
) -> None: | |
super().__init__(z_near, z_far) | |
self.ray_batch_size = ray_batch_size | |
self.channels = channels | |
def sample(self, images, poses, projs, image_ids=None): | |
n, v, c, h, w = images.shape | |
all_rgb_gt = [] | |
all_rays = [] | |
for n_ in range(n): | |
focals = projs[n_, :, [0, 1], [0, 1]] | |
centers = projs[n_, :, [0, 1], [2, 2]] | |
rays, xy = util.gen_rays( | |
poses[n_].view(-1, 4, 4), | |
w, | |
h, | |
focal=focals, | |
c=centers, | |
z_near=self.z_near, | |
z_far=self.z_far, | |
) | |
# Append frame id to the ray | |
if image_ids is None: | |
ids = torch.arange(v, device=images.device, dtype=images.dtype) | |
else: | |
ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype) | |
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
rays = torch.cat((rays, ids), dim=-1) | |
rays = torch.cat((rays, xy), dim=-1) | |
r_dim = rays.shape[-1] | |
rays = rays.view(-1, r_dim) | |
rgb_gt = images[n_].view(-1, c, h, w) | |
rgb_gt = rgb_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, c) | |
pix_inds = torch.randint(0, v * h * w, (self.ray_batch_size,)) | |
rgb_gt = rgb_gt[pix_inds] | |
rays = rays[pix_inds] | |
all_rgb_gt.append(rgb_gt) | |
all_rays.append(rays) | |
all_rgb_gt = torch.stack(all_rgb_gt) | |
all_rays = torch.stack(all_rays) | |
return all_rays, all_rgb_gt | |
def reconstruct(self, render_dict, channels=None): | |
for name, render_dict_part in render_dict.items(): | |
if not type(render_dict_part) == dict or not "rgb" in render_dict_part: | |
continue | |
if channels is None: | |
channels = self.channels | |
rgb = render_dict_part["rgb"] # n, n_pts, v * 3 | |
depth = render_dict_part["depth"] | |
invalid = render_dict_part["invalid"] | |
rgb_gt = render_dict["rgb_gt"] | |
n, n_pts, v_c = rgb.shape | |
v = v_c // channels | |
n_smps = invalid.shape[-2] | |
render_dict_part["rgb"] = rgb.view(n, n_pts, v, channels) | |
render_dict_part["depth"] = depth.view(n, n_pts) | |
render_dict_part["invalid"] = invalid.view(n, n_pts, n_smps, v) | |
if "invalid_features" in render_dict_part: | |
invalid_features = render_dict_part["invalid_features"] | |
render_dict_part["invalid_features"] = invalid_features.view(n, n_pts, n_smps, v) | |
if "weights" in render_dict_part: | |
weights = render_dict_part["weights"] | |
render_dict_part["weights"] = weights.view(n, n_pts, n_smps) | |
if "alphas" in render_dict_part: | |
alphas = render_dict_part["alphas"] | |
render_dict_part["alphas"] = alphas.view(n, n_pts, n_smps) | |
if "z_samps" in render_dict_part: | |
z_samps = render_dict_part["z_samps"] | |
render_dict_part["z_samps"] = z_samps.view(n, n_pts, n_smps) | |
if "rgb_samps" in render_dict_part: | |
rgb_samps = render_dict_part["rgb_samps"] | |
render_dict_part["rgb_samps"] = rgb_samps.view(n, n_pts, n_smps, v, channels) | |
if "ray_info" in render_dict_part: | |
ri_shape = render_dict_part["ray_info"].shape[-1] | |
ray_info = render_dict_part["ray_info"] | |
render_dict_part["ray_info"] = ray_info.view(n, n_pts, ri_shape) | |
if "extras" in render_dict_part: | |
extras_shape = render_dict_part["extras"].shape[-1] | |
extras = render_dict_part["extras"] | |
render_dict_part["extras"] = extras.view(n, n_pts, extras_shape) | |
render_dict[name] = render_dict_part | |
render_dict["rgb_gt"] = rgb_gt.view(n, n_pts, channels) | |
return render_dict | |
class PatchRaySampler(RaySampler): | |
def __init__( | |
self, | |
z_near: float, | |
z_far: float, | |
ray_batch_size: int, | |
patch_size: int, | |
channels: int = 3, | |
snap_to_grid: bool = False, | |
dino_upscaled: bool = False, | |
) -> None: | |
super().__init__(z_near, z_far) | |
self.ray_batch_size = ray_batch_size | |
self.channels = channels | |
self.snap_to_grid = snap_to_grid | |
self.dino_upscaled = dino_upscaled | |
if isinstance(patch_size, int): | |
self.patch_size_x, self.patch_size_y = patch_size, patch_size | |
elif ( | |
isinstance(patch_size, tuple) | |
or isinstance(patch_size, list) | |
or isinstance(patch_size, ListConfig) | |
): | |
self.patch_size_y = patch_size[0] | |
self.patch_size_x = patch_size[1] | |
else: | |
raise ValueError(f"Invalid format for patch size") | |
assert (ray_batch_size % (self.patch_size_x * self.patch_size_y)) == 0 | |
self._patch_count = self.ray_batch_size // ( | |
self.patch_size_x * self.patch_size_y | |
) | |
def sample( | |
self, images, poses, projs, image_ids=None, dino_features=None, loss_feature_grid_shift=None, | |
): ### dim(images) == nv (ids_loss nv randomly sampled) | |
n, v, c, h, w = images.shape | |
self.channels = c | |
if dino_features is not None: | |
_, _, dino_channels, _, _ = dino_features.shape | |
dino_features = dino_features.permute(0, 1, 3, 4, 2) | |
device = images.device | |
images = images.permute(0, 1, 3, 4, 2) | |
all_rgb_gt, all_rays, all_dino_gt = [], [], [] | |
for n_ in range(n): | |
focals = projs[n_, :, [0, 1], [0, 1]] | |
centers = projs[n_, :, [0, 1], [2, 2]] | |
rays, xy = util.gen_rays( | |
poses[n_].view(-1, 4, 4), | |
w, | |
h, | |
focal=focals, | |
c=centers, | |
z_near=self.z_near, | |
z_far=self.z_far, | |
) | |
# Append frame id to the ray | |
if image_ids is None: | |
ids = torch.arange(v, device=images.device, dtype=images.dtype) | |
else: | |
ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype) | |
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
rays = torch.cat((rays, ids), dim=-1) | |
rays = torch.cat((rays, xy), dim=-1) | |
r_dim = rays.shape[-1] | |
patch_coords_v = torch.randint(0, v, (self._patch_count,)) | |
if self.snap_to_grid: | |
if loss_feature_grid_shift is not None: | |
patch_coords_y = torch.randint(0, h // self.patch_size_y - 1, (self._patch_count,)) | |
patch_coords_x = torch.randint(0, w // self.patch_size_x - 1, (self._patch_count,)) | |
else: | |
patch_coords_y = torch.randint(0, h // self.patch_size_y, (self._patch_count,)) | |
patch_coords_x = torch.randint(0, w // self.patch_size_x, (self._patch_count,)) | |
else: | |
patch_coords_y = torch.randint(0, h - self.patch_size_y, (self._patch_count,)) | |
patch_coords_x = torch.randint(0, w - self.patch_size_x, (self._patch_count,)) | |
sample_rgb_gt = [] | |
sample_rays = [] | |
sample_dino_gt = [] | |
for v_, coord_y, coord_x in zip(patch_coords_v, patch_coords_y, patch_coords_x): | |
if self.snap_to_grid: | |
patch_y, patch_x = coord_y, coord_x | |
if loss_feature_grid_shift is not None: | |
y = (loss_feature_grid_shift[0] % self.patch_size_y) + self.patch_size_y * coord_y | |
x = (loss_feature_grid_shift[1] % self.patch_size_x) + self.patch_size_x * coord_x | |
if loss_feature_grid_shift[0] < 0: | |
patch_y += 1 | |
if loss_feature_grid_shift[1] < 0: | |
patch_x += 1 | |
else: | |
y = self.patch_size_y * coord_y | |
x = self.patch_size_x * coord_x | |
else: | |
raise NotImplementedError | |
rgb_gt_patch = images[n_][ | |
v_, y : y + self.patch_size_y, x : x + self.patch_size_x, : | |
].reshape(-1, self.channels) | |
rays_patch = rays[ | |
v_, y : y + self.patch_size_y, x : x + self.patch_size_x, : | |
].reshape(-1, r_dim) | |
sample_rgb_gt.append(rgb_gt_patch) | |
sample_rays.append(rays_patch) | |
if dino_features is not None: | |
if self.dino_upscaled: | |
dino_gt_patch = dino_features[n_][ | |
v_, y: y + self.patch_size_y, x: x + self.patch_size_x, : | |
].reshape(-1, dino_channels) | |
else: | |
dino_gt_patch = dino_features[n_][ | |
v_, patch_y, patch_x, : | |
].reshape(-1, dino_channels) | |
sample_dino_gt.append(dino_gt_patch) | |
sample_rgb_gt = torch.cat(sample_rgb_gt, dim=0) | |
sample_rays = torch.cat(sample_rays, dim=0) | |
all_rgb_gt.append(sample_rgb_gt) | |
all_rays.append(sample_rays) | |
if dino_features is not None: | |
sample_dino_gt = torch.cat(sample_dino_gt, dim=0) | |
all_dino_gt.append(sample_dino_gt) | |
all_rgb_gt = torch.stack(all_rgb_gt) | |
all_rays = torch.stack(all_rays) | |
if dino_features is not None: | |
all_dino_gt = torch.stack(all_dino_gt) | |
return all_rays, all_rgb_gt, all_dino_gt | |
else: | |
return all_rays, all_rgb_gt | |
def reconstruct(self, render_dict, channels=None, dino_channels=None): | |
for name, render_dict_part in render_dict.items(): | |
if not type(render_dict_part) == dict or not "rgb" in render_dict_part: | |
continue | |
if channels is None: | |
channels = self.channels | |
rgb_gt = render_dict["rgb_gt"] | |
dino_gt = render_dict["dino_gt"] | |
n, n_pts, v_c = render_dict_part["rgb"].shape | |
v = v_c // channels | |
n_smps = render_dict_part["weights"].shape[-1] | |
# (This can be a different v from the sample method) | |
render_dict_part["rgb"] = render_dict_part["rgb"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, v, channels | |
) | |
render_dict_part["weights"] = render_dict_part["weights"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps | |
) | |
render_dict_part["depth"] = render_dict_part["depth"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x | |
) | |
render_dict_part["invalid"] = render_dict_part["invalid"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps, v | |
) | |
# TODO: Figure out DINO invalid policy | |
# if "invalid_features" in render_dict_part: | |
# render_dict_part["invalid_features"] = render_dict_part["invalid_features"].view( | |
# n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps, v | |
# ) | |
if "alphas" in render_dict_part: | |
render_dict_part["alphas"] = render_dict_part["alphas"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps | |
) | |
if "z_samps" in render_dict_part: | |
render_dict_part["z_samps"] = render_dict_part["z_samps"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps | |
) | |
if "rgb_samps" in render_dict_part: | |
render_dict_part["rgb_samps"] = render_dict_part["rgb_samps"].view( | |
n, | |
self._patch_count, | |
self.patch_size_y, | |
self.patch_size_x, | |
n_smps, | |
v, | |
channels, | |
) | |
if "ray_info" in render_dict_part: | |
ri_shape = render_dict_part["ray_info"].shape[-1] | |
render_dict_part["ray_info"] = render_dict_part["ray_info"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, ri_shape | |
) | |
if "extras" in render_dict_part: | |
extras_shape = render_dict_part["extras"].shape[-1] | |
render_dict_part["extras"] = render_dict_part["extras"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, extras_shape | |
) | |
if "dino_features" in render_dict_part: | |
dino_shape = render_dict_part["dino_features"].shape[-1] | |
render_dict_part["dino_features"] = render_dict_part["dino_features"].view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, 1, dino_shape | |
) | |
render_dict[name] = render_dict_part | |
render_dict["rgb_gt"] = rgb_gt.view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, channels | |
) | |
dino_gt_shape = dino_gt.shape[-1] | |
if self.dino_upscaled: | |
render_dict["dino_gt"] = dino_gt.view( | |
n, self._patch_count, self.patch_size_y, self.patch_size_x, dino_gt_shape | |
) | |
else: | |
render_dict["dino_gt"] = dino_gt.view( | |
n, self._patch_count, dino_gt_shape | |
) | |
if "dino_artifacts" in render_dict: | |
render_dict["dino_artifacts"] = render_dict["dino_artifacts"].view( | |
n, self._patch_count, dino_gt_shape | |
) | |
return render_dict | |
class PointBasedRaySampler(RandomRaySampler): | |
def sample( | |
self, images, poses, projs, xyz, image_ids=None | |
): ### dim(images) == nv (ids_loss nv randomly sampled) | |
n, v, c, h, w = images.shape | |
assert v == 1 | |
with autocast(enabled=False): | |
poses_w2c = torch.inverse(poses) | |
inv_K = torch.inverse(projs[:, :, :3, :3]) | |
B, n_pts, _ = xyz.shape | |
xyz_projected = pts_into_camera(xyz, poses_w2c) | |
distance = torch.norm(xyz_projected, dim=-2, keepdim=True) | |
xy, z = project_to_image(xyz_projected, projs) | |
xy = xy[:, 0] | |
z = z[:, 0] | |
distance = distance[:, 0, 0, :, None] | |
# For numerical stability with AMP. Should not affect training outcome | |
xy = xy.clamp(-2, 2) | |
# Build rays | |
cam_centers = poses[:, 0, None, :3, 3].expand(-1, n_pts, -1) | |
ray_dir = ((poses[:, 0, :3, :3] @ inv_K[:, 0]) @ torch.cat((xy, torch.ones_like(xy[..., :1])), dim=-1).permute(0, 2, 1)).permute(0, 2, 1) | |
cam_nears = torch.ones_like(cam_centers[..., :1]) * self.z_near | |
cam_fars = torch.ones_like(cam_centers[..., :1]) * self.z_far | |
ids = torch.zeros_like(cam_nears) | |
rays = torch.cat((cam_centers, ray_dir, cam_nears, cam_fars, ids, xy, distance), dim=-1) | |
rgb_gt = F.grid_sample(images[:, 0], xy.reshape(n, -1, 1, 2).to(images.dtype), padding_mode="border", align_corners=False)[..., 0].permute(0, 2, 1) | |
return rays, rgb_gt | |
class ImageRaySampler(RaySampler): | |
def __init__( | |
self, | |
z_near: float, | |
z_far: float, | |
height: int | None = None, | |
width: int | None = None, | |
channels: int = 3, | |
norm_dir: bool = True, | |
dino_upscaled: bool = False, | |
) -> None: | |
super().__init__(z_near, z_far) | |
self.height = height | |
self.width = width | |
self.channels = channels | |
self.norm_dir = norm_dir | |
self.dino_upscaled = dino_upscaled | |
def sample(self, images, poses, projs, image_ids=None, dino_features=None, dino_artifacts=None): | |
n, v, _, _ = poses.shape | |
device = poses.device | |
dtype = poses.dtype | |
if images is not None: | |
self.channels = images.shape[2] | |
if self.height is None: | |
self.height, self.width = images.shape[-2:] | |
if dino_features is not None: | |
_, _, dino_channels, _, _ = dino_features.shape | |
h = self.height | |
w = self.width | |
all_rgb_gt = [] | |
all_dino_gt = [] | |
all_rays = [] | |
for n_ in range(n): | |
focals = projs[n_, :, [0, 1], [0, 1]] | |
centers = projs[n_, :, [0, 1], [2, 2]] | |
rays, xy = util.gen_rays( | |
poses[n_].view(-1, 4, 4), | |
self.width, | |
self.height, | |
focal=focals, | |
c=centers, | |
z_near=self.z_near, | |
z_far=self.z_far, | |
norm_dir=self.norm_dir, | |
) | |
# Append frame id to the ray | |
if image_ids is None: | |
ids = torch.arange(v, device=device, dtype=dtype) | |
else: | |
ids = torch.tensor(image_ids, device=device, dtype=dtype) | |
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
rays = torch.cat((rays, ids), dim=-1) | |
rays = torch.cat((rays, xy), dim=-1) | |
r_dim = rays.shape[-1] | |
rays = rays.view(-1, r_dim) | |
all_rays.append(rays) | |
if images is not None: | |
rgb_gt = images[n_].view(-1, self.channels, self.height, self.width) | |
rgb_gt = ( | |
rgb_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, self.channels) | |
) | |
all_rgb_gt.append(rgb_gt) | |
if dino_features is not None: | |
patch_h, patch_w = dino_features[n_].shape[-2], dino_features[n_].shape[-1] | |
dino_gt = dino_features[n_].view(-1, dino_channels, patch_h, patch_w) | |
dino_gt = ( | |
dino_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, dino_channels) | |
) | |
all_dino_gt.append(dino_gt) | |
all_rays = torch.stack(all_rays) | |
if images is not None: | |
all_rgb_gt = torch.stack(all_rgb_gt) | |
else: | |
all_rgb_gt = None | |
if dino_features is not None: | |
all_dino_gt = torch.stack(all_dino_gt) | |
return all_rays, all_rgb_gt, all_dino_gt | |
else: | |
return all_rays, all_rgb_gt | |
def reconstruct(self, render_dict, channels=None, dino_channels=None): | |
for name, render_dict_part in render_dict.items(): | |
if not type(render_dict_part) == dict or not "rgb" in render_dict_part: | |
continue | |
if channels is None: | |
channels = self.channels | |
rgb = render_dict_part["rgb"] # n, n_pts, v * 3 | |
weights = render_dict_part["weights"] | |
depth = render_dict_part["depth"] | |
invalid = render_dict_part["invalid"] | |
n, n_pts, v_c = rgb.shape | |
v_in = n_pts // (self.height * self.width) | |
v_render = v_c // channels | |
n_smps = weights.shape[-1] | |
# (This can be a different v from the sample method) | |
render_dict_part["rgb"] = rgb.view(n, v_in, self.height, self.width, v_render, channels) | |
render_dict_part["weights"] = weights.view(n, v_in, self.height, self.width, n_smps) | |
render_dict_part["depth"] = depth.view(n, v_in, self.height, self.width) | |
render_dict_part["invalid"] = invalid.view( | |
n, v_in, self.height, self.width, n_smps, v_render | |
) | |
if "invalid_features" in render_dict_part: | |
invalid_features = render_dict_part["invalid_features"] | |
render_dict_part["invalid_features"] = invalid_features.view( | |
n, v_in, self.height, self.width, n_smps, v_render | |
) | |
if "alphas" in render_dict_part: | |
alphas = render_dict_part["alphas"] | |
render_dict_part["alphas"] = alphas.view(n, v_in, self.height, self.width, n_smps) | |
if "z_samps" in render_dict_part: | |
z_samps = render_dict_part["z_samps"] | |
render_dict_part["z_samps"] = z_samps.view( | |
n, v_in, self.height, self.width, n_smps | |
) | |
if "rgb_samps" in render_dict_part: | |
rgb_samps = render_dict_part["rgb_samps"] | |
render_dict_part["rgb_samps"] = rgb_samps.view( | |
n, v_in, self.height, self.width, n_smps, v_render, channels | |
) | |
if "ray_info" in render_dict_part: | |
ri_shape = render_dict_part["ray_info"].shape[-1] | |
ray_info = render_dict_part["ray_info"] | |
render_dict_part["ray_info"] = ray_info.view(n, v_in, self.height, self.width, ri_shape) | |
if "extras" in render_dict_part: | |
ex_shape = render_dict_part["extras"].shape[-1] | |
extras = render_dict_part["extras"] | |
render_dict_part["extras"] = extras.view(n, v_in, self.height, self.width, ex_shape) | |
if "dino_features" in render_dict_part: | |
dino_shape = render_dict_part["dino_features"].shape[-1] | |
dino = render_dict_part["dino_features"] | |
render_dict_part["dino_features"] = dino.view(n, v_in, self.height, self.width, 1, dino_shape) | |
render_dict[name] = render_dict_part | |
if "rgb_gt" in render_dict: | |
rgb_gt = render_dict["rgb_gt"] | |
render_dict["rgb_gt"] = rgb_gt.view( | |
n, v_in, self.height, self.width, channels | |
) | |
if "dino_gt" in render_dict: | |
dino_gt = render_dict["dino_gt"] | |
dino_gt_shape = dino_gt.shape[-1] | |
if self.dino_upscaled: | |
render_dict["dino_gt"] = dino_gt.view( | |
n, v_in, self.height, self.width, dino_gt_shape | |
) | |
else: | |
# TODO: patch size should not be inferred like this, but parameter | |
patch_size = isqrt((n * v_in * self.height * self.width * dino_gt_shape) // dino_gt.numel()) | |
render_dict["dino_gt"] = dino_gt.view( | |
n, v_in, self.height // patch_size, self.width // patch_size, dino_gt_shape | |
) | |
if "dino_artifacts" in render_dict: | |
dino_artifacts = render_dict["dino_artifacts"] | |
# TODO: patch size should not be inferred like this, but parameter | |
patch_size = isqrt((n * v_in * self.height * self.width * dino_gt_shape) // dino_artifacts.numel()) | |
render_dict["dino_artifacts"] = dino_artifacts.view( | |
n, v_in, self.height // patch_size, self.width // patch_size, dino_gt_shape | |
) | |
return render_dict | |
class JitteredPatchRaySampler(PatchRaySampler): | |
def __init__( | |
self, | |
z_near: float, | |
z_far: float, | |
ray_batch_size: int, | |
patch_size: int, | |
jitter_strength: float, # In pixels, max [0, 1) | |
channels: int = 3, | |
) -> None: | |
super().__init__(z_near, z_far, ray_batch_size, patch_size, channels) | |
assert 0 <= jitter_strength < 1, "Jitter strength is invalid." | |
self.jitter_strength = jitter_strength | |
x = torch.arange(0, self.patch_size_x).view(1, 1, -1, 1).expand(-1, self.patch_size_y, -1, -1) | |
y = torch.arange(0, self.patch_size_y).view(1, -1, 1, 1).expand(-1, -1, self.patch_size_x, -1) | |
self._grid = torch.cat((x, y), dim=-1) | |
def sample( | |
self, images, poses, projs, image_ids=None | |
): ### dim(images) == nv (ids_loss nv randomly sampled) | |
n, v, c, h, w = images.shape | |
device = images.device | |
all_rgb_gt, all_rays = [], [] | |
xy_offset = ((torch.rand(2) - .5) * self.jitter_strength) | |
for n_ in range(n): | |
focals = projs[n_, :, [0, 1], [0, 1]] | |
centers = projs[n_, :, [0, 1], [2, 2]] | |
rays, xy = util.gen_rays( | |
poses[n_].view(-1, 4, 4), | |
w, | |
h, | |
focal=focals, | |
c=centers, | |
z_near=self.z_near, | |
z_far=self.z_far, | |
xy_offset=xy_offset, | |
) | |
# Append frame id to the ray | |
if image_ids is None: | |
ids = torch.arange(v, device=images.device, dtype=images.dtype) | |
else: | |
ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype) | |
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1) | |
rays = torch.cat((rays, ids), dim=-1) | |
r_dim = rays.shape[-1] | |
patch_coords_v = torch.randint(0, v, (self._patch_count,)) | |
patch_coords_y = torch.randint( | |
0, h - self.patch_size_y, (self._patch_count,) | |
) | |
patch_coords_x = torch.randint( | |
0, w - self.patch_size_x, (self._patch_count,) | |
) | |
sample_rgb_gt = [] | |
sample_rays = [] | |
for v_, y, x in zip(patch_coords_v, patch_coords_y, patch_coords_x): | |
xy_patch = torch.tensor((x, y)).view(1, 1, 1, 2) | |
patch_grid = self._grid + xy_patch + xy_offset.view(1, 1, 1, 2) + .5 | |
patch_grid = patch_grid.to(images.device) | |
patch_grid[..., 0] = (patch_grid[..., 0] / w) * 2 - 1 | |
patch_grid[..., 1] = (patch_grid[..., 1] / h) * 2 - 1 | |
rgb_gt_patch = F.grid_sample(images[n_:n_+1, v_], patch_grid, padding_mode="border", align_corners=False) | |
rgb_gt_patch = rgb_gt_patch.permute(0, 2, 3, 1).reshape(-1, self.channels) | |
rays_patch = rays[ | |
v_, y : y + self.patch_size_y, x : x + self.patch_size_x, : | |
].reshape(-1, r_dim) | |
sample_rgb_gt.append(rgb_gt_patch) | |
sample_rays.append(rays_patch) | |
sample_rgb_gt = torch.cat(sample_rgb_gt, dim=0) | |
sample_rays = torch.cat(sample_rays, dim=0) | |
all_rgb_gt.append(sample_rgb_gt) | |
all_rays.append(sample_rays) | |
all_rgb_gt = torch.stack(all_rgb_gt) | |
all_rays = torch.stack(all_rays) | |
return all_rays, all_rgb_gt | |
def get_ray_sampler(config) -> RaySampler: | |
z_near = config["z_near"] | |
z_far = config["z_far"] | |
sample_mode = config.get("sample_mode", "random") | |
# TODO: check channel size | |
match sample_mode: | |
case "random": | |
return RandomRaySampler(z_near, z_far, **config["args"]) | |
case "patch": | |
return PatchRaySampler(z_near, z_far, **config["args"]) | |
case "jitteredpatch": | |
return JitteredPatchRaySampler(z_near, z_far, **config["args"]) | |
case "image": | |
return ImageRaySampler(z_near, z_far) | |
case _: | |
raise NotImplementedError | |