|
|
""" |
|
|
adopted from SparseFusion |
|
|
Wrapper for the full CO3Dv2 dataset |
|
|
#@ Modified from https://github.com/facebookresearch/pytorch3d |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import random |
|
|
import time |
|
|
import warnings |
|
|
from collections import defaultdict |
|
|
from itertools import islice |
|
|
from typing import ( |
|
|
Any, |
|
|
ClassVar, |
|
|
List, |
|
|
Mapping, |
|
|
Optional, |
|
|
Sequence, |
|
|
Tuple, |
|
|
Type, |
|
|
TypedDict, |
|
|
Union, |
|
|
) |
|
|
from einops import rearrange, repeat |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms.functional as TF |
|
|
from pytorch3d.utils import opencv_from_cameras_projection |
|
|
from pytorch3d.implicitron.dataset import types |
|
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase |
|
|
from sgm.data.json_index_dataset import ( |
|
|
FrameAnnotsEntry, |
|
|
_bbox_xywh_to_xyxy, |
|
|
_bbox_xyxy_to_xywh, |
|
|
_clamp_box_to_image_bounds_and_round, |
|
|
_crop_around_box, |
|
|
_get_1d_bounds, |
|
|
_get_bbox_from_mask, |
|
|
_get_clamp_bbox, |
|
|
_load_1bit_png_mask, |
|
|
_load_16big_png_depth, |
|
|
_load_depth, |
|
|
_load_depth_mask, |
|
|
_load_image, |
|
|
_load_mask, |
|
|
_load_pointcloud, |
|
|
_rescale_bbox, |
|
|
_safe_as_tensor, |
|
|
_seq_name_to_seed, |
|
|
) |
|
|
from sgm.data.objaverse import video_collate_fn |
|
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( |
|
|
get_available_subset_names, |
|
|
) |
|
|
from pytorch3d.renderer.cameras import PerspectiveCameras |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field, fields |
|
|
|
|
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch |
|
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras |
|
|
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch |
|
|
from pytorch_lightning import LightningDataModule |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
CO3D_ALL_CATEGORIES = list( |
|
|
reversed( |
|
|
[ |
|
|
"baseballbat", |
|
|
"banana", |
|
|
"bicycle", |
|
|
"microwave", |
|
|
"tv", |
|
|
"cellphone", |
|
|
"toilet", |
|
|
"hairdryer", |
|
|
"couch", |
|
|
"kite", |
|
|
"pizza", |
|
|
"umbrella", |
|
|
"wineglass", |
|
|
"laptop", |
|
|
"hotdog", |
|
|
"stopsign", |
|
|
"frisbee", |
|
|
"baseballglove", |
|
|
"cup", |
|
|
"parkingmeter", |
|
|
"backpack", |
|
|
"toyplane", |
|
|
"toybus", |
|
|
"handbag", |
|
|
"chair", |
|
|
"keyboard", |
|
|
"car", |
|
|
"motorcycle", |
|
|
"carrot", |
|
|
"bottle", |
|
|
"sandwich", |
|
|
"remote", |
|
|
"bowl", |
|
|
"skateboard", |
|
|
"toaster", |
|
|
"mouse", |
|
|
"toytrain", |
|
|
"book", |
|
|
"toytruck", |
|
|
"orange", |
|
|
"broccoli", |
|
|
"plant", |
|
|
"teddybear", |
|
|
"suitcase", |
|
|
"bench", |
|
|
"ball", |
|
|
"cake", |
|
|
"vase", |
|
|
"hydrant", |
|
|
"apple", |
|
|
"donut", |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
CO3D_ALL_TEN = [ |
|
|
"donut", |
|
|
"apple", |
|
|
"hydrant", |
|
|
"vase", |
|
|
"cake", |
|
|
"ball", |
|
|
"bench", |
|
|
"suitcase", |
|
|
"teddybear", |
|
|
"plant", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FrameData(Mapping[str, Any]): |
|
|
""" |
|
|
A type of the elements returned by indexing the dataset object. |
|
|
It can represent both individual frames and batches of thereof; |
|
|
in this documentation, the sizes of tensors refer to single frames; |
|
|
add the first batch dimension for the collation result. |
|
|
Args: |
|
|
frame_number: The number of the frame within its sequence. |
|
|
0-based continuous integers. |
|
|
sequence_name: The unique name of the frame's sequence. |
|
|
sequence_category: The object category of the sequence. |
|
|
frame_timestamp: The time elapsed since the start of a sequence in sec. |
|
|
image_size_hw: The size of the image in pixels; (height, width) tensor |
|
|
of shape (2,). |
|
|
image_path: The qualified path to the loaded image (with dataset_root). |
|
|
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image |
|
|
of the frame; elements are floats in [0, 1]. |
|
|
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image |
|
|
regions. Regions can be invalid (mask_crop[i,j]=0) in case they |
|
|
are a result of zero-padding of the image after cropping around |
|
|
the object bounding box; elements are floats in {0.0, 1.0}. |
|
|
depth_path: The qualified path to the frame's depth map. |
|
|
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map |
|
|
of the frame; values correspond to distances from the camera; |
|
|
use `depth_mask` and `mask_crop` to filter for valid pixels. |
|
|
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the |
|
|
depth map that are valid for evaluation, they have been checked for |
|
|
consistency across views; elements are floats in {0.0, 1.0}. |
|
|
mask_path: A qualified path to the foreground probability mask. |
|
|
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the |
|
|
pixels belonging to the captured object; elements are floats |
|
|
in [0, 1]. |
|
|
bbox_xywh: The bounding box tightly enclosing the foreground object in the |
|
|
format (x0, y0, width, height). The convention assumes that |
|
|
`x0+width` and `y0+height` includes the boundary of the box. |
|
|
I.e., to slice out the corresponding crop from an image tensor `I` |
|
|
we execute `crop = I[..., y0:y0+height, x0:x0+width]` |
|
|
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` |
|
|
in the original image coordinates in the format (x0, y0, width, height). |
|
|
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs |
|
|
from `bbox_xywh` due to padding (which can happen e.g. due to |
|
|
setting `JsonIndexDataset.box_crop_context > 0`) |
|
|
camera: A PyTorch3D camera object corresponding the frame's viewpoint, |
|
|
corrected for cropping if it happened. |
|
|
camera_quality_score: The score proportional to the confidence of the |
|
|
frame's camera estimation (the higher the more accurate). |
|
|
point_cloud_quality_score: The score proportional to the accuracy of the |
|
|
frame's sequence point cloud (the higher the more accurate). |
|
|
sequence_point_cloud_path: The path to the sequence's point cloud. |
|
|
sequence_point_cloud: A PyTorch3D Pointclouds object holding the |
|
|
point cloud corresponding to the frame's sequence. When the object |
|
|
represents a batch of frames, point clouds may be deduplicated; |
|
|
see `sequence_point_cloud_idx`. |
|
|
sequence_point_cloud_idx: Integer indices mapping frame indices to the |
|
|
corresponding point clouds in `sequence_point_cloud`; to get the |
|
|
corresponding point cloud to `image_rgb[i]`, use |
|
|
`sequence_point_cloud[sequence_point_cloud_idx[i]]`. |
|
|
frame_type: The type of the loaded frame specified in |
|
|
`subset_lists_file`, if provided. |
|
|
meta: A dict for storing additional frame information. |
|
|
""" |
|
|
|
|
|
frame_number: Optional[torch.LongTensor] |
|
|
sequence_name: Union[str, List[str]] |
|
|
sequence_category: Union[str, List[str]] |
|
|
frame_timestamp: Optional[torch.Tensor] = None |
|
|
image_size_hw: Optional[torch.Tensor] = None |
|
|
image_path: Union[str, List[str], None] = None |
|
|
image_rgb: Optional[torch.Tensor] = None |
|
|
|
|
|
mask_crop: Optional[torch.Tensor] = None |
|
|
depth_path: Union[str, List[str], None] = "" |
|
|
depth_map: Optional[torch.Tensor] = torch.zeros(1) |
|
|
depth_mask: Optional[torch.Tensor] = torch.zeros(1) |
|
|
mask_path: Union[str, List[str], None] = None |
|
|
fg_probability: Optional[torch.Tensor] = None |
|
|
bbox_xywh: Optional[torch.Tensor] = None |
|
|
crop_bbox_xywh: Optional[torch.Tensor] = None |
|
|
camera: Optional[PerspectiveCameras] = None |
|
|
camera_quality_score: Optional[torch.Tensor] = None |
|
|
point_cloud_quality_score: Optional[torch.Tensor] = None |
|
|
sequence_point_cloud_path: Union[str, List[str], None] = "" |
|
|
sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1) |
|
|
sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1) |
|
|
frame_type: Union[str, List[str], None] = "" |
|
|
meta: dict = field(default_factory=lambda: {}) |
|
|
valid_region: Optional[torch.Tensor] = None |
|
|
category_one_hot: Optional[torch.Tensor] = None |
|
|
|
|
|
def to(self, *args, **kwargs): |
|
|
new_params = {} |
|
|
for f in fields(self): |
|
|
value = getattr(self, f.name) |
|
|
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): |
|
|
new_params[f.name] = value.to(*args, **kwargs) |
|
|
else: |
|
|
new_params[f.name] = value |
|
|
return type(self)(**new_params) |
|
|
|
|
|
def cpu(self): |
|
|
return self.to(device=torch.device("cpu")) |
|
|
|
|
|
def cuda(self): |
|
|
return self.to(device=torch.device("cuda")) |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
for f in fields(self): |
|
|
yield f.name |
|
|
|
|
|
def __getitem__(self, key): |
|
|
return getattr(self, key) |
|
|
|
|
|
def __len__(self): |
|
|
return len(fields(self)) |
|
|
|
|
|
@classmethod |
|
|
def collate(cls, batch): |
|
|
""" |
|
|
Given a list objects `batch` of class `cls`, collates them into a batched |
|
|
representation suitable for processing with deep networks. |
|
|
""" |
|
|
|
|
|
elem = batch[0] |
|
|
|
|
|
if isinstance(elem, cls): |
|
|
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] |
|
|
id_to_idx = defaultdict(list) |
|
|
for i, pc_id in enumerate(pointcloud_ids): |
|
|
id_to_idx[pc_id].append(i) |
|
|
|
|
|
sequence_point_cloud = [] |
|
|
sequence_point_cloud_idx = -np.ones((len(batch),)) |
|
|
for i, ind in enumerate(id_to_idx.values()): |
|
|
sequence_point_cloud_idx[ind] = i |
|
|
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) |
|
|
assert (sequence_point_cloud_idx >= 0).all() |
|
|
|
|
|
override_fields = { |
|
|
"sequence_point_cloud": sequence_point_cloud, |
|
|
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), |
|
|
} |
|
|
|
|
|
|
|
|
collated = {} |
|
|
for f in fields(elem): |
|
|
list_values = override_fields.get( |
|
|
f.name, [getattr(d, f.name) for d in batch] |
|
|
) |
|
|
collated[f.name] = ( |
|
|
cls.collate(list_values) |
|
|
if all(list_value is not None for list_value in list_values) |
|
|
else None |
|
|
) |
|
|
return cls(**collated) |
|
|
|
|
|
elif isinstance(elem, Pointclouds): |
|
|
return join_pointclouds_as_batch(batch) |
|
|
|
|
|
elif isinstance(elem, CamerasBase): |
|
|
|
|
|
return join_cameras_as_batch(batch) |
|
|
else: |
|
|
return torch.utils.data._utils.collate.default_collate(batch) |
|
|
|
|
|
|
|
|
|
|
|
class CO3Dv2Wrapper(torch.utils.data.Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
root_dir="/drive/datasets/co3d/", |
|
|
category="hydrant", |
|
|
subset="fewview_train", |
|
|
stage="train", |
|
|
sample_batch_size=20, |
|
|
image_size=256, |
|
|
masked=False, |
|
|
deprecated_val_region=False, |
|
|
return_frame_data_list=False, |
|
|
reso: int = 256, |
|
|
mask_type: str = "random", |
|
|
cond_aug_mean=-3.0, |
|
|
cond_aug_std=0.5, |
|
|
condition_on_elevation=False, |
|
|
fps_id=0.0, |
|
|
motion_bucket_id=300.0, |
|
|
num_frames: int = 20, |
|
|
use_mask: bool = True, |
|
|
load_pixelnerf: bool = True, |
|
|
scale_pose: bool = True, |
|
|
max_n_cond: int = 5, |
|
|
min_n_cond: int = 2, |
|
|
cond_on_multi: bool = False, |
|
|
): |
|
|
root = root_dir |
|
|
from typing import List |
|
|
|
|
|
from co3d.dataset.data_types import ( |
|
|
FrameAnnotation, |
|
|
SequenceAnnotation, |
|
|
load_dataclass_jgzip, |
|
|
) |
|
|
|
|
|
self.dataset_root = root |
|
|
self.path_manager = None |
|
|
self.subset = subset |
|
|
self.stage = stage |
|
|
self.subset_lists_file: List[str] = [ |
|
|
f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json" |
|
|
] |
|
|
self.subsets: Optional[List[str]] = [subset] |
|
|
self.sample_batch_size = sample_batch_size |
|
|
self.limit_to: int = 0 |
|
|
self.limit_sequences_to: int = 0 |
|
|
self.pick_sequence: Tuple[str, ...] = () |
|
|
self.exclude_sequence: Tuple[str, ...] = () |
|
|
self.limit_category_to: Tuple[int, ...] = () |
|
|
self.load_images: bool = True |
|
|
self.load_depths: bool = False |
|
|
self.load_depth_masks: bool = False |
|
|
self.load_masks: bool = True |
|
|
self.load_point_clouds: bool = False |
|
|
self.max_points: int = 0 |
|
|
self.mask_images: bool = False |
|
|
self.mask_depths: bool = False |
|
|
self.image_height: Optional[int] = image_size |
|
|
self.image_width: Optional[int] = image_size |
|
|
self.box_crop: bool = True |
|
|
self.box_crop_mask_thr: float = 0.4 |
|
|
self.box_crop_context: float = 0.3 |
|
|
self.remove_empty_masks: bool = True |
|
|
self.n_frames_per_sequence: int = -1 |
|
|
self.seed: int = 0 |
|
|
self.sort_frames: bool = False |
|
|
self.eval_batches: Any = None |
|
|
|
|
|
self.img_h = self.image_height |
|
|
self.img_w = self.image_width |
|
|
self.masked = masked |
|
|
self.deprecated_val_region = deprecated_val_region |
|
|
self.return_frame_data_list = return_frame_data_list |
|
|
|
|
|
self.reso = reso |
|
|
self.num_frames = num_frames |
|
|
self.cond_aug_mean = cond_aug_mean |
|
|
self.cond_aug_std = cond_aug_std |
|
|
self.condition_on_elevation = condition_on_elevation |
|
|
self.fps_id = fps_id |
|
|
self.motion_bucket_id = motion_bucket_id |
|
|
self.mask_type = mask_type |
|
|
self.use_mask = use_mask |
|
|
self.load_pixelnerf = load_pixelnerf |
|
|
self.scale_pose = scale_pose |
|
|
self.max_n_cond = max_n_cond |
|
|
self.min_n_cond = min_n_cond |
|
|
self.cond_on_multi = cond_on_multi |
|
|
|
|
|
if self.cond_on_multi: |
|
|
assert self.min_n_cond == self.max_n_cond |
|
|
|
|
|
start_time = time.time() |
|
|
if "all_" in category or category == "all": |
|
|
self.category_frame_annotations = [] |
|
|
self.category_sequence_annotations = [] |
|
|
self.subset_lists_file = [] |
|
|
|
|
|
if category == "all": |
|
|
cats = CO3D_ALL_CATEGORIES |
|
|
elif category == "all_four": |
|
|
cats = ["hydrant", "teddybear", "motorcycle", "bench"] |
|
|
elif category == "all_ten": |
|
|
cats = [ |
|
|
"donut", |
|
|
"apple", |
|
|
"hydrant", |
|
|
"vase", |
|
|
"cake", |
|
|
"ball", |
|
|
"bench", |
|
|
"suitcase", |
|
|
"teddybear", |
|
|
"plant", |
|
|
] |
|
|
elif category == "all_15": |
|
|
cats = [ |
|
|
"hydrant", |
|
|
"teddybear", |
|
|
"motorcycle", |
|
|
"bench", |
|
|
"hotdog", |
|
|
"remote", |
|
|
"suitcase", |
|
|
"donut", |
|
|
"plant", |
|
|
"toaster", |
|
|
"keyboard", |
|
|
"handbag", |
|
|
"toyplane", |
|
|
"tv", |
|
|
"orange", |
|
|
] |
|
|
else: |
|
|
print("UNSPECIFIED CATEGORY SUBSET") |
|
|
cats = ["hydrant", "teddybear"] |
|
|
print("loading", cats) |
|
|
for cat in cats: |
|
|
self.category_frame_annotations.extend( |
|
|
load_dataclass_jgzip( |
|
|
f"{self.dataset_root}/{cat}/frame_annotations.jgz", |
|
|
List[FrameAnnotation], |
|
|
) |
|
|
) |
|
|
self.category_sequence_annotations.extend( |
|
|
load_dataclass_jgzip( |
|
|
f"{self.dataset_root}/{cat}/sequence_annotations.jgz", |
|
|
List[SequenceAnnotation], |
|
|
) |
|
|
) |
|
|
self.subset_lists_file.append( |
|
|
f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json" |
|
|
) |
|
|
|
|
|
else: |
|
|
self.category_frame_annotations = load_dataclass_jgzip( |
|
|
f"{self.dataset_root}/{category}/frame_annotations.jgz", |
|
|
List[FrameAnnotation], |
|
|
) |
|
|
self.category_sequence_annotations = load_dataclass_jgzip( |
|
|
f"{self.dataset_root}/{category}/sequence_annotations.jgz", |
|
|
List[SequenceAnnotation], |
|
|
) |
|
|
|
|
|
self.subset_to_image_path = None |
|
|
self._load_frames() |
|
|
self._load_sequences() |
|
|
self._sort_frames() |
|
|
self._load_subset_lists() |
|
|
self._filter_db() |
|
|
|
|
|
|
|
|
logger.info(str(self)) |
|
|
|
|
|
self.seq_to_frames = {} |
|
|
for fi, item in enumerate(self.frame_annots): |
|
|
if item["frame_annotation"].sequence_name in self.seq_to_frames: |
|
|
self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi) |
|
|
else: |
|
|
self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi] |
|
|
|
|
|
if self.stage != "test" or self.subset != "fewview_test": |
|
|
count = 0 |
|
|
new_seq_to_frames = {} |
|
|
for item in self.seq_to_frames: |
|
|
if len(self.seq_to_frames[item]) > 10: |
|
|
count += 1 |
|
|
new_seq_to_frames[item] = self.seq_to_frames[item] |
|
|
self.seq_to_frames = new_seq_to_frames |
|
|
|
|
|
self.seq_list = list(self.seq_to_frames.keys()) |
|
|
|
|
|
|
|
|
remove_list = ["411_55952_107659", "376_42884_85882"] |
|
|
for remove_idx in remove_list: |
|
|
if remove_idx in self.seq_to_frames: |
|
|
self.seq_list.remove(remove_idx) |
|
|
print("removing", remove_idx) |
|
|
|
|
|
print("total training seq", len(self.seq_to_frames)) |
|
|
print("data loading took", time.time() - start_time, "seconds") |
|
|
|
|
|
self.all_category_list = list(CO3D_ALL_CATEGORIES) |
|
|
self.all_category_list.sort() |
|
|
self.cat_to_idx = {} |
|
|
for ci, cname in enumerate(self.all_category_list): |
|
|
self.cat_to_idx[cname] = ci |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.seq_list) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
seq_index = self.seq_list[index] |
|
|
|
|
|
if self.subset == "fewview_test" and self.stage == "test": |
|
|
batch_idx = torch.arange(len(self.seq_to_frames[seq_index])) |
|
|
|
|
|
elif self.stage == "test": |
|
|
batch_idx = ( |
|
|
torch.linspace( |
|
|
0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size |
|
|
) |
|
|
.long() |
|
|
.tolist() |
|
|
) |
|
|
else: |
|
|
rand = torch.randperm(len(self.seq_to_frames[seq_index])) |
|
|
batch_idx = rand[: min(len(rand), self.sample_batch_size)] |
|
|
|
|
|
frame_data_list = [] |
|
|
idx_list = [] |
|
|
timestamp_list = [] |
|
|
for idx in batch_idx: |
|
|
idx_list.append(self.seq_to_frames[seq_index][idx]) |
|
|
timestamp_list.append( |
|
|
self.frame_annots[self.seq_to_frames[seq_index][idx]][ |
|
|
"frame_annotation" |
|
|
].frame_timestamp |
|
|
) |
|
|
frame_data_list.append( |
|
|
self._get_frame(int(self.seq_to_frames[seq_index][idx])) |
|
|
) |
|
|
|
|
|
time_order = torch.argsort(torch.tensor(timestamp_list)) |
|
|
frame_data_list = [frame_data_list[i] for i in time_order] |
|
|
|
|
|
frame_data = FrameData.collate(frame_data_list) |
|
|
image_size = torch.Tensor([self.image_height]).repeat( |
|
|
frame_data.camera.R.shape[0], 2 |
|
|
) |
|
|
frame_dict = { |
|
|
"R": frame_data.camera.R, |
|
|
"T": frame_data.camera.T, |
|
|
"f": frame_data.camera.focal_length, |
|
|
"c": frame_data.camera.principal_point, |
|
|
"images": frame_data.image_rgb * frame_data.fg_probability |
|
|
+ (1 - frame_data.fg_probability), |
|
|
"valid_region": frame_data.mask_crop, |
|
|
"bbox": frame_data.valid_region, |
|
|
"image_size": image_size, |
|
|
"frame_type": frame_data.frame_type, |
|
|
"idx": seq_index, |
|
|
"category": frame_data.category_one_hot, |
|
|
} |
|
|
if not self.masked: |
|
|
frame_dict["images_full"] = frame_data.image_rgb |
|
|
frame_dict["masks"] = frame_data.fg_probability |
|
|
frame_dict["mask_crop"] = frame_data.mask_crop |
|
|
|
|
|
cond_aug = np.exp( |
|
|
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean |
|
|
) |
|
|
|
|
|
def _pad(input): |
|
|
return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[ |
|
|
: self.num_frames |
|
|
] |
|
|
|
|
|
if len(frame_dict["images"]) < self.num_frames: |
|
|
for k in frame_dict: |
|
|
if isinstance(frame_dict[k], torch.Tensor): |
|
|
frame_dict[k] = _pad(frame_dict[k]) |
|
|
|
|
|
data = dict() |
|
|
if "images_full" in frame_dict: |
|
|
frames = frame_dict["images_full"] * 2 - 1 |
|
|
else: |
|
|
frames = frame_dict["images"] * 2 - 1 |
|
|
data["frames"] = frames |
|
|
cond = frames[0] |
|
|
data["cond_frames_without_noise"] = cond |
|
|
data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames) |
|
|
data["cond_frames"] = cond + cond_aug * torch.randn_like(cond) |
|
|
data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames) |
|
|
data["motion_bucket_id"] = torch.as_tensor( |
|
|
[self.motion_bucket_id] * self.num_frames |
|
|
) |
|
|
data["num_video_frames"] = self.num_frames |
|
|
data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames) |
|
|
|
|
|
if self.load_pixelnerf: |
|
|
data["pixelnerf_input"] = dict() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scene_cameras = PerspectiveCameras( |
|
|
R=frame_dict["R"], |
|
|
T=frame_dict["T"], |
|
|
focal_length=frame_dict["f"], |
|
|
principal_point=frame_dict["c"], |
|
|
image_size=frame_dict["image_size"], |
|
|
) |
|
|
R, T, intrinsics = opencv_from_cameras_projection( |
|
|
scene_cameras, frame_dict["image_size"] |
|
|
) |
|
|
c2ws = torch.zeros(R.shape[0], 4, 4) |
|
|
c2ws[:, :3, :3] = R |
|
|
c2ws[:, :3, 3] = T |
|
|
c2ws[:, 3, 3] = 1.0 |
|
|
c2ws = c2ws.inverse() |
|
|
c2ws[..., 1:3] *= -1 |
|
|
intrinsics[:, :2] /= 256 |
|
|
|
|
|
cameras = torch.zeros(c2ws.shape[0], 25) |
|
|
cameras[..., :16] = c2ws.reshape(-1, 16) |
|
|
cameras[..., 16:] = intrinsics.reshape(-1, 9) |
|
|
if self.scale_pose: |
|
|
c2ws = cameras[..., :16].reshape(-1, 4, 4) |
|
|
center = c2ws[:, :3, 3].mean(0) |
|
|
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() |
|
|
scale = 1.5 / radius |
|
|
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale |
|
|
cameras[..., :16] = c2ws.reshape(-1, 16) |
|
|
|
|
|
data["pixelnerf_input"]["frames"] = frames |
|
|
data["pixelnerf_input"]["cameras"] = cameras |
|
|
data["pixelnerf_input"]["rgb"] = ( |
|
|
F.interpolate( |
|
|
frames, |
|
|
(self.image_width // 8, self.image_height // 8), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
+ 1 |
|
|
) * 0.5 |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(self, batch): |
|
|
|
|
|
if self.max_n_cond > 1: |
|
|
|
|
|
n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1) |
|
|
|
|
|
|
|
|
if n_cond > 1: |
|
|
for b in batch: |
|
|
source_index = [0] + np.random.choice( |
|
|
np.arange(1, self.num_frames), |
|
|
self.max_n_cond - 1, |
|
|
replace=False, |
|
|
).tolist() |
|
|
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) |
|
|
b["pixelnerf_input"]["n_cond"] = n_cond |
|
|
b["pixelnerf_input"]["source_images"] = b["frames"][source_index] |
|
|
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ |
|
|
"cameras" |
|
|
][source_index] |
|
|
|
|
|
if self.cond_on_multi: |
|
|
b["cond_frames_without_noise"] = b["frames"][source_index] |
|
|
|
|
|
ret = video_collate_fn(batch) |
|
|
|
|
|
if self.cond_on_multi: |
|
|
ret["cond_frames_without_noise"] = rearrange( |
|
|
ret["cond_frames_without_noise"], "b t ... -> (b t) ..." |
|
|
) |
|
|
|
|
|
return ret |
|
|
|
|
|
def _get_frame(self, index): |
|
|
|
|
|
|
|
|
|
|
|
entry = self.frame_annots[index]["frame_annotation"] |
|
|
|
|
|
point_cloud = self.seq_annots[entry.sequence_name].point_cloud |
|
|
frame_data = FrameData( |
|
|
frame_number=_safe_as_tensor(entry.frame_number, torch.long), |
|
|
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), |
|
|
sequence_name=entry.sequence_name, |
|
|
sequence_category=self.seq_annots[entry.sequence_name].category, |
|
|
camera_quality_score=_safe_as_tensor( |
|
|
self.seq_annots[entry.sequence_name].viewpoint_quality_score, |
|
|
torch.float, |
|
|
), |
|
|
point_cloud_quality_score=_safe_as_tensor( |
|
|
point_cloud.quality_score, torch.float |
|
|
) |
|
|
if point_cloud is not None |
|
|
else None, |
|
|
) |
|
|
|
|
|
|
|
|
frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) |
|
|
|
|
|
( |
|
|
frame_data.fg_probability, |
|
|
frame_data.mask_path, |
|
|
frame_data.bbox_xywh, |
|
|
clamp_bbox_xyxy, |
|
|
frame_data.crop_bbox_xywh, |
|
|
) = self._load_crop_fg_probability(entry) |
|
|
|
|
|
scale = 1.0 |
|
|
if self.load_images and entry.image is not None: |
|
|
|
|
|
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) |
|
|
|
|
|
( |
|
|
frame_data.image_rgb, |
|
|
frame_data.image_path, |
|
|
frame_data.mask_crop, |
|
|
scale, |
|
|
) = self._load_crop_images( |
|
|
entry, frame_data.fg_probability, clamp_bbox_xyxy |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.deprecated_val_region: |
|
|
|
|
|
valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float() |
|
|
|
|
|
valid_bbox[0] = torch.clip( |
|
|
( |
|
|
valid_bbox[0] |
|
|
- torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") |
|
|
) |
|
|
/ torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), |
|
|
-1.0, |
|
|
1.0, |
|
|
) |
|
|
valid_bbox[1] = torch.clip( |
|
|
( |
|
|
valid_bbox[1] |
|
|
- torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") |
|
|
) |
|
|
/ torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), |
|
|
-1.0, |
|
|
1.0, |
|
|
) |
|
|
valid_bbox[2] = torch.clip( |
|
|
( |
|
|
valid_bbox[2] |
|
|
- torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") |
|
|
) |
|
|
/ torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), |
|
|
-1.0, |
|
|
1.0, |
|
|
) |
|
|
valid_bbox[3] = torch.clip( |
|
|
( |
|
|
valid_bbox[3] |
|
|
- torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") |
|
|
) |
|
|
/ torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), |
|
|
-1.0, |
|
|
1.0, |
|
|
) |
|
|
|
|
|
frame_data.valid_region = valid_bbox |
|
|
else: |
|
|
|
|
|
if self.stage == "train": |
|
|
assert self.image_height == 256 and self.image_width == 256 |
|
|
valid = torch.nonzero(frame_data.mask_crop[0]) |
|
|
min_y = valid[:, 0].min() |
|
|
min_x = valid[:, 1].min() |
|
|
max_y = valid[:, 0].max() |
|
|
max_x = valid[:, 1].max() |
|
|
valid_bbox = torch.tensor( |
|
|
[min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device |
|
|
).unsqueeze(0) |
|
|
valid_bbox = torch.clip( |
|
|
(valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0 |
|
|
) |
|
|
frame_data.valid_region = valid_bbox[0] |
|
|
else: |
|
|
valid = torch.nonzero(frame_data.mask_crop[0]) |
|
|
min_y = valid[:, 0].min() |
|
|
min_x = valid[:, 1].min() |
|
|
max_y = valid[:, 0].max() |
|
|
max_x = valid[:, 1].max() |
|
|
valid_bbox = torch.tensor( |
|
|
[min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device |
|
|
).unsqueeze(0) |
|
|
valid_bbox = torch.clip( |
|
|
(valid_bbox - (self.image_height // 2)) / (self.image_height // 2), |
|
|
-1.0, |
|
|
1.0, |
|
|
) |
|
|
frame_data.valid_region = valid_bbox[0] |
|
|
|
|
|
|
|
|
frame_data.category_one_hot = torch.zeros( |
|
|
(len(self.all_category_list)), device=frame_data.image_rgb.device |
|
|
) |
|
|
frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1 |
|
|
|
|
|
if self.load_depths and entry.depth is not None: |
|
|
( |
|
|
frame_data.depth_map, |
|
|
frame_data.depth_path, |
|
|
frame_data.depth_mask, |
|
|
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) |
|
|
|
|
|
if entry.viewpoint is not None: |
|
|
frame_data.camera = self._get_pytorch3d_camera( |
|
|
entry, |
|
|
scale, |
|
|
clamp_bbox_xyxy, |
|
|
) |
|
|
|
|
|
if self.load_point_clouds and point_cloud is not None: |
|
|
frame_data.sequence_point_cloud_path = pcl_path = os.path.join( |
|
|
self.dataset_root, point_cloud.path |
|
|
) |
|
|
frame_data.sequence_point_cloud = _load_pointcloud( |
|
|
self._local_path(pcl_path), max_points=self.max_points |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return frame_data |
|
|
|
|
|
def _extract_and_set_eval_batches(self): |
|
|
""" |
|
|
Sets eval_batches based on input eval_batch_index. |
|
|
""" |
|
|
if self.eval_batch_index is not None: |
|
|
if self.eval_batches is not None: |
|
|
raise ValueError( |
|
|
"Cannot define both eval_batch_index and eval_batches." |
|
|
) |
|
|
self.eval_batches = self.seq_frame_index_to_dataset_index( |
|
|
self.eval_batch_index |
|
|
) |
|
|
|
|
|
def _load_crop_fg_probability( |
|
|
self, entry: types.FrameAnnotation |
|
|
) -> Tuple[ |
|
|
Optional[torch.Tensor], |
|
|
Optional[str], |
|
|
Optional[torch.Tensor], |
|
|
Optional[torch.Tensor], |
|
|
Optional[torch.Tensor], |
|
|
]: |
|
|
fg_probability = None |
|
|
full_path = None |
|
|
bbox_xywh = None |
|
|
clamp_bbox_xyxy = None |
|
|
crop_box_xywh = None |
|
|
|
|
|
if (self.load_masks or self.box_crop) and entry.mask is not None: |
|
|
full_path = os.path.join(self.dataset_root, entry.mask.path) |
|
|
mask = _load_mask(self._local_path(full_path)) |
|
|
|
|
|
if mask.shape[-2:] != entry.image.size: |
|
|
raise ValueError( |
|
|
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" |
|
|
) |
|
|
|
|
|
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) |
|
|
|
|
|
if self.box_crop: |
|
|
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( |
|
|
_get_clamp_bbox( |
|
|
bbox_xywh, |
|
|
image_path=entry.image.path, |
|
|
box_crop_context=self.box_crop_context, |
|
|
), |
|
|
image_size_hw=tuple(mask.shape[-2:]), |
|
|
) |
|
|
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) |
|
|
|
|
|
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) |
|
|
|
|
|
fg_probability, _, _ = self._resize_image(mask, mode="nearest") |
|
|
|
|
|
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh |
|
|
|
|
|
def _load_crop_images( |
|
|
self, |
|
|
entry: types.FrameAnnotation, |
|
|
fg_probability: Optional[torch.Tensor], |
|
|
clamp_bbox_xyxy: Optional[torch.Tensor], |
|
|
) -> Tuple[torch.Tensor, str, torch.Tensor, float]: |
|
|
assert self.dataset_root is not None and entry.image is not None |
|
|
path = os.path.join(self.dataset_root, entry.image.path) |
|
|
image_rgb = _load_image(self._local_path(path)) |
|
|
|
|
|
if image_rgb.shape[-2:] != entry.image.size: |
|
|
raise ValueError( |
|
|
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" |
|
|
) |
|
|
|
|
|
if self.box_crop: |
|
|
assert clamp_bbox_xyxy is not None |
|
|
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) |
|
|
|
|
|
image_rgb, scale, mask_crop = self._resize_image(image_rgb) |
|
|
|
|
|
if self.mask_images: |
|
|
assert fg_probability is not None |
|
|
image_rgb *= fg_probability |
|
|
|
|
|
return image_rgb, path, mask_crop, scale |
|
|
|
|
|
def _load_mask_depth( |
|
|
self, |
|
|
entry: types.FrameAnnotation, |
|
|
clamp_bbox_xyxy: Optional[torch.Tensor], |
|
|
fg_probability: Optional[torch.Tensor], |
|
|
) -> Tuple[torch.Tensor, str, torch.Tensor]: |
|
|
entry_depth = entry.depth |
|
|
assert entry_depth is not None |
|
|
path = os.path.join(self.dataset_root, entry_depth.path) |
|
|
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) |
|
|
|
|
|
if self.box_crop: |
|
|
assert clamp_bbox_xyxy is not None |
|
|
depth_bbox_xyxy = _rescale_bbox( |
|
|
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] |
|
|
) |
|
|
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) |
|
|
|
|
|
depth_map, _, _ = self._resize_image(depth_map, mode="nearest") |
|
|
|
|
|
if self.mask_depths: |
|
|
assert fg_probability is not None |
|
|
depth_map *= fg_probability |
|
|
|
|
|
if self.load_depth_masks: |
|
|
assert entry_depth.mask_path is not None |
|
|
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) |
|
|
depth_mask = _load_depth_mask(self._local_path(mask_path)) |
|
|
|
|
|
if self.box_crop: |
|
|
assert clamp_bbox_xyxy is not None |
|
|
depth_mask_bbox_xyxy = _rescale_bbox( |
|
|
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] |
|
|
) |
|
|
depth_mask = _crop_around_box( |
|
|
depth_mask, depth_mask_bbox_xyxy, mask_path |
|
|
) |
|
|
|
|
|
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") |
|
|
else: |
|
|
depth_mask = torch.ones_like(depth_map) |
|
|
|
|
|
return depth_map, path, depth_mask |
|
|
|
|
|
def _get_pytorch3d_camera( |
|
|
self, |
|
|
entry: types.FrameAnnotation, |
|
|
scale: float, |
|
|
clamp_bbox_xyxy: Optional[torch.Tensor], |
|
|
) -> PerspectiveCameras: |
|
|
entry_viewpoint = entry.viewpoint |
|
|
assert entry_viewpoint is not None |
|
|
|
|
|
principal_point = torch.tensor( |
|
|
entry_viewpoint.principal_point, dtype=torch.float |
|
|
) |
|
|
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) |
|
|
|
|
|
half_image_size_wh_orig = ( |
|
|
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 |
|
|
) |
|
|
|
|
|
|
|
|
format = entry_viewpoint.intrinsics_format |
|
|
if format.lower() == "ndc_norm_image_bounds": |
|
|
|
|
|
rescale = half_image_size_wh_orig |
|
|
elif format.lower() == "ndc_isotropic": |
|
|
rescale = half_image_size_wh_orig.min() |
|
|
else: |
|
|
raise ValueError(f"Unknown intrinsics format: {format}") |
|
|
|
|
|
|
|
|
principal_point_px = half_image_size_wh_orig - principal_point * rescale |
|
|
focal_length_px = focal_length * rescale |
|
|
if self.box_crop: |
|
|
assert clamp_bbox_xyxy is not None |
|
|
principal_point_px -= clamp_bbox_xyxy[:2] |
|
|
|
|
|
|
|
|
if self.image_height is None or self.image_width is None: |
|
|
out_size = list(reversed(entry.image.size)) |
|
|
else: |
|
|
out_size = [self.image_width, self.image_height] |
|
|
|
|
|
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 |
|
|
half_min_image_size_output = half_image_size_output.min() |
|
|
|
|
|
|
|
|
principal_point = ( |
|
|
half_image_size_output - principal_point_px * scale |
|
|
) / half_min_image_size_output |
|
|
focal_length = focal_length_px * scale / half_min_image_size_output |
|
|
|
|
|
return PerspectiveCameras( |
|
|
focal_length=focal_length[None], |
|
|
principal_point=principal_point[None], |
|
|
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], |
|
|
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], |
|
|
) |
|
|
|
|
|
def _load_frames(self) -> None: |
|
|
self.frame_annots = [ |
|
|
FrameAnnotsEntry(frame_annotation=a, subset=None) |
|
|
for a in self.category_frame_annotations |
|
|
] |
|
|
|
|
|
def _load_sequences(self) -> None: |
|
|
self.seq_annots = { |
|
|
entry.sequence_name: entry for entry in self.category_sequence_annotations |
|
|
} |
|
|
|
|
|
def _load_subset_lists(self) -> None: |
|
|
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.") |
|
|
if not self.subset_lists_file: |
|
|
return |
|
|
|
|
|
frame_path_to_subset = {} |
|
|
|
|
|
for subset_list_file in self.subset_lists_file: |
|
|
with open(self._local_path(subset_list_file), "r") as f: |
|
|
subset_to_seq_frame = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for set_ in subset_to_seq_frame: |
|
|
for _, _, path in subset_to_seq_frame[set_]: |
|
|
if path in frame_path_to_subset: |
|
|
frame_path_to_subset[path].add(set_) |
|
|
else: |
|
|
frame_path_to_subset[path] = {set_} |
|
|
|
|
|
|
|
|
for frame in self.frame_annots: |
|
|
frame["subset"] = frame_path_to_subset.get( |
|
|
frame["frame_annotation"].image.path, None |
|
|
) |
|
|
|
|
|
if frame["subset"] is None: |
|
|
continue |
|
|
warnings.warn( |
|
|
"Subset lists are given but don't include " |
|
|
+ frame["frame_annotation"].image.path |
|
|
) |
|
|
|
|
|
def _sort_frames(self) -> None: |
|
|
|
|
|
|
|
|
self.frame_annots = sorted( |
|
|
self.frame_annots, |
|
|
key=lambda f: ( |
|
|
f["frame_annotation"].sequence_name, |
|
|
f["frame_annotation"].frame_timestamp or 0, |
|
|
), |
|
|
) |
|
|
|
|
|
def _filter_db(self) -> None: |
|
|
if self.remove_empty_masks: |
|
|
logger.info("Removing images with empty masks.") |
|
|
|
|
|
old_len = len(self.frame_annots) |
|
|
|
|
|
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." |
|
|
|
|
|
def positive_mass(frame_annot: types.FrameAnnotation) -> bool: |
|
|
mask = frame_annot.mask |
|
|
if mask is None: |
|
|
return False |
|
|
if mask.mass is None: |
|
|
raise ValueError(msg) |
|
|
return mask.mass > 1 |
|
|
|
|
|
self.frame_annots = [ |
|
|
frame |
|
|
for frame in self.frame_annots |
|
|
if positive_mass(frame["frame_annotation"]) |
|
|
] |
|
|
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots))) |
|
|
|
|
|
|
|
|
subsets = self.subsets |
|
|
if subsets: |
|
|
if not self.subset_lists_file: |
|
|
raise ValueError( |
|
|
"Subset filter is on but subset_lists_file was not given" |
|
|
) |
|
|
|
|
|
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.") |
|
|
|
|
|
|
|
|
self.frame_annots = [ |
|
|
entry |
|
|
for entry in self.frame_annots |
|
|
if (entry["subset"] is not None and self.stage in entry["subset"]) |
|
|
] |
|
|
|
|
|
if len(self.frame_annots) == 0: |
|
|
raise ValueError(f"There are no frames in the '{subsets}' subsets!") |
|
|
|
|
|
self._invalidate_indexes(filter_seq_annots=True) |
|
|
|
|
|
if len(self.limit_category_to) > 0: |
|
|
logger.info(f"Limiting dataset to categories: {self.limit_category_to}") |
|
|
|
|
|
self.seq_annots = { |
|
|
name: entry |
|
|
for name, entry in self.seq_annots.items() |
|
|
if entry.category in self.limit_category_to |
|
|
} |
|
|
|
|
|
|
|
|
for prefix in ("pick", "exclude"): |
|
|
orig_len = len(self.seq_annots) |
|
|
attr = f"{prefix}_sequence" |
|
|
arr = getattr(self, attr) |
|
|
if len(arr) > 0: |
|
|
logger.info(f"{attr}: {str(arr)}") |
|
|
self.seq_annots = { |
|
|
name: entry |
|
|
for name, entry in self.seq_annots.items() |
|
|
if (name in arr) == (prefix == "pick") |
|
|
} |
|
|
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) |
|
|
|
|
|
if self.limit_sequences_to > 0: |
|
|
self.seq_annots = dict( |
|
|
islice(self.seq_annots.items(), self.limit_sequences_to) |
|
|
) |
|
|
|
|
|
|
|
|
self.frame_annots = [ |
|
|
f |
|
|
for f in self.frame_annots |
|
|
if f["frame_annotation"].sequence_name in self.seq_annots |
|
|
] |
|
|
|
|
|
self._invalidate_indexes() |
|
|
|
|
|
if self.n_frames_per_sequence > 0: |
|
|
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") |
|
|
keep_idx = [] |
|
|
|
|
|
for seq, seq_indices in self._seq_to_idx.items(): |
|
|
|
|
|
|
|
|
seed = _seq_name_to_seed(seq) + self.seed |
|
|
seq_idx_shuffled = random.Random(seed).sample( |
|
|
sorted(seq_indices), len(seq_indices) |
|
|
) |
|
|
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) |
|
|
|
|
|
logger.info( |
|
|
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)) |
|
|
) |
|
|
self.frame_annots = [self.frame_annots[i] for i in keep_idx] |
|
|
self._invalidate_indexes(filter_seq_annots=False) |
|
|
|
|
|
|
|
|
if self.limit_to > 0 and self.limit_to < len(self.frame_annots): |
|
|
logger.info( |
|
|
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) |
|
|
) |
|
|
self.frame_annots = self.frame_annots[: self.limit_to] |
|
|
self._invalidate_indexes(filter_seq_annots=True) |
|
|
|
|
|
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: |
|
|
|
|
|
|
|
|
self._invalidate_seq_to_idx() |
|
|
|
|
|
if filter_seq_annots: |
|
|
|
|
|
self.seq_annots = { |
|
|
k: v |
|
|
for k, v in self.seq_annots.items() |
|
|
|
|
|
if k in self._seq_to_idx |
|
|
} |
|
|
|
|
|
def _invalidate_seq_to_idx(self) -> None: |
|
|
seq_to_idx = defaultdict(list) |
|
|
|
|
|
for idx, entry in enumerate(self.frame_annots): |
|
|
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) |
|
|
|
|
|
self._seq_to_idx = seq_to_idx |
|
|
|
|
|
def _resize_image( |
|
|
self, image, mode="bilinear" |
|
|
) -> Tuple[torch.Tensor, float, torch.Tensor]: |
|
|
image_height, image_width = self.image_height, self.image_width |
|
|
if image_height is None or image_width is None: |
|
|
|
|
|
imre_ = torch.from_numpy(image) |
|
|
return imre_, 1.0, torch.ones_like(imre_[:1]) |
|
|
|
|
|
minscale = min( |
|
|
image_height / image.shape[-2], |
|
|
image_width / image.shape[-1], |
|
|
) |
|
|
imre = torch.nn.functional.interpolate( |
|
|
torch.from_numpy(image)[None], |
|
|
scale_factor=minscale, |
|
|
mode=mode, |
|
|
align_corners=False if mode == "bilinear" else None, |
|
|
recompute_scale_factor=True, |
|
|
)[0] |
|
|
|
|
|
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) |
|
|
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre |
|
|
|
|
|
|
|
|
mask = torch.zeros(1, self.image_height, self.image_width) |
|
|
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 |
|
|
return imre_, minscale, mask |
|
|
|
|
|
def _local_path(self, path: str) -> str: |
|
|
if self.path_manager is None: |
|
|
return path |
|
|
return self.path_manager.get_local_path(path) |
|
|
|
|
|
def get_frame_numbers_and_timestamps( |
|
|
self, idxs: Sequence[int] |
|
|
) -> List[Tuple[int, float]]: |
|
|
out: List[Tuple[int, float]] = [] |
|
|
for idx in idxs: |
|
|
|
|
|
frame_annotation = self.frame_annots[idx]["frame_annotation"] |
|
|
out.append( |
|
|
(frame_annotation.frame_number, frame_annotation.frame_timestamp) |
|
|
) |
|
|
return out |
|
|
|
|
|
def get_eval_batches(self) -> Optional[List[List[int]]]: |
|
|
return self.eval_batches |
|
|
|
|
|
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: |
|
|
return entry["frame_annotation"].meta["frame_type"] |
|
|
|
|
|
|
|
|
class CO3DDataset(LightningDataModule): |
|
|
def __init__( |
|
|
self, |
|
|
root_dir, |
|
|
batch_size=2, |
|
|
shuffle=True, |
|
|
num_workers=10, |
|
|
prefetch_factor=2, |
|
|
category="hydrant", |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.num_workers = num_workers |
|
|
self.prefetch_factor = prefetch_factor |
|
|
self.shuffle = shuffle |
|
|
|
|
|
self.train_dataset = CO3Dv2Wrapper( |
|
|
root_dir=root_dir, |
|
|
stage="train", |
|
|
category=category, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
self.test_dataset = CO3Dv2Wrapper( |
|
|
root_dir=root_dir, |
|
|
stage="test", |
|
|
subset="fewview_dev", |
|
|
category=category, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader( |
|
|
self.train_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=self.shuffle, |
|
|
num_workers=self.num_workers, |
|
|
prefetch_factor=self.prefetch_factor, |
|
|
collate_fn=self.train_dataset.collate_fn, |
|
|
) |
|
|
|
|
|
def test_dataloader(self): |
|
|
return DataLoader( |
|
|
self.test_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=self.shuffle, |
|
|
num_workers=self.num_workers, |
|
|
prefetch_factor=self.prefetch_factor, |
|
|
collate_fn=self.test_dataset.collate_fn, |
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader( |
|
|
self.test_dataset, |
|
|
batch_size=self.batch_size, |
|
|
shuffle=self.shuffle, |
|
|
num_workers=self.num_workers, |
|
|
prefetch_factor=self.prefetch_factor, |
|
|
collate_fn=video_collate_fn, |
|
|
) |
|
|
|