Spaces:
Running
on
Zero
Running
on
Zero
from random import shuffle | |
import random | |
from typing import Callable, Optional | |
import numpy as np | |
import torch | |
EncoderSamplingStrategy = Callable[[int], list[int]] | |
LossSamplingStrategy = Callable[[int], tuple[list[int], list[int], Optional[list[list[bool]]]]] | |
# ============================================ ENCODING SAMPLING STRATEGIES ============================================ | |
def default_encoder_sampler() -> EncoderSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> list[int]: | |
return [0] | |
return _sampling_strategy | |
def kitti_360_full_encoder_sampler( | |
num_encoder_frames: int, always_use_base_frame: bool = True | |
) -> EncoderSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> list[int]: | |
if always_use_base_frame: | |
encoder_perm = (torch.randperm(num_frames - 1) + 1)[ | |
: num_encoder_frames - 1 | |
].tolist() | |
ids_encoder = [0] | |
ids_encoder.extend(encoder_perm) | |
else: | |
ids_encoder = (torch.randperm(num_frames - 1) + 1)[ | |
:num_encoder_frames | |
].tolist() | |
return ids_encoder | |
return _sampling_strategy | |
def kitti_360_stereo_encoder_sampler( | |
num_encoder_frames: int, num_stereo_frames: int, always_use_base_frame: bool = True | |
) -> EncoderSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> list[int]: | |
num_frames = min(num_frames, num_stereo_frames) | |
if always_use_base_frame: | |
encoder_perm = (torch.randperm(num_frames - 1) + 1)[ | |
: num_encoder_frames - 1 | |
].tolist() | |
ids_encoder = [0] | |
ids_encoder.extend(encoder_perm) | |
else: | |
ids_encoder = (torch.randperm(num_frames - 1) + 1)[ | |
:num_encoder_frames | |
].tolist() | |
return ids_encoder | |
return _sampling_strategy | |
def get_encoder_sampling(config) -> EncoderSamplingStrategy: | |
strategy = config.get("name", None) | |
match strategy: | |
case "kitti_360_full": | |
return kitti_360_full_encoder_sampler(**config["args"]) | |
case "kitti_360_stereo": | |
return kitti_360_stereo_encoder_sampler(**config["args"]) | |
case _: | |
return default_encoder_sampler() | |
# =============================================== LOSS SAMPLING STRATEGIES ============================================= | |
def single_view_loss_sampler( | |
shuffle_frames: bool = False, all_frames: bool = False | |
) -> LossSamplingStrategy: | |
if all_frames: | |
starting_frame = 0 | |
else: | |
starting_frame = 1 | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
frames = [id for id in range(num_frames)] | |
if shuffle_frames: | |
shuffle(frames) | |
return frames[0:1], frames[starting_frame:], None | |
return _sampling_strategy | |
def single_view_renderer_sampler( | |
shuffle_frames: bool = False, all_frames: bool = False | |
) -> LossSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
frames = [id for id in range(num_frames)] | |
if shuffle_frames: | |
shuffle(frames) | |
if all_frames: | |
return frames, frames[0:1], None | |
else: | |
return frames[0:-1], frames[0:1], None | |
return _sampling_strategy | |
def stereo_view_loss_sampler(shuffle_frames: bool = False) -> LossSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
all_frames = [id for id in range(num_frames)] | |
if shuffle_frames: | |
shuffle(all_frames) | |
if all_frames[0] < num_frames // 2: | |
ids_loss = list(range(num_frames // 2)) | |
ids_renderer = list(range(num_frames // 2, num_frames)) | |
else: | |
ids_renderer = list(range(num_frames // 2)) | |
ids_loss = list(range(num_frames // 2, num_frames)) | |
return ids_loss, ids_renderer, None | |
return _sampling_strategy | |
def kitti_360_loss_sampler() -> LossSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
ids_loss: list[int] = [] | |
ids_renderer: list[int] = [] | |
for cam_pair_base_id in range(0, num_frames, 2): | |
if random.randint(0, 2): | |
ids_loss.append(cam_pair_base_id) | |
ids_renderer.append(cam_pair_base_id + 1) | |
else: | |
ids_loss.append(cam_pair_base_id + 1) | |
ids_renderer.append(cam_pair_base_id) | |
return ids_loss, ids_renderer, None | |
return _sampling_strategy | |
def kitti_360_loss_sampler() -> LossSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
ids_loss: list[int] = [] | |
ids_renderer: list[int] = [] | |
for cam_pair_base_id in range(0, num_frames, 2): | |
if random.randint(0, 2): | |
ids_loss.append(cam_pair_base_id) | |
ids_renderer.append(cam_pair_base_id + 1) | |
else: | |
ids_loss.append(cam_pair_base_id + 1) | |
ids_renderer.append(cam_pair_base_id) | |
return ids_loss, ids_renderer, None | |
return _sampling_strategy | |
def kitti_360_with_mapping_loss_sampler() -> LossSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
ids_loss: list[int] = [] | |
ids_renderer: list[int] = [] | |
mapping = [] | |
for cam_pair_base_id in range(0, num_frames, 2): | |
if random.randint(0, 2): | |
ids_loss.append(cam_pair_base_id) | |
ids_renderer.append(cam_pair_base_id + 1) | |
mapping.append([len(ids_renderer) - 1]) | |
else: | |
ids_loss.append(cam_pair_base_id + 1) | |
ids_renderer.append(cam_pair_base_id) | |
mapping.append([len(ids_renderer) - 1]) | |
mapping = np.array(mapping, dtype=np.int64) | |
return ids_loss, ids_renderer, mapping | |
return _sampling_strategy | |
def waymo_with_mapping_loss_sampler() -> LossSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
ids_loss: list[int] = [] | |
ids_renderer: list[int] = [] | |
mapping = [] | |
for cam_pair_base_id in range(0, num_frames, 2): | |
if random.randint(0, 2): | |
ids_loss.append(cam_pair_base_id) | |
ids_renderer.append(cam_pair_base_id + 1) | |
mapping.extend([[len(ids_renderer) - 1], [len(ids_renderer) - 1]]) | |
else: | |
ids_loss.append(cam_pair_base_id + 1) | |
ids_renderer.append(cam_pair_base_id) | |
mapping.extend([[len(ids_renderer) - 1], [len(ids_renderer) - 1]]) | |
mapping = np.array(mapping, dtype=np.int64) | |
return ids_loss, ids_renderer, mapping | |
return _sampling_strategy | |
def alternate_loss_sampler() -> LossSamplingStrategy: | |
def _sampling_strategy(num_frames: int) -> tuple[list[int], list[int]]: | |
frames = [id for id in range(num_frames)] | |
if random.randint(0, 2): | |
return list(range(0, num_frames, 2)), list(range(1, num_frames, 2)), None | |
else: | |
return list(range(1, num_frames, 2)), list(range(0, num_frames, 2)), None | |
return _sampling_strategy | |
def get_loss_renderer_sampling(config) -> EncoderSamplingStrategy: | |
strategy = config.get("name", None) | |
match strategy: | |
case "single_loss": | |
return single_view_loss_sampler(**config.get("args", {})) | |
case "single_renderer": | |
return single_view_renderer_sampler(**config.get("args", {})) | |
case "stereo_loss": | |
return stereo_view_loss_sampler(**config.get("args", {})) | |
case "kitti_360": | |
return kitti_360_loss_sampler() | |
case "kitti_360_with_mapping": | |
return kitti_360_with_mapping_loss_sampler() | |
case "waymo_with_mapping": | |
return waymo_with_mapping_loss_sampler() | |
case "alternate": | |
return alternate_loss_sampler() | |
case _: | |
return single_view_loss_sampler(False) | |
# old sampling strategies | |
# if self.training: | |
# frame_perm = torch.randperm(v) | |
# else: | |
# frame_perm = torch.arange(v) ## eval | |
# if self.enc_style == "random": ## encoded views | |
# encoder_perm = (torch.randperm(v - 1) + 1)[ | |
# : self.nv_ - 1 | |
# ].tolist() ## nv-1 for mono [0] idx | |
# ids_encoder = [0] ## always starts sampling from mono cam | |
# ids_encoder.extend(encoder_perm) ## add more cam_views randomly incl. fe | |
# elif self.enc_style == "default": | |
# ids_encoder = [ | |
# v_ for v_ in range(self.nv_) | |
# ] ## iterating view(v_) over num_views(nv_) | |
# elif self.enc_style == "stereo": | |
# if self.training: | |
# # if v < 8: raise RuntimeError(f"__number of views should be more than 4 when excluding fisheye views") | |
# # if v < 8: raise RuntimeError(f"__number of views should be more than 4 when excluding fisheye views") | |
# encoder_perm = (torch.randperm(v - (1 + 4)) + 1)[ | |
# : self.nv_ - 1 | |
# ].tolist() | |
# ids_encoder = [0] | |
# ids_encoder.extend(encoder_perm) | |
# else: | |
# ids_encoder = [0] | |
# else: | |
# raise NotImplementedError(f"__unrecognized enc_style: {self.enc_style}") | |
# ## default: ids_encoder = [0,1,2,3] <=> front stereo for 1st + 2nd time stamps | |
# if ( | |
# not self.training and self.ids_enc_viz_eval | |
# ): ## when eval in viz to be standardized with test: it's eval from line 354, base_trainer.py | |
# ids_encoder = self.ids_enc_viz_eval ## fixed during eval | |
# ids_render = torch.sort( | |
# frame_perm[[i for i in self.frames_render if i < v]] | |
# ).values ## ? ### tensor([0, 4]) | |
# combine_ids = None | |
# if self.training: | |
# if self.frame_sample_mode == "only": | |
# ids_loss = [0] | |
# ids_render = ids_render[ids_render != 0] | |
# elif self.frame_sample_mode == "not": | |
# frame_perm = torch.randperm(v - 1) + 1 | |
# ids_loss = torch.sort( | |
# frame_perm[[i for i in self.frames_render if i < v - 1]] | |
# ).values | |
# ids_render = [i for i in range(v) if i not in ids_loss] | |
# elif self.frame_sample_mode == "stereo": | |
# if frame_perm[0] < v // 2: | |
# ids_loss = list(range(v // 2)) | |
# ids_render = list(range(v // 2, v)) | |
# else: | |
# ids_loss = list(range(v // 2, v)) | |
# ids_render = list(range(v // 2)) | |
# elif self.frame_sample_mode == "mono": | |
# split_i = v // 2 | |
# if frame_perm[0] < v // 2: | |
# ids_loss = list(range(0, split_i, 2)) + list( | |
# range(split_i + 1, v, 2) | |
# ) | |
# ids_render = list(range(1, split_i, 2)) + list(range(split_i, v, 2)) | |
# else: | |
# ids_loss = list(range(1, split_i, 2)) + list(range(split_i, v, 2)) | |
# ids_render = list(range(0, split_i, 2)) + list( | |
# range(split_i + 1, v, 2) | |
# ) | |
# elif self.frame_sample_mode == "kitti360-mono": | |
# steps = v // 4 | |
# start_from = 0 if frame_perm[0] < v // 2 else 1 | |
# ids_loss, ids_render = [], [] | |
# for cam in range( | |
# 4 | |
# ): ## stereo cam sampled for each time ## ! c.f. paper: N_{render}, N_{loss} | |
# ids_loss += [cam * steps + i for i in range(start_from, steps, 2)] | |
# ids_render += [ | |
# cam * steps + i for i in range(1 - start_from, steps, 2) | |
# ] | |
# start_from = 1 - start_from | |
# if self.enc_style == "test": | |
# ids_encoder = ids_loss[: self.nv_] | |
# elif self.frame_sample_mode.startswith("waymo"): | |
# num_views = int(self.frame_sample_mode.split("-")[-1]) | |
# steps = v // num_views | |
# split = steps // 2 | |
# # Predict features from half-left, center, half-right | |
# ids_encoder = [0, steps, steps * 2] | |
# # Combine all frames half-left, center, half-right for efficiency reasons | |
# combine_ids = [(i, steps + i, steps * 2 + i) for i in range(steps)] | |
# if self.training: | |
# step_perm = torch.randperm(steps) | |
# else: | |
# step_perm = torch.arange(steps) ## eval | |
# step_perm = step_perm.tolist() | |
# ids_loss = sum( | |
# [ | |
# [i + j * steps for j in range(num_views)] | |
# for i in step_perm[:split] | |
# ], | |
# [], | |
# ) | |
# ids_render = sum( | |
# [ | |
# [i + j * steps for j in range(num_views)] | |
# for i in step_perm[split:] | |
# ], | |
# [], | |
# ) | |
# elif self.frame_sample_mode == "default": | |
# ids_loss = frame_perm[ | |
# [i for i in range(v) if frame_perm[i] not in ids_render] | |
# ] | |
# else: | |
# raise NotImplementedError | |
# else: ## eval (!= self.training) | |
# ids_loss = torch.arange(v) | |
# ids_render = [0] | |
# if self.frame_sample_mode.startswith("waymo"): | |
# num_views = int(self.frame_sample_mode.split("-")[-1]) | |
# steps = v // num_views | |
# split = steps // 2 | |
# # Predict features from half-left, center, half-right | |
# ids_encoder = [0, steps, steps * 2] | |
# ids_render = [0, steps, steps * 2] | |
# combine_ids = [(i, steps + i, steps * 2 + i) for i in range(steps)] | |