SceneDINO / scenedino /common /sampling_strategies.py
jev-aleks's picture
scenedino init
9e15541
from random import shuffle
import random
from typing import Callable, Optional
import numpy as np
import torch
EncoderSamplingStrategy = Callable[[int], list[int]]
LossSamplingStrategy = Callable[[int], tuple[list[int], list[int], Optional[list[list[bool]]]]]
# ============================================ ENCODING SAMPLING STRATEGIES ============================================
def default_encoder_sampler() -> EncoderSamplingStrategy:
def _sampling_strategy(num_frames: int) -> list[int]:
return [0]
return _sampling_strategy
def kitti_360_full_encoder_sampler(
num_encoder_frames: int, always_use_base_frame: bool = True
) -> EncoderSamplingStrategy:
def _sampling_strategy(num_frames: int) -> list[int]:
if always_use_base_frame:
encoder_perm = (torch.randperm(num_frames - 1) + 1)[
: num_encoder_frames - 1
].tolist()
ids_encoder = [0]
ids_encoder.extend(encoder_perm)
else:
ids_encoder = (torch.randperm(num_frames - 1) + 1)[
:num_encoder_frames
].tolist()
return ids_encoder
return _sampling_strategy
def kitti_360_stereo_encoder_sampler(
num_encoder_frames: int, num_stereo_frames: int, always_use_base_frame: bool = True
) -> EncoderSamplingStrategy:
def _sampling_strategy(num_frames: int) -> list[int]:
num_frames = min(num_frames, num_stereo_frames)
if always_use_base_frame:
encoder_perm = (torch.randperm(num_frames - 1) + 1)[
: num_encoder_frames - 1
].tolist()
ids_encoder = [0]
ids_encoder.extend(encoder_perm)
else:
ids_encoder = (torch.randperm(num_frames - 1) + 1)[
:num_encoder_frames
].tolist()
return ids_encoder
return _sampling_strategy
def get_encoder_sampling(config) -> EncoderSamplingStrategy:
strategy = config.get("name", None)
match strategy:
case "kitti_360_full":
return kitti_360_full_encoder_sampler(**config["args"])
case "kitti_360_stereo":
return kitti_360_stereo_encoder_sampler(**config["args"])
case _:
return default_encoder_sampler()
# =============================================== LOSS SAMPLING STRATEGIES =============================================
def single_view_loss_sampler(
shuffle_frames: bool = False, all_frames: bool = False
) -> LossSamplingStrategy:
if all_frames:
starting_frame = 0
else:
starting_frame = 1
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
frames = [id for id in range(num_frames)]
if shuffle_frames:
shuffle(frames)
return frames[0:1], frames[starting_frame:], None
return _sampling_strategy
def single_view_renderer_sampler(
shuffle_frames: bool = False, all_frames: bool = False
) -> LossSamplingStrategy:
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
frames = [id for id in range(num_frames)]
if shuffle_frames:
shuffle(frames)
if all_frames:
return frames, frames[0:1], None
else:
return frames[0:-1], frames[0:1], None
return _sampling_strategy
def stereo_view_loss_sampler(shuffle_frames: bool = False) -> LossSamplingStrategy:
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
all_frames = [id for id in range(num_frames)]
if shuffle_frames:
shuffle(all_frames)
if all_frames[0] < num_frames // 2:
ids_loss = list(range(num_frames // 2))
ids_renderer = list(range(num_frames // 2, num_frames))
else:
ids_renderer = list(range(num_frames // 2))
ids_loss = list(range(num_frames // 2, num_frames))
return ids_loss, ids_renderer, None
return _sampling_strategy
def kitti_360_loss_sampler() -> LossSamplingStrategy:
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
ids_loss: list[int] = []
ids_renderer: list[int] = []
for cam_pair_base_id in range(0, num_frames, 2):
if random.randint(0, 2):
ids_loss.append(cam_pair_base_id)
ids_renderer.append(cam_pair_base_id + 1)
else:
ids_loss.append(cam_pair_base_id + 1)
ids_renderer.append(cam_pair_base_id)
return ids_loss, ids_renderer, None
return _sampling_strategy
def kitti_360_loss_sampler() -> LossSamplingStrategy:
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
ids_loss: list[int] = []
ids_renderer: list[int] = []
for cam_pair_base_id in range(0, num_frames, 2):
if random.randint(0, 2):
ids_loss.append(cam_pair_base_id)
ids_renderer.append(cam_pair_base_id + 1)
else:
ids_loss.append(cam_pair_base_id + 1)
ids_renderer.append(cam_pair_base_id)
return ids_loss, ids_renderer, None
return _sampling_strategy
def kitti_360_with_mapping_loss_sampler() -> LossSamplingStrategy:
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
ids_loss: list[int] = []
ids_renderer: list[int] = []
mapping = []
for cam_pair_base_id in range(0, num_frames, 2):
if random.randint(0, 2):
ids_loss.append(cam_pair_base_id)
ids_renderer.append(cam_pair_base_id + 1)
mapping.append([len(ids_renderer) - 1])
else:
ids_loss.append(cam_pair_base_id + 1)
ids_renderer.append(cam_pair_base_id)
mapping.append([len(ids_renderer) - 1])
mapping = np.array(mapping, dtype=np.int64)
return ids_loss, ids_renderer, mapping
return _sampling_strategy
def waymo_with_mapping_loss_sampler() -> LossSamplingStrategy:
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
ids_loss: list[int] = []
ids_renderer: list[int] = []
mapping = []
for cam_pair_base_id in range(0, num_frames, 2):
if random.randint(0, 2):
ids_loss.append(cam_pair_base_id)
ids_renderer.append(cam_pair_base_id + 1)
mapping.extend([[len(ids_renderer) - 1], [len(ids_renderer) - 1]])
else:
ids_loss.append(cam_pair_base_id + 1)
ids_renderer.append(cam_pair_base_id)
mapping.extend([[len(ids_renderer) - 1], [len(ids_renderer) - 1]])
mapping = np.array(mapping, dtype=np.int64)
return ids_loss, ids_renderer, mapping
return _sampling_strategy
def alternate_loss_sampler() -> LossSamplingStrategy:
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]:
frames = [id for id in range(num_frames)]
if random.randint(0, 2):
return list(range(0, num_frames, 2)), list(range(1, num_frames, 2)), None
else:
return list(range(1, num_frames, 2)), list(range(0, num_frames, 2)), None
return _sampling_strategy
def get_loss_renderer_sampling(config) -> EncoderSamplingStrategy:
strategy = config.get("name", None)
match strategy:
case "single_loss":
return single_view_loss_sampler(**config.get("args", {}))
case "single_renderer":
return single_view_renderer_sampler(**config.get("args", {}))
case "stereo_loss":
return stereo_view_loss_sampler(**config.get("args", {}))
case "kitti_360":
return kitti_360_loss_sampler()
case "kitti_360_with_mapping":
return kitti_360_with_mapping_loss_sampler()
case "waymo_with_mapping":
return waymo_with_mapping_loss_sampler()
case "alternate":
return alternate_loss_sampler()
case _:
return single_view_loss_sampler(False)
# old sampling strategies
# if self.training:
# frame_perm = torch.randperm(v)
# else:
# frame_perm = torch.arange(v) ## eval
# if self.enc_style == "random": ## encoded views
# encoder_perm = (torch.randperm(v - 1) + 1)[
# : self.nv_ - 1
# ].tolist() ## nv-1 for mono [0] idx
# ids_encoder = [0] ## always starts sampling from mono cam
# ids_encoder.extend(encoder_perm) ## add more cam_views randomly incl. fe
# elif self.enc_style == "default":
# ids_encoder = [
# v_ for v_ in range(self.nv_)
# ] ## iterating view(v_) over num_views(nv_)
# elif self.enc_style == "stereo":
# if self.training:
# # if v < 8: raise RuntimeError(f"__number of views should be more than 4 when excluding fisheye views")
# # if v < 8: raise RuntimeError(f"__number of views should be more than 4 when excluding fisheye views")
# encoder_perm = (torch.randperm(v - (1 + 4)) + 1)[
# : self.nv_ - 1
# ].tolist()
# ids_encoder = [0]
# ids_encoder.extend(encoder_perm)
# else:
# ids_encoder = [0]
# else:
# raise NotImplementedError(f"__unrecognized enc_style: {self.enc_style}")
# ## default: ids_encoder = [0,1,2,3] <=> front stereo for 1st + 2nd time stamps
# if (
# not self.training and self.ids_enc_viz_eval
# ): ## when eval in viz to be standardized with test: it's eval from line 354, base_trainer.py
# ids_encoder = self.ids_enc_viz_eval ## fixed during eval
# ids_render = torch.sort(
# frame_perm[[i for i in self.frames_render if i < v]]
# ).values ## ? ### tensor([0, 4])
# combine_ids = None
# if self.training:
# if self.frame_sample_mode == "only":
# ids_loss = [0]
# ids_render = ids_render[ids_render != 0]
# elif self.frame_sample_mode == "not":
# frame_perm = torch.randperm(v - 1) + 1
# ids_loss = torch.sort(
# frame_perm[[i for i in self.frames_render if i < v - 1]]
# ).values
# ids_render = [i for i in range(v) if i not in ids_loss]
# elif self.frame_sample_mode == "stereo":
# if frame_perm[0] < v // 2:
# ids_loss = list(range(v // 2))
# ids_render = list(range(v // 2, v))
# else:
# ids_loss = list(range(v // 2, v))
# ids_render = list(range(v // 2))
# elif self.frame_sample_mode == "mono":
# split_i = v // 2
# if frame_perm[0] < v // 2:
# ids_loss = list(range(0, split_i, 2)) + list(
# range(split_i + 1, v, 2)
# )
# ids_render = list(range(1, split_i, 2)) + list(range(split_i, v, 2))
# else:
# ids_loss = list(range(1, split_i, 2)) + list(range(split_i, v, 2))
# ids_render = list(range(0, split_i, 2)) + list(
# range(split_i + 1, v, 2)
# )
# elif self.frame_sample_mode == "kitti360-mono":
# steps = v // 4
# start_from = 0 if frame_perm[0] < v // 2 else 1
# ids_loss, ids_render = [], []
# for cam in range(
# 4
# ): ## stereo cam sampled for each time ## ! c.f. paper: N_{render}, N_{loss}
# ids_loss += [cam * steps + i for i in range(start_from, steps, 2)]
# ids_render += [
# cam * steps + i for i in range(1 - start_from, steps, 2)
# ]
# start_from = 1 - start_from
# if self.enc_style == "test":
# ids_encoder = ids_loss[: self.nv_]
# elif self.frame_sample_mode.startswith("waymo"):
# num_views = int(self.frame_sample_mode.split("-")[-1])
# steps = v // num_views
# split = steps // 2
# # Predict features from half-left, center, half-right
# ids_encoder = [0, steps, steps * 2]
# # Combine all frames half-left, center, half-right for efficiency reasons
# combine_ids = [(i, steps + i, steps * 2 + i) for i in range(steps)]
# if self.training:
# step_perm = torch.randperm(steps)
# else:
# step_perm = torch.arange(steps) ## eval
# step_perm = step_perm.tolist()
# ids_loss = sum(
# [
# [i + j * steps for j in range(num_views)]
# for i in step_perm[:split]
# ],
# [],
# )
# ids_render = sum(
# [
# [i + j * steps for j in range(num_views)]
# for i in step_perm[split:]
# ],
# [],
# )
# elif self.frame_sample_mode == "default":
# ids_loss = frame_perm[
# [i for i in range(v) if frame_perm[i] not in ids_render]
# ]
# else:
# raise NotImplementedError
# else: ## eval (!= self.training)
# ids_loss = torch.arange(v)
# ids_render = [0]
# if self.frame_sample_mode.startswith("waymo"):
# num_views = int(self.frame_sample_mode.split("-")[-1])
# steps = v // num_views
# split = steps // 2
# # Predict features from half-left, center, half-right
# ids_encoder = [0, steps, steps * 2]
# ids_render = [0, steps, steps * 2]
# combine_ids = [(i, steps + i, steps * 2 + i) for i in range(steps)]