Spaces:
Runtime error
Runtime error
| """ | |
| adopted from SparseFusion | |
| Wrapper for the full CO3Dv2 dataset | |
| #@ Modified from https://github.com/facebookresearch/pytorch3d | |
| """ | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import time | |
| import warnings | |
| from collections import defaultdict | |
| from itertools import islice | |
| from typing import ( | |
| Any, | |
| ClassVar, | |
| List, | |
| Mapping, | |
| Optional, | |
| Sequence, | |
| Tuple, | |
| Type, | |
| TypedDict, | |
| Union, | |
| ) | |
| from einops import rearrange, repeat | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from pytorch3d.utils import opencv_from_cameras_projection | |
| from pytorch3d.implicitron.dataset import types | |
| from pytorch3d.implicitron.dataset.dataset_base import DatasetBase | |
| from sgm.data.json_index_dataset import ( | |
| FrameAnnotsEntry, | |
| _bbox_xywh_to_xyxy, | |
| _bbox_xyxy_to_xywh, | |
| _clamp_box_to_image_bounds_and_round, | |
| _crop_around_box, | |
| _get_1d_bounds, | |
| _get_bbox_from_mask, | |
| _get_clamp_bbox, | |
| _load_1bit_png_mask, | |
| _load_16big_png_depth, | |
| _load_depth, | |
| _load_depth_mask, | |
| _load_image, | |
| _load_mask, | |
| _load_pointcloud, | |
| _rescale_bbox, | |
| _safe_as_tensor, | |
| _seq_name_to_seed, | |
| ) | |
| from sgm.data.objaverse import video_collate_fn | |
| from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( | |
| get_available_subset_names, | |
| ) | |
| from pytorch3d.renderer.cameras import PerspectiveCameras | |
| logger = logging.getLogger(__name__) | |
| from dataclasses import dataclass, field, fields | |
| from pytorch3d.renderer.camera_utils import join_cameras_as_batch | |
| from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras | |
| from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch | |
| from pytorch_lightning import LightningDataModule | |
| from torch.utils.data import DataLoader | |
| CO3D_ALL_CATEGORIES = list( | |
| reversed( | |
| [ | |
| "baseballbat", | |
| "banana", | |
| "bicycle", | |
| "microwave", | |
| "tv", | |
| "cellphone", | |
| "toilet", | |
| "hairdryer", | |
| "couch", | |
| "kite", | |
| "pizza", | |
| "umbrella", | |
| "wineglass", | |
| "laptop", | |
| "hotdog", | |
| "stopsign", | |
| "frisbee", | |
| "baseballglove", | |
| "cup", | |
| "parkingmeter", | |
| "backpack", | |
| "toyplane", | |
| "toybus", | |
| "handbag", | |
| "chair", | |
| "keyboard", | |
| "car", | |
| "motorcycle", | |
| "carrot", | |
| "bottle", | |
| "sandwich", | |
| "remote", | |
| "bowl", | |
| "skateboard", | |
| "toaster", | |
| "mouse", | |
| "toytrain", | |
| "book", | |
| "toytruck", | |
| "orange", | |
| "broccoli", | |
| "plant", | |
| "teddybear", | |
| "suitcase", | |
| "bench", | |
| "ball", | |
| "cake", | |
| "vase", | |
| "hydrant", | |
| "apple", | |
| "donut", | |
| ] | |
| ) | |
| ) | |
| CO3D_ALL_TEN = [ | |
| "donut", | |
| "apple", | |
| "hydrant", | |
| "vase", | |
| "cake", | |
| "ball", | |
| "bench", | |
| "suitcase", | |
| "teddybear", | |
| "plant", | |
| ] | |
| # @ FROM https://github.com/facebookresearch/pytorch3d | |
| class FrameData(Mapping[str, Any]): | |
| """ | |
| A type of the elements returned by indexing the dataset object. | |
| It can represent both individual frames and batches of thereof; | |
| in this documentation, the sizes of tensors refer to single frames; | |
| add the first batch dimension for the collation result. | |
| Args: | |
| frame_number: The number of the frame within its sequence. | |
| 0-based continuous integers. | |
| sequence_name: The unique name of the frame's sequence. | |
| sequence_category: The object category of the sequence. | |
| frame_timestamp: The time elapsed since the start of a sequence in sec. | |
| image_size_hw: The size of the image in pixels; (height, width) tensor | |
| of shape (2,). | |
| image_path: The qualified path to the loaded image (with dataset_root). | |
| image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image | |
| of the frame; elements are floats in [0, 1]. | |
| mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image | |
| regions. Regions can be invalid (mask_crop[i,j]=0) in case they | |
| are a result of zero-padding of the image after cropping around | |
| the object bounding box; elements are floats in {0.0, 1.0}. | |
| depth_path: The qualified path to the frame's depth map. | |
| depth_map: A float Tensor of shape `(1, H, W)` holding the depth map | |
| of the frame; values correspond to distances from the camera; | |
| use `depth_mask` and `mask_crop` to filter for valid pixels. | |
| depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the | |
| depth map that are valid for evaluation, they have been checked for | |
| consistency across views; elements are floats in {0.0, 1.0}. | |
| mask_path: A qualified path to the foreground probability mask. | |
| fg_probability: A Tensor of `(1, H, W)` denoting the probability of the | |
| pixels belonging to the captured object; elements are floats | |
| in [0, 1]. | |
| bbox_xywh: The bounding box tightly enclosing the foreground object in the | |
| format (x0, y0, width, height). The convention assumes that | |
| `x0+width` and `y0+height` includes the boundary of the box. | |
| I.e., to slice out the corresponding crop from an image tensor `I` | |
| we execute `crop = I[..., y0:y0+height, x0:x0+width]` | |
| crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` | |
| in the original image coordinates in the format (x0, y0, width, height). | |
| The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs | |
| from `bbox_xywh` due to padding (which can happen e.g. due to | |
| setting `JsonIndexDataset.box_crop_context > 0`) | |
| camera: A PyTorch3D camera object corresponding the frame's viewpoint, | |
| corrected for cropping if it happened. | |
| camera_quality_score: The score proportional to the confidence of the | |
| frame's camera estimation (the higher the more accurate). | |
| point_cloud_quality_score: The score proportional to the accuracy of the | |
| frame's sequence point cloud (the higher the more accurate). | |
| sequence_point_cloud_path: The path to the sequence's point cloud. | |
| sequence_point_cloud: A PyTorch3D Pointclouds object holding the | |
| point cloud corresponding to the frame's sequence. When the object | |
| represents a batch of frames, point clouds may be deduplicated; | |
| see `sequence_point_cloud_idx`. | |
| sequence_point_cloud_idx: Integer indices mapping frame indices to the | |
| corresponding point clouds in `sequence_point_cloud`; to get the | |
| corresponding point cloud to `image_rgb[i]`, use | |
| `sequence_point_cloud[sequence_point_cloud_idx[i]]`. | |
| frame_type: The type of the loaded frame specified in | |
| `subset_lists_file`, if provided. | |
| meta: A dict for storing additional frame information. | |
| """ | |
| frame_number: Optional[torch.LongTensor] | |
| sequence_name: Union[str, List[str]] | |
| sequence_category: Union[str, List[str]] | |
| frame_timestamp: Optional[torch.Tensor] = None | |
| image_size_hw: Optional[torch.Tensor] = None | |
| image_path: Union[str, List[str], None] = None | |
| image_rgb: Optional[torch.Tensor] = None | |
| # masks out padding added due to cropping the square bit | |
| mask_crop: Optional[torch.Tensor] = None | |
| depth_path: Union[str, List[str], None] = "" | |
| depth_map: Optional[torch.Tensor] = torch.zeros(1) | |
| depth_mask: Optional[torch.Tensor] = torch.zeros(1) | |
| mask_path: Union[str, List[str], None] = None | |
| fg_probability: Optional[torch.Tensor] = None | |
| bbox_xywh: Optional[torch.Tensor] = None | |
| crop_bbox_xywh: Optional[torch.Tensor] = None | |
| camera: Optional[PerspectiveCameras] = None | |
| camera_quality_score: Optional[torch.Tensor] = None | |
| point_cloud_quality_score: Optional[torch.Tensor] = None | |
| sequence_point_cloud_path: Union[str, List[str], None] = "" | |
| sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1) | |
| sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1) | |
| frame_type: Union[str, List[str], None] = "" # known | unseen | |
| meta: dict = field(default_factory=lambda: {}) | |
| valid_region: Optional[torch.Tensor] = None | |
| category_one_hot: Optional[torch.Tensor] = None | |
| def to(self, *args, **kwargs): | |
| new_params = {} | |
| for f in fields(self): | |
| value = getattr(self, f.name) | |
| if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): | |
| new_params[f.name] = value.to(*args, **kwargs) | |
| else: | |
| new_params[f.name] = value | |
| return type(self)(**new_params) | |
| def cpu(self): | |
| return self.to(device=torch.device("cpu")) | |
| def cuda(self): | |
| return self.to(device=torch.device("cuda")) | |
| # the following functions make sure **frame_data can be passed to functions | |
| def __iter__(self): | |
| for f in fields(self): | |
| yield f.name | |
| def __getitem__(self, key): | |
| return getattr(self, key) | |
| def __len__(self): | |
| return len(fields(self)) | |
| def collate(cls, batch): | |
| """ | |
| Given a list objects `batch` of class `cls`, collates them into a batched | |
| representation suitable for processing with deep networks. | |
| """ | |
| elem = batch[0] | |
| if isinstance(elem, cls): | |
| pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] | |
| id_to_idx = defaultdict(list) | |
| for i, pc_id in enumerate(pointcloud_ids): | |
| id_to_idx[pc_id].append(i) | |
| sequence_point_cloud = [] | |
| sequence_point_cloud_idx = -np.ones((len(batch),)) | |
| for i, ind in enumerate(id_to_idx.values()): | |
| sequence_point_cloud_idx[ind] = i | |
| sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) | |
| assert (sequence_point_cloud_idx >= 0).all() | |
| override_fields = { | |
| "sequence_point_cloud": sequence_point_cloud, | |
| "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), | |
| } | |
| # note that the pre-collate value of sequence_point_cloud_idx is unused | |
| collated = {} | |
| for f in fields(elem): | |
| list_values = override_fields.get( | |
| f.name, [getattr(d, f.name) for d in batch] | |
| ) | |
| collated[f.name] = ( | |
| cls.collate(list_values) | |
| if all(list_value is not None for list_value in list_values) | |
| else None | |
| ) | |
| return cls(**collated) | |
| elif isinstance(elem, Pointclouds): | |
| return join_pointclouds_as_batch(batch) | |
| elif isinstance(elem, CamerasBase): | |
| # TODO: don't store K; enforce working in NDC space | |
| return join_cameras_as_batch(batch) | |
| else: | |
| return torch.utils.data._utils.collate.default_collate(batch) | |
| # @ MODIFIED FROM https://github.com/facebookresearch/pytorch3d | |
| class CO3Dv2Wrapper(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| root_dir="/drive/datasets/co3d/", | |
| category="hydrant", | |
| subset="fewview_train", | |
| stage="train", | |
| sample_batch_size=20, | |
| image_size=256, | |
| masked=False, | |
| deprecated_val_region=False, | |
| return_frame_data_list=False, | |
| reso: int = 256, | |
| mask_type: str = "random", | |
| cond_aug_mean=-3.0, | |
| cond_aug_std=0.5, | |
| condition_on_elevation=False, | |
| fps_id=0.0, | |
| motion_bucket_id=300.0, | |
| num_frames: int = 20, | |
| use_mask: bool = True, | |
| load_pixelnerf: bool = True, | |
| scale_pose: bool = True, | |
| max_n_cond: int = 5, | |
| min_n_cond: int = 2, | |
| cond_on_multi: bool = False, | |
| ): | |
| root = root_dir | |
| from typing import List | |
| from co3d.dataset.data_types import ( | |
| FrameAnnotation, | |
| SequenceAnnotation, | |
| load_dataclass_jgzip, | |
| ) | |
| self.dataset_root = root | |
| self.path_manager = None | |
| self.subset = subset | |
| self.stage = stage | |
| self.subset_lists_file: List[str] = [ | |
| f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json" | |
| ] | |
| self.subsets: Optional[List[str]] = [subset] | |
| self.sample_batch_size = sample_batch_size | |
| self.limit_to: int = 0 | |
| self.limit_sequences_to: int = 0 | |
| self.pick_sequence: Tuple[str, ...] = () | |
| self.exclude_sequence: Tuple[str, ...] = () | |
| self.limit_category_to: Tuple[int, ...] = () | |
| self.load_images: bool = True | |
| self.load_depths: bool = False | |
| self.load_depth_masks: bool = False | |
| self.load_masks: bool = True | |
| self.load_point_clouds: bool = False | |
| self.max_points: int = 0 | |
| self.mask_images: bool = False | |
| self.mask_depths: bool = False | |
| self.image_height: Optional[int] = image_size | |
| self.image_width: Optional[int] = image_size | |
| self.box_crop: bool = True | |
| self.box_crop_mask_thr: float = 0.4 | |
| self.box_crop_context: float = 0.3 | |
| self.remove_empty_masks: bool = True | |
| self.n_frames_per_sequence: int = -1 | |
| self.seed: int = 0 | |
| self.sort_frames: bool = False | |
| self.eval_batches: Any = None | |
| self.img_h = self.image_height | |
| self.img_w = self.image_width | |
| self.masked = masked | |
| self.deprecated_val_region = deprecated_val_region | |
| self.return_frame_data_list = return_frame_data_list | |
| self.reso = reso | |
| self.num_frames = num_frames | |
| self.cond_aug_mean = cond_aug_mean | |
| self.cond_aug_std = cond_aug_std | |
| self.condition_on_elevation = condition_on_elevation | |
| self.fps_id = fps_id | |
| self.motion_bucket_id = motion_bucket_id | |
| self.mask_type = mask_type | |
| self.use_mask = use_mask | |
| self.load_pixelnerf = load_pixelnerf | |
| self.scale_pose = scale_pose | |
| self.max_n_cond = max_n_cond | |
| self.min_n_cond = min_n_cond | |
| self.cond_on_multi = cond_on_multi | |
| if self.cond_on_multi: | |
| assert self.min_n_cond == self.max_n_cond | |
| start_time = time.time() | |
| if "all_" in category or category == "all": | |
| self.category_frame_annotations = [] | |
| self.category_sequence_annotations = [] | |
| self.subset_lists_file = [] | |
| if category == "all": | |
| cats = CO3D_ALL_CATEGORIES | |
| elif category == "all_four": | |
| cats = ["hydrant", "teddybear", "motorcycle", "bench"] | |
| elif category == "all_ten": | |
| cats = [ | |
| "donut", | |
| "apple", | |
| "hydrant", | |
| "vase", | |
| "cake", | |
| "ball", | |
| "bench", | |
| "suitcase", | |
| "teddybear", | |
| "plant", | |
| ] | |
| elif category == "all_15": | |
| cats = [ | |
| "hydrant", | |
| "teddybear", | |
| "motorcycle", | |
| "bench", | |
| "hotdog", | |
| "remote", | |
| "suitcase", | |
| "donut", | |
| "plant", | |
| "toaster", | |
| "keyboard", | |
| "handbag", | |
| "toyplane", | |
| "tv", | |
| "orange", | |
| ] | |
| else: | |
| print("UNSPECIFIED CATEGORY SUBSET") | |
| cats = ["hydrant", "teddybear"] | |
| print("loading", cats) | |
| for cat in cats: | |
| self.category_frame_annotations.extend( | |
| load_dataclass_jgzip( | |
| f"{self.dataset_root}/{cat}/frame_annotations.jgz", | |
| List[FrameAnnotation], | |
| ) | |
| ) | |
| self.category_sequence_annotations.extend( | |
| load_dataclass_jgzip( | |
| f"{self.dataset_root}/{cat}/sequence_annotations.jgz", | |
| List[SequenceAnnotation], | |
| ) | |
| ) | |
| self.subset_lists_file.append( | |
| f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json" | |
| ) | |
| else: | |
| self.category_frame_annotations = load_dataclass_jgzip( | |
| f"{self.dataset_root}/{category}/frame_annotations.jgz", | |
| List[FrameAnnotation], | |
| ) | |
| self.category_sequence_annotations = load_dataclass_jgzip( | |
| f"{self.dataset_root}/{category}/sequence_annotations.jgz", | |
| List[SequenceAnnotation], | |
| ) | |
| self.subset_to_image_path = None | |
| self._load_frames() | |
| self._load_sequences() | |
| self._sort_frames() | |
| self._load_subset_lists() | |
| self._filter_db() # also computes sequence indices | |
| # self._extract_and_set_eval_batches() | |
| # print(self.eval_batches) | |
| logger.info(str(self)) | |
| self.seq_to_frames = {} | |
| for fi, item in enumerate(self.frame_annots): | |
| if item["frame_annotation"].sequence_name in self.seq_to_frames: | |
| self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi) | |
| else: | |
| self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi] | |
| if self.stage != "test" or self.subset != "fewview_test": | |
| count = 0 | |
| new_seq_to_frames = {} | |
| for item in self.seq_to_frames: | |
| if len(self.seq_to_frames[item]) > 10: | |
| count += 1 | |
| new_seq_to_frames[item] = self.seq_to_frames[item] | |
| self.seq_to_frames = new_seq_to_frames | |
| self.seq_list = list(self.seq_to_frames.keys()) | |
| # @ REMOVE A FEW TRAINING SEQ THAT CAUSES BUG | |
| remove_list = ["411_55952_107659", "376_42884_85882"] | |
| for remove_idx in remove_list: | |
| if remove_idx in self.seq_to_frames: | |
| self.seq_list.remove(remove_idx) | |
| print("removing", remove_idx) | |
| print("total training seq", len(self.seq_to_frames)) | |
| print("data loading took", time.time() - start_time, "seconds") | |
| self.all_category_list = list(CO3D_ALL_CATEGORIES) | |
| self.all_category_list.sort() | |
| self.cat_to_idx = {} | |
| for ci, cname in enumerate(self.all_category_list): | |
| self.cat_to_idx[cname] = ci | |
| def __len__(self): | |
| return len(self.seq_list) | |
| def __getitem__(self, index): | |
| seq_index = self.seq_list[index] | |
| if self.subset == "fewview_test" and self.stage == "test": | |
| batch_idx = torch.arange(len(self.seq_to_frames[seq_index])) | |
| elif self.stage == "test": | |
| batch_idx = ( | |
| torch.linspace( | |
| 0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size | |
| ) | |
| .long() | |
| .tolist() | |
| ) | |
| else: | |
| rand = torch.randperm(len(self.seq_to_frames[seq_index])) | |
| batch_idx = rand[: min(len(rand), self.sample_batch_size)] | |
| frame_data_list = [] | |
| idx_list = [] | |
| timestamp_list = [] | |
| for idx in batch_idx: | |
| idx_list.append(self.seq_to_frames[seq_index][idx]) | |
| timestamp_list.append( | |
| self.frame_annots[self.seq_to_frames[seq_index][idx]][ | |
| "frame_annotation" | |
| ].frame_timestamp | |
| ) | |
| frame_data_list.append( | |
| self._get_frame(int(self.seq_to_frames[seq_index][idx])) | |
| ) | |
| time_order = torch.argsort(torch.tensor(timestamp_list)) | |
| frame_data_list = [frame_data_list[i] for i in time_order] | |
| frame_data = FrameData.collate(frame_data_list) | |
| image_size = torch.Tensor([self.image_height]).repeat( | |
| frame_data.camera.R.shape[0], 2 | |
| ) | |
| frame_dict = { | |
| "R": frame_data.camera.R, | |
| "T": frame_data.camera.T, | |
| "f": frame_data.camera.focal_length, | |
| "c": frame_data.camera.principal_point, | |
| "images": frame_data.image_rgb * frame_data.fg_probability | |
| + (1 - frame_data.fg_probability), | |
| "valid_region": frame_data.mask_crop, | |
| "bbox": frame_data.valid_region, | |
| "image_size": image_size, | |
| "frame_type": frame_data.frame_type, | |
| "idx": seq_index, | |
| "category": frame_data.category_one_hot, | |
| } | |
| if not self.masked: | |
| frame_dict["images_full"] = frame_data.image_rgb | |
| frame_dict["masks"] = frame_data.fg_probability | |
| frame_dict["mask_crop"] = frame_data.mask_crop | |
| cond_aug = np.exp( | |
| np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean | |
| ) | |
| def _pad(input): | |
| return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[ | |
| : self.num_frames | |
| ] | |
| if len(frame_dict["images"]) < self.num_frames: | |
| for k in frame_dict: | |
| if isinstance(frame_dict[k], torch.Tensor): | |
| frame_dict[k] = _pad(frame_dict[k]) | |
| data = dict() | |
| if "images_full" in frame_dict: | |
| frames = frame_dict["images_full"] * 2 - 1 | |
| else: | |
| frames = frame_dict["images"] * 2 - 1 | |
| data["frames"] = frames | |
| cond = frames[0] | |
| data["cond_frames_without_noise"] = cond | |
| data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames) | |
| data["cond_frames"] = cond + cond_aug * torch.randn_like(cond) | |
| data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames) | |
| data["motion_bucket_id"] = torch.as_tensor( | |
| [self.motion_bucket_id] * self.num_frames | |
| ) | |
| data["num_video_frames"] = self.num_frames | |
| data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames) | |
| if self.load_pixelnerf: | |
| data["pixelnerf_input"] = dict() | |
| # Rs = frame_dict["R"].transpose(-1, -2) | |
| # Ts = frame_dict["T"] | |
| # Rs[:, :, 2] *= -1 | |
| # Rs[:, :, 0] *= -1 | |
| # Ts[:, 2] *= -1 | |
| # Ts[:, 0] *= -1 | |
| # c2ws = torch.zeros(Rs.shape[0], 4, 4) | |
| # c2ws[:, :3, :3] = Rs | |
| # c2ws[:, :3, 3] = Ts | |
| # c2ws[:, 3, 3] = 1 | |
| # c2ws = c2ws.inverse() | |
| # # c2ws[..., 0] *= -1 | |
| # # c2ws[..., 2] *= -1 | |
| # cx = frame_dict["c"][:, 0] | |
| # cy = frame_dict["c"][:, 1] | |
| # fx = frame_dict["f"][:, 0] | |
| # fy = frame_dict["f"][:, 1] | |
| # intrinsics = torch.zeros(cx.shape[0], 3, 3) | |
| # intrinsics[:, 2, 2] = 1 | |
| # intrinsics[:, 0, 0] = fx | |
| # intrinsics[:, 1, 1] = fy | |
| # intrinsics[:, 0, 2] = cx | |
| # intrinsics[:, 1, 2] = cy | |
| scene_cameras = PerspectiveCameras( | |
| R=frame_dict["R"], | |
| T=frame_dict["T"], | |
| focal_length=frame_dict["f"], | |
| principal_point=frame_dict["c"], | |
| image_size=frame_dict["image_size"], | |
| ) | |
| R, T, intrinsics = opencv_from_cameras_projection( | |
| scene_cameras, frame_dict["image_size"] | |
| ) | |
| c2ws = torch.zeros(R.shape[0], 4, 4) | |
| c2ws[:, :3, :3] = R | |
| c2ws[:, :3, 3] = T | |
| c2ws[:, 3, 3] = 1.0 | |
| c2ws = c2ws.inverse() | |
| c2ws[..., 1:3] *= -1 | |
| intrinsics[:, :2] /= 256 | |
| cameras = torch.zeros(c2ws.shape[0], 25) | |
| cameras[..., :16] = c2ws.reshape(-1, 16) | |
| cameras[..., 16:] = intrinsics.reshape(-1, 9) | |
| if self.scale_pose: | |
| c2ws = cameras[..., :16].reshape(-1, 4, 4) | |
| center = c2ws[:, :3, 3].mean(0) | |
| radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() | |
| scale = 1.5 / radius | |
| c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale | |
| cameras[..., :16] = c2ws.reshape(-1, 16) | |
| data["pixelnerf_input"]["frames"] = frames | |
| data["pixelnerf_input"]["cameras"] = cameras | |
| data["pixelnerf_input"]["rgb"] = ( | |
| F.interpolate( | |
| frames, | |
| (self.image_width // 8, self.image_height // 8), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| + 1 | |
| ) * 0.5 | |
| return data | |
| # if self.return_frame_data_list: | |
| # return (frame_dict, frame_data_list) | |
| # return frame_dict | |
| def collate_fn(self, batch): | |
| # a hack to add source index and keep consistent within a batch | |
| if self.max_n_cond > 1: | |
| # TODO implement this | |
| n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1) | |
| # debug | |
| # source_index = [0] | |
| if n_cond > 1: | |
| for b in batch: | |
| source_index = [0] + np.random.choice( | |
| np.arange(1, self.num_frames), | |
| self.max_n_cond - 1, | |
| replace=False, | |
| ).tolist() | |
| b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) | |
| b["pixelnerf_input"]["n_cond"] = n_cond | |
| b["pixelnerf_input"]["source_images"] = b["frames"][source_index] | |
| b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ | |
| "cameras" | |
| ][source_index] | |
| if self.cond_on_multi: | |
| b["cond_frames_without_noise"] = b["frames"][source_index] | |
| ret = video_collate_fn(batch) | |
| if self.cond_on_multi: | |
| ret["cond_frames_without_noise"] = rearrange( | |
| ret["cond_frames_without_noise"], "b t ... -> (b t) ..." | |
| ) | |
| return ret | |
| def _get_frame(self, index): | |
| # if index >= len(self.frame_annots): | |
| # raise IndexError(f"index {index} out of range {len(self.frame_annots)}") | |
| entry = self.frame_annots[index]["frame_annotation"] | |
| # pyre-ignore[16] | |
| point_cloud = self.seq_annots[entry.sequence_name].point_cloud | |
| frame_data = FrameData( | |
| frame_number=_safe_as_tensor(entry.frame_number, torch.long), | |
| frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), | |
| sequence_name=entry.sequence_name, | |
| sequence_category=self.seq_annots[entry.sequence_name].category, | |
| camera_quality_score=_safe_as_tensor( | |
| self.seq_annots[entry.sequence_name].viewpoint_quality_score, | |
| torch.float, | |
| ), | |
| point_cloud_quality_score=_safe_as_tensor( | |
| point_cloud.quality_score, torch.float | |
| ) | |
| if point_cloud is not None | |
| else None, | |
| ) | |
| # The rest of the fields are optional | |
| frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) | |
| ( | |
| frame_data.fg_probability, | |
| frame_data.mask_path, | |
| frame_data.bbox_xywh, | |
| clamp_bbox_xyxy, | |
| frame_data.crop_bbox_xywh, | |
| ) = self._load_crop_fg_probability(entry) | |
| scale = 1.0 | |
| if self.load_images and entry.image is not None: | |
| # original image size | |
| frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) | |
| ( | |
| frame_data.image_rgb, | |
| frame_data.image_path, | |
| frame_data.mask_crop, | |
| scale, | |
| ) = self._load_crop_images( | |
| entry, frame_data.fg_probability, clamp_bbox_xyxy | |
| ) | |
| # print(frame_data.fg_probability.sum()) | |
| # print('scale', scale) | |
| #! INSERT | |
| if self.deprecated_val_region: | |
| # print(frame_data.crop_bbox_xywh) | |
| valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float() | |
| # print(valid_bbox, frame_data.image_size_hw) | |
| valid_bbox[0] = torch.clip( | |
| ( | |
| valid_bbox[0] | |
| - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") | |
| ) | |
| / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), | |
| -1.0, | |
| 1.0, | |
| ) | |
| valid_bbox[1] = torch.clip( | |
| ( | |
| valid_bbox[1] | |
| - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") | |
| ) | |
| / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), | |
| -1.0, | |
| 1.0, | |
| ) | |
| valid_bbox[2] = torch.clip( | |
| ( | |
| valid_bbox[2] | |
| - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") | |
| ) | |
| / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), | |
| -1.0, | |
| 1.0, | |
| ) | |
| valid_bbox[3] = torch.clip( | |
| ( | |
| valid_bbox[3] | |
| - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") | |
| ) | |
| / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), | |
| -1.0, | |
| 1.0, | |
| ) | |
| # print(valid_bbox) | |
| frame_data.valid_region = valid_bbox | |
| else: | |
| #! UPDATED VALID BBOX | |
| if self.stage == "train": | |
| assert self.image_height == 256 and self.image_width == 256 | |
| valid = torch.nonzero(frame_data.mask_crop[0]) | |
| min_y = valid[:, 0].min() | |
| min_x = valid[:, 1].min() | |
| max_y = valid[:, 0].max() | |
| max_x = valid[:, 1].max() | |
| valid_bbox = torch.tensor( | |
| [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device | |
| ).unsqueeze(0) | |
| valid_bbox = torch.clip( | |
| (valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0 | |
| ) | |
| frame_data.valid_region = valid_bbox[0] | |
| else: | |
| valid = torch.nonzero(frame_data.mask_crop[0]) | |
| min_y = valid[:, 0].min() | |
| min_x = valid[:, 1].min() | |
| max_y = valid[:, 0].max() | |
| max_x = valid[:, 1].max() | |
| valid_bbox = torch.tensor( | |
| [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device | |
| ).unsqueeze(0) | |
| valid_bbox = torch.clip( | |
| (valid_bbox - (self.image_height // 2)) / (self.image_height // 2), | |
| -1.0, | |
| 1.0, | |
| ) | |
| frame_data.valid_region = valid_bbox[0] | |
| #! SET CLASS ONEHOT | |
| frame_data.category_one_hot = torch.zeros( | |
| (len(self.all_category_list)), device=frame_data.image_rgb.device | |
| ) | |
| frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1 | |
| if self.load_depths and entry.depth is not None: | |
| ( | |
| frame_data.depth_map, | |
| frame_data.depth_path, | |
| frame_data.depth_mask, | |
| ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) | |
| if entry.viewpoint is not None: | |
| frame_data.camera = self._get_pytorch3d_camera( | |
| entry, | |
| scale, | |
| clamp_bbox_xyxy, | |
| ) | |
| if self.load_point_clouds and point_cloud is not None: | |
| frame_data.sequence_point_cloud_path = pcl_path = os.path.join( | |
| self.dataset_root, point_cloud.path | |
| ) | |
| frame_data.sequence_point_cloud = _load_pointcloud( | |
| self._local_path(pcl_path), max_points=self.max_points | |
| ) | |
| # for key in frame_data: | |
| # if frame_data[key] == None: | |
| # print(key) | |
| return frame_data | |
| def _extract_and_set_eval_batches(self): | |
| """ | |
| Sets eval_batches based on input eval_batch_index. | |
| """ | |
| if self.eval_batch_index is not None: | |
| if self.eval_batches is not None: | |
| raise ValueError( | |
| "Cannot define both eval_batch_index and eval_batches." | |
| ) | |
| self.eval_batches = self.seq_frame_index_to_dataset_index( | |
| self.eval_batch_index | |
| ) | |
| def _load_crop_fg_probability( | |
| self, entry: types.FrameAnnotation | |
| ) -> Tuple[ | |
| Optional[torch.Tensor], | |
| Optional[str], | |
| Optional[torch.Tensor], | |
| Optional[torch.Tensor], | |
| Optional[torch.Tensor], | |
| ]: | |
| fg_probability = None | |
| full_path = None | |
| bbox_xywh = None | |
| clamp_bbox_xyxy = None | |
| crop_box_xywh = None | |
| if (self.load_masks or self.box_crop) and entry.mask is not None: | |
| full_path = os.path.join(self.dataset_root, entry.mask.path) | |
| mask = _load_mask(self._local_path(full_path)) | |
| if mask.shape[-2:] != entry.image.size: | |
| raise ValueError( | |
| f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" | |
| ) | |
| bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) | |
| if self.box_crop: | |
| clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( | |
| _get_clamp_bbox( | |
| bbox_xywh, | |
| image_path=entry.image.path, | |
| box_crop_context=self.box_crop_context, | |
| ), | |
| image_size_hw=tuple(mask.shape[-2:]), | |
| ) | |
| crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) | |
| mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) | |
| fg_probability, _, _ = self._resize_image(mask, mode="nearest") | |
| return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh | |
| def _load_crop_images( | |
| self, | |
| entry: types.FrameAnnotation, | |
| fg_probability: Optional[torch.Tensor], | |
| clamp_bbox_xyxy: Optional[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: | |
| assert self.dataset_root is not None and entry.image is not None | |
| path = os.path.join(self.dataset_root, entry.image.path) | |
| image_rgb = _load_image(self._local_path(path)) | |
| if image_rgb.shape[-2:] != entry.image.size: | |
| raise ValueError( | |
| f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" | |
| ) | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) | |
| image_rgb, scale, mask_crop = self._resize_image(image_rgb) | |
| if self.mask_images: | |
| assert fg_probability is not None | |
| image_rgb *= fg_probability | |
| return image_rgb, path, mask_crop, scale | |
| def _load_mask_depth( | |
| self, | |
| entry: types.FrameAnnotation, | |
| clamp_bbox_xyxy: Optional[torch.Tensor], | |
| fg_probability: Optional[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, str, torch.Tensor]: | |
| entry_depth = entry.depth | |
| assert entry_depth is not None | |
| path = os.path.join(self.dataset_root, entry_depth.path) | |
| depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| depth_bbox_xyxy = _rescale_bbox( | |
| clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] | |
| ) | |
| depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) | |
| depth_map, _, _ = self._resize_image(depth_map, mode="nearest") | |
| if self.mask_depths: | |
| assert fg_probability is not None | |
| depth_map *= fg_probability | |
| if self.load_depth_masks: | |
| assert entry_depth.mask_path is not None | |
| mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) | |
| depth_mask = _load_depth_mask(self._local_path(mask_path)) | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| depth_mask_bbox_xyxy = _rescale_bbox( | |
| clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] | |
| ) | |
| depth_mask = _crop_around_box( | |
| depth_mask, depth_mask_bbox_xyxy, mask_path | |
| ) | |
| depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") | |
| else: | |
| depth_mask = torch.ones_like(depth_map) | |
| return depth_map, path, depth_mask | |
| def _get_pytorch3d_camera( | |
| self, | |
| entry: types.FrameAnnotation, | |
| scale: float, | |
| clamp_bbox_xyxy: Optional[torch.Tensor], | |
| ) -> PerspectiveCameras: | |
| entry_viewpoint = entry.viewpoint | |
| assert entry_viewpoint is not None | |
| # principal point and focal length | |
| principal_point = torch.tensor( | |
| entry_viewpoint.principal_point, dtype=torch.float | |
| ) | |
| focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) | |
| half_image_size_wh_orig = ( | |
| torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 | |
| ) | |
| # first, we convert from the dataset's NDC convention to pixels | |
| format = entry_viewpoint.intrinsics_format | |
| if format.lower() == "ndc_norm_image_bounds": | |
| # this is e.g. currently used in CO3D for storing intrinsics | |
| rescale = half_image_size_wh_orig | |
| elif format.lower() == "ndc_isotropic": | |
| rescale = half_image_size_wh_orig.min() | |
| else: | |
| raise ValueError(f"Unknown intrinsics format: {format}") | |
| # principal point and focal length in pixels | |
| principal_point_px = half_image_size_wh_orig - principal_point * rescale | |
| focal_length_px = focal_length * rescale | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| principal_point_px -= clamp_bbox_xyxy[:2] | |
| # now, convert from pixels to PyTorch3D v0.5+ NDC convention | |
| if self.image_height is None or self.image_width is None: | |
| out_size = list(reversed(entry.image.size)) | |
| else: | |
| out_size = [self.image_width, self.image_height] | |
| half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 | |
| half_min_image_size_output = half_image_size_output.min() | |
| # rescaled principal point and focal length in ndc | |
| principal_point = ( | |
| half_image_size_output - principal_point_px * scale | |
| ) / half_min_image_size_output | |
| focal_length = focal_length_px * scale / half_min_image_size_output | |
| return PerspectiveCameras( | |
| focal_length=focal_length[None], | |
| principal_point=principal_point[None], | |
| R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], | |
| T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], | |
| ) | |
| def _load_frames(self) -> None: | |
| self.frame_annots = [ | |
| FrameAnnotsEntry(frame_annotation=a, subset=None) | |
| for a in self.category_frame_annotations | |
| ] | |
| def _load_sequences(self) -> None: | |
| self.seq_annots = { | |
| entry.sequence_name: entry for entry in self.category_sequence_annotations | |
| } | |
| def _load_subset_lists(self) -> None: | |
| logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.") | |
| if not self.subset_lists_file: | |
| return | |
| frame_path_to_subset = {} | |
| for subset_list_file in self.subset_lists_file: | |
| with open(self._local_path(subset_list_file), "r") as f: | |
| subset_to_seq_frame = json.load(f) | |
| #! PRINT SUBSET_LIST STATS | |
| # if len(self.subset_lists_file) == 1: | |
| # print('train frames', len(subset_to_seq_frame['train'])) | |
| # print('val frames', len(subset_to_seq_frame['val'])) | |
| # print('test frames', len(subset_to_seq_frame['test'])) | |
| for set_ in subset_to_seq_frame: | |
| for _, _, path in subset_to_seq_frame[set_]: | |
| if path in frame_path_to_subset: | |
| frame_path_to_subset[path].add(set_) | |
| else: | |
| frame_path_to_subset[path] = {set_} | |
| # pyre-ignore[16] | |
| for frame in self.frame_annots: | |
| frame["subset"] = frame_path_to_subset.get( | |
| frame["frame_annotation"].image.path, None | |
| ) | |
| if frame["subset"] is None: | |
| continue | |
| warnings.warn( | |
| "Subset lists are given but don't include " | |
| + frame["frame_annotation"].image.path | |
| ) | |
| def _sort_frames(self) -> None: | |
| # Sort frames to have them grouped by sequence, ordered by timestamp | |
| # pyre-ignore[16] | |
| self.frame_annots = sorted( | |
| self.frame_annots, | |
| key=lambda f: ( | |
| f["frame_annotation"].sequence_name, | |
| f["frame_annotation"].frame_timestamp or 0, | |
| ), | |
| ) | |
| def _filter_db(self) -> None: | |
| if self.remove_empty_masks: | |
| logger.info("Removing images with empty masks.") | |
| # pyre-ignore[16] | |
| old_len = len(self.frame_annots) | |
| msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." | |
| def positive_mass(frame_annot: types.FrameAnnotation) -> bool: | |
| mask = frame_annot.mask | |
| if mask is None: | |
| return False | |
| if mask.mass is None: | |
| raise ValueError(msg) | |
| return mask.mass > 1 | |
| self.frame_annots = [ | |
| frame | |
| for frame in self.frame_annots | |
| if positive_mass(frame["frame_annotation"]) | |
| ] | |
| logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots))) | |
| # this has to be called after joining with categories!! | |
| subsets = self.subsets | |
| if subsets: | |
| if not self.subset_lists_file: | |
| raise ValueError( | |
| "Subset filter is on but subset_lists_file was not given" | |
| ) | |
| logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.") | |
| # truncate the list of subsets to the valid one | |
| self.frame_annots = [ | |
| entry | |
| for entry in self.frame_annots | |
| if (entry["subset"] is not None and self.stage in entry["subset"]) | |
| ] | |
| if len(self.frame_annots) == 0: | |
| raise ValueError(f"There are no frames in the '{subsets}' subsets!") | |
| self._invalidate_indexes(filter_seq_annots=True) | |
| if len(self.limit_category_to) > 0: | |
| logger.info(f"Limiting dataset to categories: {self.limit_category_to}") | |
| # pyre-ignore[16] | |
| self.seq_annots = { | |
| name: entry | |
| for name, entry in self.seq_annots.items() | |
| if entry.category in self.limit_category_to | |
| } | |
| # sequence filters | |
| for prefix in ("pick", "exclude"): | |
| orig_len = len(self.seq_annots) | |
| attr = f"{prefix}_sequence" | |
| arr = getattr(self, attr) | |
| if len(arr) > 0: | |
| logger.info(f"{attr}: {str(arr)}") | |
| self.seq_annots = { | |
| name: entry | |
| for name, entry in self.seq_annots.items() | |
| if (name in arr) == (prefix == "pick") | |
| } | |
| logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) | |
| if self.limit_sequences_to > 0: | |
| self.seq_annots = dict( | |
| islice(self.seq_annots.items(), self.limit_sequences_to) | |
| ) | |
| # retain only frames from retained sequences | |
| self.frame_annots = [ | |
| f | |
| for f in self.frame_annots | |
| if f["frame_annotation"].sequence_name in self.seq_annots | |
| ] | |
| self._invalidate_indexes() | |
| if self.n_frames_per_sequence > 0: | |
| logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") | |
| keep_idx = [] | |
| # pyre-ignore[16] | |
| for seq, seq_indices in self._seq_to_idx.items(): | |
| # infer the seed from the sequence name, this is reproducible | |
| # and makes the selection differ for different sequences | |
| seed = _seq_name_to_seed(seq) + self.seed | |
| seq_idx_shuffled = random.Random(seed).sample( | |
| sorted(seq_indices), len(seq_indices) | |
| ) | |
| keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) | |
| logger.info( | |
| "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)) | |
| ) | |
| self.frame_annots = [self.frame_annots[i] for i in keep_idx] | |
| self._invalidate_indexes(filter_seq_annots=False) | |
| # sequences are not decimated, so self.seq_annots is valid | |
| if self.limit_to > 0 and self.limit_to < len(self.frame_annots): | |
| logger.info( | |
| "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) | |
| ) | |
| self.frame_annots = self.frame_annots[: self.limit_to] | |
| self._invalidate_indexes(filter_seq_annots=True) | |
| def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: | |
| # update _seq_to_idx and filter seq_meta according to frame_annots change | |
| # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx | |
| self._invalidate_seq_to_idx() | |
| if filter_seq_annots: | |
| # pyre-ignore[16] | |
| self.seq_annots = { | |
| k: v | |
| for k, v in self.seq_annots.items() | |
| # pyre-ignore[16] | |
| if k in self._seq_to_idx | |
| } | |
| def _invalidate_seq_to_idx(self) -> None: | |
| seq_to_idx = defaultdict(list) | |
| # pyre-ignore[16] | |
| for idx, entry in enumerate(self.frame_annots): | |
| seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) | |
| # pyre-ignore[16] | |
| self._seq_to_idx = seq_to_idx | |
| def _resize_image( | |
| self, image, mode="bilinear" | |
| ) -> Tuple[torch.Tensor, float, torch.Tensor]: | |
| image_height, image_width = self.image_height, self.image_width | |
| if image_height is None or image_width is None: | |
| # skip the resizing | |
| imre_ = torch.from_numpy(image) | |
| return imre_, 1.0, torch.ones_like(imre_[:1]) | |
| # takes numpy array, returns pytorch tensor | |
| minscale = min( | |
| image_height / image.shape[-2], | |
| image_width / image.shape[-1], | |
| ) | |
| imre = torch.nn.functional.interpolate( | |
| torch.from_numpy(image)[None], | |
| scale_factor=minscale, | |
| mode=mode, | |
| align_corners=False if mode == "bilinear" else None, | |
| recompute_scale_factor=True, | |
| )[0] | |
| # pyre-fixme[19]: Expected 1 positional argument. | |
| imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) | |
| imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre | |
| # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. | |
| # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. | |
| mask = torch.zeros(1, self.image_height, self.image_width) | |
| mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 | |
| return imre_, minscale, mask | |
| def _local_path(self, path: str) -> str: | |
| if self.path_manager is None: | |
| return path | |
| return self.path_manager.get_local_path(path) | |
| def get_frame_numbers_and_timestamps( | |
| self, idxs: Sequence[int] | |
| ) -> List[Tuple[int, float]]: | |
| out: List[Tuple[int, float]] = [] | |
| for idx in idxs: | |
| # pyre-ignore[16] | |
| frame_annotation = self.frame_annots[idx]["frame_annotation"] | |
| out.append( | |
| (frame_annotation.frame_number, frame_annotation.frame_timestamp) | |
| ) | |
| return out | |
| def get_eval_batches(self) -> Optional[List[List[int]]]: | |
| return self.eval_batches | |
| def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: | |
| return entry["frame_annotation"].meta["frame_type"] | |
| class CO3DDataset(LightningDataModule): | |
| def __init__( | |
| self, | |
| root_dir, | |
| batch_size=2, | |
| shuffle=True, | |
| num_workers=10, | |
| prefetch_factor=2, | |
| category="hydrant", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.batch_size = batch_size | |
| self.num_workers = num_workers | |
| self.prefetch_factor = prefetch_factor | |
| self.shuffle = shuffle | |
| self.train_dataset = CO3Dv2Wrapper( | |
| root_dir=root_dir, | |
| stage="train", | |
| category=category, | |
| **kwargs, | |
| ) | |
| self.test_dataset = CO3Dv2Wrapper( | |
| root_dir=root_dir, | |
| stage="test", | |
| subset="fewview_dev", | |
| category=category, | |
| **kwargs, | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| prefetch_factor=self.prefetch_factor, | |
| collate_fn=self.train_dataset.collate_fn, | |
| ) | |
| def test_dataloader(self): | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| prefetch_factor=self.prefetch_factor, | |
| collate_fn=self.test_dataset.collate_fn, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.test_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=self.shuffle, | |
| num_workers=self.num_workers, | |
| prefetch_factor=self.prefetch_factor, | |
| collate_fn=video_collate_fn, | |
| ) | |