SceneDINO / scenedino /common /ray_sampler.py
jev-aleks's picture
scenedino init
9e15541
from math import isqrt
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from omegaconf import ListConfig
from scenedino.common import util
from scenedino.common.cameras.pinhole import outside_frustum, project_to_image, pts_into_camera
class RaySampler:
def __init__(self, z_near: float, z_far: float) -> None:
self.z_near = z_near
self.z_far = z_far
def sample(self, images, poses, projs):
raise NotImplementedError
def reconstruct(self, render_dict):
raise NotImplementedError
class RandomRaySampler(RaySampler):
def __init__(
self, z_near: float, z_far: float, ray_batch_size: int, channels: int = 3
) -> None:
super().__init__(z_near, z_far)
self.ray_batch_size = ray_batch_size
self.channels = channels
def sample(self, images, poses, projs, image_ids=None):
n, v, c, h, w = images.shape
all_rgb_gt = []
all_rays = []
for n_ in range(n):
focals = projs[n_, :, [0, 1], [0, 1]]
centers = projs[n_, :, [0, 1], [2, 2]]
rays, xy = util.gen_rays(
poses[n_].view(-1, 4, 4),
w,
h,
focal=focals,
c=centers,
z_near=self.z_near,
z_far=self.z_far,
)
# Append frame id to the ray
if image_ids is None:
ids = torch.arange(v, device=images.device, dtype=images.dtype)
else:
ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype)
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1)
rays = torch.cat((rays, ids), dim=-1)
rays = torch.cat((rays, xy), dim=-1)
r_dim = rays.shape[-1]
rays = rays.view(-1, r_dim)
rgb_gt = images[n_].view(-1, c, h, w)
rgb_gt = rgb_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, c)
pix_inds = torch.randint(0, v * h * w, (self.ray_batch_size,))
rgb_gt = rgb_gt[pix_inds]
rays = rays[pix_inds]
all_rgb_gt.append(rgb_gt)
all_rays.append(rays)
all_rgb_gt = torch.stack(all_rgb_gt)
all_rays = torch.stack(all_rays)
return all_rays, all_rgb_gt
def reconstruct(self, render_dict, channels=None):
for name, render_dict_part in render_dict.items():
if not type(render_dict_part) == dict or not "rgb" in render_dict_part:
continue
if channels is None:
channels = self.channels
rgb = render_dict_part["rgb"] # n, n_pts, v * 3
depth = render_dict_part["depth"]
invalid = render_dict_part["invalid"]
rgb_gt = render_dict["rgb_gt"]
n, n_pts, v_c = rgb.shape
v = v_c // channels
n_smps = invalid.shape[-2]
render_dict_part["rgb"] = rgb.view(n, n_pts, v, channels)
render_dict_part["depth"] = depth.view(n, n_pts)
render_dict_part["invalid"] = invalid.view(n, n_pts, n_smps, v)
if "invalid_features" in render_dict_part:
invalid_features = render_dict_part["invalid_features"]
render_dict_part["invalid_features"] = invalid_features.view(n, n_pts, n_smps, v)
if "weights" in render_dict_part:
weights = render_dict_part["weights"]
render_dict_part["weights"] = weights.view(n, n_pts, n_smps)
if "alphas" in render_dict_part:
alphas = render_dict_part["alphas"]
render_dict_part["alphas"] = alphas.view(n, n_pts, n_smps)
if "z_samps" in render_dict_part:
z_samps = render_dict_part["z_samps"]
render_dict_part["z_samps"] = z_samps.view(n, n_pts, n_smps)
if "rgb_samps" in render_dict_part:
rgb_samps = render_dict_part["rgb_samps"]
render_dict_part["rgb_samps"] = rgb_samps.view(n, n_pts, n_smps, v, channels)
if "ray_info" in render_dict_part:
ri_shape = render_dict_part["ray_info"].shape[-1]
ray_info = render_dict_part["ray_info"]
render_dict_part["ray_info"] = ray_info.view(n, n_pts, ri_shape)
if "extras" in render_dict_part:
extras_shape = render_dict_part["extras"].shape[-1]
extras = render_dict_part["extras"]
render_dict_part["extras"] = extras.view(n, n_pts, extras_shape)
render_dict[name] = render_dict_part
render_dict["rgb_gt"] = rgb_gt.view(n, n_pts, channels)
return render_dict
class PatchRaySampler(RaySampler):
def __init__(
self,
z_near: float,
z_far: float,
ray_batch_size: int,
patch_size: int,
channels: int = 3,
snap_to_grid: bool = False,
dino_upscaled: bool = False,
) -> None:
super().__init__(z_near, z_far)
self.ray_batch_size = ray_batch_size
self.channels = channels
self.snap_to_grid = snap_to_grid
self.dino_upscaled = dino_upscaled
if isinstance(patch_size, int):
self.patch_size_x, self.patch_size_y = patch_size, patch_size
elif (
isinstance(patch_size, tuple)
or isinstance(patch_size, list)
or isinstance(patch_size, ListConfig)
):
self.patch_size_y = patch_size[0]
self.patch_size_x = patch_size[1]
else:
raise ValueError(f"Invalid format for patch size")
assert (ray_batch_size % (self.patch_size_x * self.patch_size_y)) == 0
self._patch_count = self.ray_batch_size // (
self.patch_size_x * self.patch_size_y
)
def sample(
self, images, poses, projs, image_ids=None, dino_features=None, loss_feature_grid_shift=None,
): ### dim(images) == nv (ids_loss nv randomly sampled)
n, v, c, h, w = images.shape
self.channels = c
if dino_features is not None:
_, _, dino_channels, _, _ = dino_features.shape
dino_features = dino_features.permute(0, 1, 3, 4, 2)
device = images.device
images = images.permute(0, 1, 3, 4, 2)
all_rgb_gt, all_rays, all_dino_gt = [], [], []
for n_ in range(n):
focals = projs[n_, :, [0, 1], [0, 1]]
centers = projs[n_, :, [0, 1], [2, 2]]
rays, xy = util.gen_rays(
poses[n_].view(-1, 4, 4),
w,
h,
focal=focals,
c=centers,
z_near=self.z_near,
z_far=self.z_far,
)
# Append frame id to the ray
if image_ids is None:
ids = torch.arange(v, device=images.device, dtype=images.dtype)
else:
ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype)
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1)
rays = torch.cat((rays, ids), dim=-1)
rays = torch.cat((rays, xy), dim=-1)
r_dim = rays.shape[-1]
patch_coords_v = torch.randint(0, v, (self._patch_count,))
if self.snap_to_grid:
if loss_feature_grid_shift is not None:
patch_coords_y = torch.randint(0, h // self.patch_size_y - 1, (self._patch_count,))
patch_coords_x = torch.randint(0, w // self.patch_size_x - 1, (self._patch_count,))
else:
patch_coords_y = torch.randint(0, h // self.patch_size_y, (self._patch_count,))
patch_coords_x = torch.randint(0, w // self.patch_size_x, (self._patch_count,))
else:
patch_coords_y = torch.randint(0, h - self.patch_size_y, (self._patch_count,))
patch_coords_x = torch.randint(0, w - self.patch_size_x, (self._patch_count,))
sample_rgb_gt = []
sample_rays = []
sample_dino_gt = []
for v_, coord_y, coord_x in zip(patch_coords_v, patch_coords_y, patch_coords_x):
if self.snap_to_grid:
patch_y, patch_x = coord_y, coord_x
if loss_feature_grid_shift is not None:
y = (loss_feature_grid_shift[0] % self.patch_size_y) + self.patch_size_y * coord_y
x = (loss_feature_grid_shift[1] % self.patch_size_x) + self.patch_size_x * coord_x
if loss_feature_grid_shift[0] < 0:
patch_y += 1
if loss_feature_grid_shift[1] < 0:
patch_x += 1
else:
y = self.patch_size_y * coord_y
x = self.patch_size_x * coord_x
else:
raise NotImplementedError
rgb_gt_patch = images[n_][
v_, y : y + self.patch_size_y, x : x + self.patch_size_x, :
].reshape(-1, self.channels)
rays_patch = rays[
v_, y : y + self.patch_size_y, x : x + self.patch_size_x, :
].reshape(-1, r_dim)
sample_rgb_gt.append(rgb_gt_patch)
sample_rays.append(rays_patch)
if dino_features is not None:
if self.dino_upscaled:
dino_gt_patch = dino_features[n_][
v_, y: y + self.patch_size_y, x: x + self.patch_size_x, :
].reshape(-1, dino_channels)
else:
dino_gt_patch = dino_features[n_][
v_, patch_y, patch_x, :
].reshape(-1, dino_channels)
sample_dino_gt.append(dino_gt_patch)
sample_rgb_gt = torch.cat(sample_rgb_gt, dim=0)
sample_rays = torch.cat(sample_rays, dim=0)
all_rgb_gt.append(sample_rgb_gt)
all_rays.append(sample_rays)
if dino_features is not None:
sample_dino_gt = torch.cat(sample_dino_gt, dim=0)
all_dino_gt.append(sample_dino_gt)
all_rgb_gt = torch.stack(all_rgb_gt)
all_rays = torch.stack(all_rays)
if dino_features is not None:
all_dino_gt = torch.stack(all_dino_gt)
return all_rays, all_rgb_gt, all_dino_gt
else:
return all_rays, all_rgb_gt
def reconstruct(self, render_dict, channels=None, dino_channels=None):
for name, render_dict_part in render_dict.items():
if not type(render_dict_part) == dict or not "rgb" in render_dict_part:
continue
if channels is None:
channels = self.channels
rgb_gt = render_dict["rgb_gt"]
dino_gt = render_dict["dino_gt"]
n, n_pts, v_c = render_dict_part["rgb"].shape
v = v_c // channels
n_smps = render_dict_part["weights"].shape[-1]
# (This can be a different v from the sample method)
render_dict_part["rgb"] = render_dict_part["rgb"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, v, channels
)
render_dict_part["weights"] = render_dict_part["weights"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps
)
render_dict_part["depth"] = render_dict_part["depth"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x
)
render_dict_part["invalid"] = render_dict_part["invalid"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps, v
)
# TODO: Figure out DINO invalid policy
# if "invalid_features" in render_dict_part:
# render_dict_part["invalid_features"] = render_dict_part["invalid_features"].view(
# n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps, v
# )
if "alphas" in render_dict_part:
render_dict_part["alphas"] = render_dict_part["alphas"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps
)
if "z_samps" in render_dict_part:
render_dict_part["z_samps"] = render_dict_part["z_samps"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, n_smps
)
if "rgb_samps" in render_dict_part:
render_dict_part["rgb_samps"] = render_dict_part["rgb_samps"].view(
n,
self._patch_count,
self.patch_size_y,
self.patch_size_x,
n_smps,
v,
channels,
)
if "ray_info" in render_dict_part:
ri_shape = render_dict_part["ray_info"].shape[-1]
render_dict_part["ray_info"] = render_dict_part["ray_info"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, ri_shape
)
if "extras" in render_dict_part:
extras_shape = render_dict_part["extras"].shape[-1]
render_dict_part["extras"] = render_dict_part["extras"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, extras_shape
)
if "dino_features" in render_dict_part:
dino_shape = render_dict_part["dino_features"].shape[-1]
render_dict_part["dino_features"] = render_dict_part["dino_features"].view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, 1, dino_shape
)
render_dict[name] = render_dict_part
render_dict["rgb_gt"] = rgb_gt.view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, channels
)
dino_gt_shape = dino_gt.shape[-1]
if self.dino_upscaled:
render_dict["dino_gt"] = dino_gt.view(
n, self._patch_count, self.patch_size_y, self.patch_size_x, dino_gt_shape
)
else:
render_dict["dino_gt"] = dino_gt.view(
n, self._patch_count, dino_gt_shape
)
if "dino_artifacts" in render_dict:
render_dict["dino_artifacts"] = render_dict["dino_artifacts"].view(
n, self._patch_count, dino_gt_shape
)
return render_dict
class PointBasedRaySampler(RandomRaySampler):
def sample(
self, images, poses, projs, xyz, image_ids=None
): ### dim(images) == nv (ids_loss nv randomly sampled)
n, v, c, h, w = images.shape
assert v == 1
with autocast(enabled=False):
poses_w2c = torch.inverse(poses)
inv_K = torch.inverse(projs[:, :, :3, :3])
B, n_pts, _ = xyz.shape
xyz_projected = pts_into_camera(xyz, poses_w2c)
distance = torch.norm(xyz_projected, dim=-2, keepdim=True)
xy, z = project_to_image(xyz_projected, projs)
xy = xy[:, 0]
z = z[:, 0]
distance = distance[:, 0, 0, :, None]
# For numerical stability with AMP. Should not affect training outcome
xy = xy.clamp(-2, 2)
# Build rays
cam_centers = poses[:, 0, None, :3, 3].expand(-1, n_pts, -1)
ray_dir = ((poses[:, 0, :3, :3] @ inv_K[:, 0]) @ torch.cat((xy, torch.ones_like(xy[..., :1])), dim=-1).permute(0, 2, 1)).permute(0, 2, 1)
cam_nears = torch.ones_like(cam_centers[..., :1]) * self.z_near
cam_fars = torch.ones_like(cam_centers[..., :1]) * self.z_far
ids = torch.zeros_like(cam_nears)
rays = torch.cat((cam_centers, ray_dir, cam_nears, cam_fars, ids, xy, distance), dim=-1)
rgb_gt = F.grid_sample(images[:, 0], xy.reshape(n, -1, 1, 2).to(images.dtype), padding_mode="border", align_corners=False)[..., 0].permute(0, 2, 1)
return rays, rgb_gt
class ImageRaySampler(RaySampler):
def __init__(
self,
z_near: float,
z_far: float,
height: int | None = None,
width: int | None = None,
channels: int = 3,
norm_dir: bool = True,
dino_upscaled: bool = False,
) -> None:
super().__init__(z_near, z_far)
self.height = height
self.width = width
self.channels = channels
self.norm_dir = norm_dir
self.dino_upscaled = dino_upscaled
def sample(self, images, poses, projs, image_ids=None, dino_features=None, dino_artifacts=None):
n, v, _, _ = poses.shape
device = poses.device
dtype = poses.dtype
if images is not None:
self.channels = images.shape[2]
if self.height is None:
self.height, self.width = images.shape[-2:]
if dino_features is not None:
_, _, dino_channels, _, _ = dino_features.shape
h = self.height
w = self.width
all_rgb_gt = []
all_dino_gt = []
all_rays = []
for n_ in range(n):
focals = projs[n_, :, [0, 1], [0, 1]]
centers = projs[n_, :, [0, 1], [2, 2]]
rays, xy = util.gen_rays(
poses[n_].view(-1, 4, 4),
self.width,
self.height,
focal=focals,
c=centers,
z_near=self.z_near,
z_far=self.z_far,
norm_dir=self.norm_dir,
)
# Append frame id to the ray
if image_ids is None:
ids = torch.arange(v, device=device, dtype=dtype)
else:
ids = torch.tensor(image_ids, device=device, dtype=dtype)
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1)
rays = torch.cat((rays, ids), dim=-1)
rays = torch.cat((rays, xy), dim=-1)
r_dim = rays.shape[-1]
rays = rays.view(-1, r_dim)
all_rays.append(rays)
if images is not None:
rgb_gt = images[n_].view(-1, self.channels, self.height, self.width)
rgb_gt = (
rgb_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, self.channels)
)
all_rgb_gt.append(rgb_gt)
if dino_features is not None:
patch_h, patch_w = dino_features[n_].shape[-2], dino_features[n_].shape[-1]
dino_gt = dino_features[n_].view(-1, dino_channels, patch_h, patch_w)
dino_gt = (
dino_gt.permute(0, 2, 3, 1).contiguous().reshape(-1, dino_channels)
)
all_dino_gt.append(dino_gt)
all_rays = torch.stack(all_rays)
if images is not None:
all_rgb_gt = torch.stack(all_rgb_gt)
else:
all_rgb_gt = None
if dino_features is not None:
all_dino_gt = torch.stack(all_dino_gt)
return all_rays, all_rgb_gt, all_dino_gt
else:
return all_rays, all_rgb_gt
def reconstruct(self, render_dict, channels=None, dino_channels=None):
for name, render_dict_part in render_dict.items():
if not type(render_dict_part) == dict or not "rgb" in render_dict_part:
continue
if channels is None:
channels = self.channels
rgb = render_dict_part["rgb"] # n, n_pts, v * 3
weights = render_dict_part["weights"]
depth = render_dict_part["depth"]
invalid = render_dict_part["invalid"]
n, n_pts, v_c = rgb.shape
v_in = n_pts // (self.height * self.width)
v_render = v_c // channels
n_smps = weights.shape[-1]
# (This can be a different v from the sample method)
render_dict_part["rgb"] = rgb.view(n, v_in, self.height, self.width, v_render, channels)
render_dict_part["weights"] = weights.view(n, v_in, self.height, self.width, n_smps)
render_dict_part["depth"] = depth.view(n, v_in, self.height, self.width)
render_dict_part["invalid"] = invalid.view(
n, v_in, self.height, self.width, n_smps, v_render
)
if "invalid_features" in render_dict_part:
invalid_features = render_dict_part["invalid_features"]
render_dict_part["invalid_features"] = invalid_features.view(
n, v_in, self.height, self.width, n_smps, v_render
)
if "alphas" in render_dict_part:
alphas = render_dict_part["alphas"]
render_dict_part["alphas"] = alphas.view(n, v_in, self.height, self.width, n_smps)
if "z_samps" in render_dict_part:
z_samps = render_dict_part["z_samps"]
render_dict_part["z_samps"] = z_samps.view(
n, v_in, self.height, self.width, n_smps
)
if "rgb_samps" in render_dict_part:
rgb_samps = render_dict_part["rgb_samps"]
render_dict_part["rgb_samps"] = rgb_samps.view(
n, v_in, self.height, self.width, n_smps, v_render, channels
)
if "ray_info" in render_dict_part:
ri_shape = render_dict_part["ray_info"].shape[-1]
ray_info = render_dict_part["ray_info"]
render_dict_part["ray_info"] = ray_info.view(n, v_in, self.height, self.width, ri_shape)
if "extras" in render_dict_part:
ex_shape = render_dict_part["extras"].shape[-1]
extras = render_dict_part["extras"]
render_dict_part["extras"] = extras.view(n, v_in, self.height, self.width, ex_shape)
if "dino_features" in render_dict_part:
dino_shape = render_dict_part["dino_features"].shape[-1]
dino = render_dict_part["dino_features"]
render_dict_part["dino_features"] = dino.view(n, v_in, self.height, self.width, 1, dino_shape)
render_dict[name] = render_dict_part
if "rgb_gt" in render_dict:
rgb_gt = render_dict["rgb_gt"]
render_dict["rgb_gt"] = rgb_gt.view(
n, v_in, self.height, self.width, channels
)
if "dino_gt" in render_dict:
dino_gt = render_dict["dino_gt"]
dino_gt_shape = dino_gt.shape[-1]
if self.dino_upscaled:
render_dict["dino_gt"] = dino_gt.view(
n, v_in, self.height, self.width, dino_gt_shape
)
else:
# TODO: patch size should not be inferred like this, but parameter
patch_size = isqrt((n * v_in * self.height * self.width * dino_gt_shape) // dino_gt.numel())
render_dict["dino_gt"] = dino_gt.view(
n, v_in, self.height // patch_size, self.width // patch_size, dino_gt_shape
)
if "dino_artifacts" in render_dict:
dino_artifacts = render_dict["dino_artifacts"]
# TODO: patch size should not be inferred like this, but parameter
patch_size = isqrt((n * v_in * self.height * self.width * dino_gt_shape) // dino_artifacts.numel())
render_dict["dino_artifacts"] = dino_artifacts.view(
n, v_in, self.height // patch_size, self.width // patch_size, dino_gt_shape
)
return render_dict
class JitteredPatchRaySampler(PatchRaySampler):
def __init__(
self,
z_near: float,
z_far: float,
ray_batch_size: int,
patch_size: int,
jitter_strength: float, # In pixels, max [0, 1)
channels: int = 3,
) -> None:
super().__init__(z_near, z_far, ray_batch_size, patch_size, channels)
assert 0 <= jitter_strength < 1, "Jitter strength is invalid."
self.jitter_strength = jitter_strength
x = torch.arange(0, self.patch_size_x).view(1, 1, -1, 1).expand(-1, self.patch_size_y, -1, -1)
y = torch.arange(0, self.patch_size_y).view(1, -1, 1, 1).expand(-1, -1, self.patch_size_x, -1)
self._grid = torch.cat((x, y), dim=-1)
def sample(
self, images, poses, projs, image_ids=None
): ### dim(images) == nv (ids_loss nv randomly sampled)
n, v, c, h, w = images.shape
device = images.device
all_rgb_gt, all_rays = [], []
xy_offset = ((torch.rand(2) - .5) * self.jitter_strength)
for n_ in range(n):
focals = projs[n_, :, [0, 1], [0, 1]]
centers = projs[n_, :, [0, 1], [2, 2]]
rays, xy = util.gen_rays(
poses[n_].view(-1, 4, 4),
w,
h,
focal=focals,
c=centers,
z_near=self.z_near,
z_far=self.z_far,
xy_offset=xy_offset,
)
# Append frame id to the ray
if image_ids is None:
ids = torch.arange(v, device=images.device, dtype=images.dtype)
else:
ids = torch.tensor(image_ids, device=images.device, dtype=images.dtype)
ids = ids.view(v, 1, 1, 1).expand(v, h, w, 1)
rays = torch.cat((rays, ids), dim=-1)
r_dim = rays.shape[-1]
patch_coords_v = torch.randint(0, v, (self._patch_count,))
patch_coords_y = torch.randint(
0, h - self.patch_size_y, (self._patch_count,)
)
patch_coords_x = torch.randint(
0, w - self.patch_size_x, (self._patch_count,)
)
sample_rgb_gt = []
sample_rays = []
for v_, y, x in zip(patch_coords_v, patch_coords_y, patch_coords_x):
xy_patch = torch.tensor((x, y)).view(1, 1, 1, 2)
patch_grid = self._grid + xy_patch + xy_offset.view(1, 1, 1, 2) + .5
patch_grid = patch_grid.to(images.device)
patch_grid[..., 0] = (patch_grid[..., 0] / w) * 2 - 1
patch_grid[..., 1] = (patch_grid[..., 1] / h) * 2 - 1
rgb_gt_patch = F.grid_sample(images[n_:n_+1, v_], patch_grid, padding_mode="border", align_corners=False)
rgb_gt_patch = rgb_gt_patch.permute(0, 2, 3, 1).reshape(-1, self.channels)
rays_patch = rays[
v_, y : y + self.patch_size_y, x : x + self.patch_size_x, :
].reshape(-1, r_dim)
sample_rgb_gt.append(rgb_gt_patch)
sample_rays.append(rays_patch)
sample_rgb_gt = torch.cat(sample_rgb_gt, dim=0)
sample_rays = torch.cat(sample_rays, dim=0)
all_rgb_gt.append(sample_rgb_gt)
all_rays.append(sample_rays)
all_rgb_gt = torch.stack(all_rgb_gt)
all_rays = torch.stack(all_rays)
return all_rays, all_rgb_gt
def get_ray_sampler(config) -> RaySampler:
z_near = config["z_near"]
z_far = config["z_far"]
sample_mode = config.get("sample_mode", "random")
# TODO: check channel size
match sample_mode:
case "random":
return RandomRaySampler(z_near, z_far, **config["args"])
case "patch":
return PatchRaySampler(z_near, z_far, **config["args"])
case "jitteredpatch":
return JitteredPatchRaySampler(z_near, z_far, **config["args"])
case "image":
return ImageRaySampler(z_near, z_far)
case _:
raise NotImplementedError