Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import itertools | |
| import logging | |
| import numpy as np | |
| import operator | |
| import pickle | |
| from collections import OrderedDict, defaultdict | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| import torch | |
| import torch.utils.data as torchdata | |
| from tabulate import tabulate | |
| from termcolor import colored | |
| from detectron2.config import configurable | |
| from detectron2.structures import BoxMode | |
| from detectron2.utils.comm import get_world_size | |
| from detectron2.utils.env import seed_all_rng | |
| from detectron2.utils.file_io import PathManager | |
| from detectron2.utils.logger import _log_api_usage, log_first_n | |
| from .catalog import DatasetCatalog, MetadataCatalog | |
| from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset | |
| from .dataset_mapper import DatasetMapper | |
| from .detection_utils import check_metadata_consistency | |
| from .samplers import ( | |
| InferenceSampler, | |
| RandomSubsetTrainingSampler, | |
| RepeatFactorTrainingSampler, | |
| TrainingSampler, | |
| ) | |
| """ | |
| This file contains the default logic to build a dataloader for training or testing. | |
| """ | |
| __all__ = [ | |
| "build_batch_data_loader", | |
| "build_detection_train_loader", | |
| "build_detection_test_loader", | |
| "get_detection_dataset_dicts", | |
| "load_proposals_into_dataset", | |
| "print_instances_class_histogram", | |
| ] | |
| def filter_images_with_only_crowd_annotations(dataset_dicts): | |
| """ | |
| Filter out images with none annotations or only crowd annotations | |
| (i.e., images without non-crowd annotations). | |
| A common training-time preprocessing on COCO dataset. | |
| Args: | |
| dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. | |
| Returns: | |
| list[dict]: the same format, but filtered. | |
| """ | |
| num_before = len(dataset_dicts) | |
| def valid(anns): | |
| for ann in anns: | |
| if ann.get("iscrowd", 0) == 0: | |
| return True | |
| return False | |
| dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] | |
| num_after = len(dataset_dicts) | |
| logger = logging.getLogger(__name__) | |
| logger.info( | |
| "Removed {} images with no usable annotations. {} images left.".format( | |
| num_before - num_after, num_after | |
| ) | |
| ) | |
| return dataset_dicts | |
| def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): | |
| """ | |
| Filter out images with too few number of keypoints. | |
| Args: | |
| dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. | |
| Returns: | |
| list[dict]: the same format as dataset_dicts, but filtered. | |
| """ | |
| num_before = len(dataset_dicts) | |
| def visible_keypoints_in_image(dic): | |
| # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility | |
| annotations = dic["annotations"] | |
| return sum( | |
| (np.array(ann["keypoints"][2::3]) > 0).sum() | |
| for ann in annotations | |
| if "keypoints" in ann | |
| ) | |
| dataset_dicts = [ | |
| x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image | |
| ] | |
| num_after = len(dataset_dicts) | |
| logger = logging.getLogger(__name__) | |
| logger.info( | |
| "Removed {} images with fewer than {} keypoints.".format( | |
| num_before - num_after, min_keypoints_per_image | |
| ) | |
| ) | |
| return dataset_dicts | |
| def load_proposals_into_dataset(dataset_dicts, proposal_file): | |
| """ | |
| Load precomputed object proposals into the dataset. | |
| The proposal file should be a pickled dict with the following keys: | |
| - "ids": list[int] or list[str], the image ids | |
| - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id | |
| - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores | |
| corresponding to the boxes. | |
| - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. | |
| Args: | |
| dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. | |
| proposal_file (str): file path of pre-computed proposals, in pkl format. | |
| Returns: | |
| list[dict]: the same format as dataset_dicts, but added proposal field. | |
| """ | |
| logger = logging.getLogger(__name__) | |
| logger.info("Loading proposals from: {}".format(proposal_file)) | |
| with PathManager.open(proposal_file, "rb") as f: | |
| proposals = pickle.load(f, encoding="latin1") | |
| # Rename the key names in D1 proposal files | |
| rename_keys = {"indexes": "ids", "scores": "objectness_logits"} | |
| for key in rename_keys: | |
| if key in proposals: | |
| proposals[rename_keys[key]] = proposals.pop(key) | |
| # Fetch the indexes of all proposals that are in the dataset | |
| # Convert image_id to str since they could be int. | |
| img_ids = set({str(record["image_id"]) for record in dataset_dicts}) | |
| id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} | |
| # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS' | |
| bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS | |
| for record in dataset_dicts: | |
| # Get the index of the proposal | |
| i = id_to_index[str(record["image_id"])] | |
| boxes = proposals["boxes"][i] | |
| objectness_logits = proposals["objectness_logits"][i] | |
| # Sort the proposals in descending order of the scores | |
| inds = objectness_logits.argsort()[::-1] | |
| record["proposal_boxes"] = boxes[inds] | |
| record["proposal_objectness_logits"] = objectness_logits[inds] | |
| record["proposal_bbox_mode"] = bbox_mode | |
| return dataset_dicts | |
| def print_instances_class_histogram(dataset_dicts, class_names): | |
| """ | |
| Args: | |
| dataset_dicts (list[dict]): list of dataset dicts. | |
| class_names (list[str]): list of class names (zero-indexed). | |
| """ | |
| num_classes = len(class_names) | |
| hist_bins = np.arange(num_classes + 1) | |
| histogram = np.zeros((num_classes,), dtype=int) | |
| for entry in dataset_dicts: | |
| annos = entry["annotations"] | |
| classes = np.asarray( | |
| [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=int | |
| ) | |
| if len(classes): | |
| assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" | |
| assert ( | |
| classes.max() < num_classes | |
| ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" | |
| histogram += np.histogram(classes, bins=hist_bins)[0] | |
| N_COLS = min(6, len(class_names) * 2) | |
| def short_name(x): | |
| # make long class names shorter. useful for lvis | |
| if len(x) > 13: | |
| return x[:11] + ".." | |
| return x | |
| data = list( | |
| itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) | |
| ) | |
| total_num_instances = sum(data[1::2]) | |
| data.extend([None] * (N_COLS - (len(data) % N_COLS))) | |
| if num_classes > 1: | |
| data.extend(["total", total_num_instances]) | |
| data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) | |
| table = tabulate( | |
| data, | |
| headers=["category", "#instances"] * (N_COLS // 2), | |
| tablefmt="pipe", | |
| numalign="left", | |
| stralign="center", | |
| ) | |
| log_first_n( | |
| logging.INFO, | |
| "Distribution of instances among all {} categories:\n".format(num_classes) | |
| + colored(table, "cyan"), | |
| key="message", | |
| ) | |
| def get_detection_dataset_dicts( | |
| names, | |
| filter_empty=True, | |
| min_keypoints=0, | |
| proposal_files=None, | |
| check_consistency=True, | |
| ): | |
| """ | |
| Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. | |
| Args: | |
| names (str or list[str]): a dataset name or a list of dataset names | |
| filter_empty (bool): whether to filter out images without instance annotations | |
| min_keypoints (int): filter out images with fewer keypoints than | |
| `min_keypoints`. Set to 0 to do nothing. | |
| proposal_files (list[str]): if given, a list of object proposal files | |
| that match each dataset in `names`. | |
| check_consistency (bool): whether to check if datasets have consistent metadata. | |
| Returns: | |
| list[dict]: a list of dicts following the standard dataset dict format. | |
| """ | |
| if isinstance(names, str): | |
| names = [names] | |
| assert len(names), names | |
| available_datasets = DatasetCatalog.keys() | |
| names_set = set(names) | |
| if not names_set.issubset(available_datasets): | |
| logger = logging.getLogger(__name__) | |
| logger.warning( | |
| "The following dataset names are not registered in the DatasetCatalog: " | |
| f"{names_set - available_datasets}. " | |
| f"Available datasets are {available_datasets}" | |
| ) | |
| dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names] | |
| if isinstance(dataset_dicts[0], torchdata.Dataset): | |
| if len(dataset_dicts) > 1: | |
| # ConcatDataset does not work for iterable style dataset. | |
| # We could support concat for iterable as well, but it's often | |
| # not a good idea to concat iterables anyway. | |
| return torchdata.ConcatDataset(dataset_dicts) | |
| return dataset_dicts[0] | |
| for dataset_name, dicts in zip(names, dataset_dicts): | |
| assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) | |
| if proposal_files is not None: | |
| assert len(names) == len(proposal_files) | |
| # load precomputed proposals from proposal files | |
| dataset_dicts = [ | |
| load_proposals_into_dataset(dataset_i_dicts, proposal_file) | |
| for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) | |
| ] | |
| dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) | |
| has_instances = "annotations" in dataset_dicts[0] | |
| if filter_empty and has_instances: | |
| dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) | |
| if min_keypoints > 0 and has_instances: | |
| dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) | |
| if check_consistency and has_instances: | |
| try: | |
| class_names = MetadataCatalog.get(names[0]).thing_classes | |
| check_metadata_consistency("thing_classes", names) | |
| print_instances_class_histogram(dataset_dicts, class_names) | |
| except AttributeError: # class names are not available for this dataset | |
| pass | |
| assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) | |
| return dataset_dicts | |
| def build_batch_data_loader( | |
| dataset, | |
| sampler, | |
| total_batch_size, | |
| *, | |
| aspect_ratio_grouping=False, | |
| num_workers=0, | |
| collate_fn=None, | |
| drop_last: bool = True, | |
| single_gpu_batch_size=None, | |
| seed=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are: | |
| 1. support aspect ratio grouping options | |
| 2. use no "batch collation", because this is common for detection training | |
| Args: | |
| dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset. | |
| sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices. | |
| Must be provided iff. ``dataset`` is a map-style dataset. | |
| total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see | |
| :func:`build_detection_train_loader`. | |
| single_gpu_batch_size: You can specify either `single_gpu_batch_size` or `total_batch_size`. | |
| `single_gpu_batch_size` specifies the batch size that will be used for each gpu/process. | |
| `total_batch_size` allows you to specify the total aggregate batch size across gpus. | |
| It is an error to supply a value for both. | |
| drop_last (bool): if ``True``, the dataloader will drop incomplete batches. | |
| Returns: | |
| iterable[list]. Length of each list is the batch size of the current | |
| GPU. Each element in the list comes from the dataset. | |
| """ | |
| if single_gpu_batch_size: | |
| if total_batch_size: | |
| raise ValueError( | |
| """total_batch_size and single_gpu_batch_size are mutually incompatible. | |
| Please specify only one. """ | |
| ) | |
| batch_size = single_gpu_batch_size | |
| else: | |
| world_size = get_world_size() | |
| assert ( | |
| total_batch_size > 0 and total_batch_size % world_size == 0 | |
| ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( | |
| total_batch_size, world_size | |
| ) | |
| batch_size = total_batch_size // world_size | |
| logger = logging.getLogger(__name__) | |
| logger.info("Making batched data loader with batch_size=%d", batch_size) | |
| if isinstance(dataset, torchdata.IterableDataset): | |
| assert sampler is None, "sampler must be None if dataset is IterableDataset" | |
| else: | |
| dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size) | |
| generator = None | |
| if seed is not None: | |
| generator = torch.Generator() | |
| generator.manual_seed(seed) | |
| if aspect_ratio_grouping: | |
| assert drop_last, "Aspect ratio grouping will drop incomplete batches." | |
| data_loader = torchdata.DataLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements | |
| worker_init_fn=worker_init_reset_seed, | |
| generator=generator, | |
| **kwargs | |
| ) # yield individual mapped dict | |
| data_loader = AspectRatioGroupedDataset(data_loader, batch_size) | |
| if collate_fn is None: | |
| return data_loader | |
| return MapDataset(data_loader, collate_fn) | |
| else: | |
| return torchdata.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| drop_last=drop_last, | |
| num_workers=num_workers, | |
| collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, | |
| worker_init_fn=worker_init_reset_seed, | |
| generator=generator, | |
| **kwargs | |
| ) | |
| def _get_train_datasets_repeat_factors(cfg) -> Dict[str, float]: | |
| repeat_factors = cfg.DATASETS.TRAIN_REPEAT_FACTOR | |
| assert all(len(tup) == 2 for tup in repeat_factors) | |
| name_to_weight = defaultdict(lambda: 1, dict(repeat_factors)) | |
| # The sampling weights map should only contain datasets in train config | |
| unrecognized = set(name_to_weight.keys()) - set(cfg.DATASETS.TRAIN) | |
| assert not unrecognized, f"unrecognized datasets: {unrecognized}" | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Found repeat factors: {list(name_to_weight.items())}") | |
| # pyre-fixme[7]: Expected `Dict[str, float]` but got `DefaultDict[typing.Any, int]`. | |
| return name_to_weight | |
| def _build_weighted_sampler(cfg, enable_category_balance=False): | |
| dataset_repeat_factors = _get_train_datasets_repeat_factors(cfg) | |
| # OrderedDict to guarantee order of values() consistent with repeat factors | |
| dataset_name_to_dicts = OrderedDict( | |
| { | |
| name: get_detection_dataset_dicts( | |
| [name], | |
| filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, | |
| min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE | |
| if cfg.MODEL.KEYPOINT_ON | |
| else 0, | |
| proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN | |
| if cfg.MODEL.LOAD_PROPOSALS | |
| else None, | |
| ) | |
| for name in cfg.DATASETS.TRAIN | |
| } | |
| ) | |
| # Repeat factor for every sample in the dataset | |
| repeat_factors = [ | |
| [dataset_repeat_factors[dsname]] * len(dataset_name_to_dicts[dsname]) | |
| for dsname in cfg.DATASETS.TRAIN | |
| ] | |
| repeat_factors = list(itertools.chain.from_iterable(repeat_factors)) | |
| repeat_factors = torch.tensor(repeat_factors) | |
| logger = logging.getLogger(__name__) | |
| if enable_category_balance: | |
| """ | |
| 1. Calculate repeat factors using category frequency for each dataset and then merge them. | |
| 2. Element wise dot producting the dataset frequency repeat factors with | |
| the category frequency repeat factors gives the final repeat factors. | |
| """ | |
| category_repeat_factors = [ | |
| RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( | |
| dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD | |
| ) | |
| for dataset_dict in dataset_name_to_dicts.values() | |
| ] | |
| # flatten the category repeat factors from all datasets | |
| category_repeat_factors = list(itertools.chain.from_iterable(category_repeat_factors)) | |
| category_repeat_factors = torch.tensor(category_repeat_factors) | |
| repeat_factors = torch.mul(category_repeat_factors, repeat_factors) | |
| repeat_factors = repeat_factors / torch.min(repeat_factors) | |
| logger.info( | |
| "Using WeightedCategoryTrainingSampler with repeat_factors={}".format( | |
| cfg.DATASETS.TRAIN_REPEAT_FACTOR | |
| ) | |
| ) | |
| else: | |
| logger.info( | |
| "Using WeightedTrainingSampler with repeat_factors={}".format( | |
| cfg.DATASETS.TRAIN_REPEAT_FACTOR | |
| ) | |
| ) | |
| sampler = RepeatFactorTrainingSampler(repeat_factors) | |
| return sampler | |
| def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): | |
| if dataset is None: | |
| dataset = get_detection_dataset_dicts( | |
| cfg.DATASETS.TRAIN, | |
| filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, | |
| min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE | |
| if cfg.MODEL.KEYPOINT_ON | |
| else 0, | |
| proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, | |
| ) | |
| _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) | |
| if mapper is None: | |
| mapper = DatasetMapper(cfg, True) | |
| if sampler is None: | |
| sampler_name = cfg.DATALOADER.SAMPLER_TRAIN | |
| logger = logging.getLogger(__name__) | |
| if isinstance(dataset, torchdata.IterableDataset): | |
| logger.info("Not using any sampler since the dataset is IterableDataset.") | |
| sampler = None | |
| else: | |
| logger.info("Using training sampler {}".format(sampler_name)) | |
| if sampler_name == "TrainingSampler": | |
| sampler = TrainingSampler(len(dataset)) | |
| elif sampler_name == "RepeatFactorTrainingSampler": | |
| repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( | |
| dataset, cfg.DATALOADER.REPEAT_THRESHOLD | |
| ) | |
| sampler = RepeatFactorTrainingSampler(repeat_factors) | |
| elif sampler_name == "RandomSubsetTrainingSampler": | |
| sampler = RandomSubsetTrainingSampler( | |
| len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO | |
| ) | |
| elif sampler_name == "WeightedTrainingSampler": | |
| sampler = _build_weighted_sampler(cfg) | |
| elif sampler_name == "WeightedCategoryTrainingSampler": | |
| sampler = _build_weighted_sampler(cfg, enable_category_balance=True) | |
| else: | |
| raise ValueError("Unknown training sampler: {}".format(sampler_name)) | |
| return { | |
| "dataset": dataset, | |
| "sampler": sampler, | |
| "mapper": mapper, | |
| "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, | |
| "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, | |
| "num_workers": cfg.DATALOADER.NUM_WORKERS, | |
| } | |
| def build_detection_train_loader( | |
| dataset, | |
| *, | |
| mapper, | |
| sampler=None, | |
| total_batch_size, | |
| aspect_ratio_grouping=True, | |
| num_workers=0, | |
| collate_fn=None, | |
| **kwargs | |
| ): | |
| """ | |
| Build a dataloader for object detection with some default features. | |
| Args: | |
| dataset (list or torch.utils.data.Dataset): a list of dataset dicts, | |
| or a pytorch dataset (either map-style or iterable). It can be obtained | |
| by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. | |
| mapper (callable): a callable which takes a sample (dict) from dataset and | |
| returns the format to be consumed by the model. | |
| When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. | |
| sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces | |
| indices to be applied on ``dataset``. | |
| If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`, | |
| which coordinates an infinite random shuffle sequence across all workers. | |
| Sampler must be None if ``dataset`` is iterable. | |
| total_batch_size (int): total batch size across all workers. | |
| aspect_ratio_grouping (bool): whether to group images with similar | |
| aspect ratio for efficiency. When enabled, it requires each | |
| element in dataset be a dict with keys "width" and "height". | |
| num_workers (int): number of parallel data loading workers | |
| collate_fn: a function that determines how to do batching, same as the argument of | |
| `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of | |
| data. No collation is OK for small batch size and simple data structures. | |
| If your batch size is large and each sample contains too many small tensors, | |
| it's more efficient to collate them in data loader. | |
| Returns: | |
| torch.utils.data.DataLoader: | |
| a dataloader. Each output from it is a ``list[mapped_element]`` of length | |
| ``total_batch_size / num_workers``, where ``mapped_element`` is produced | |
| by the ``mapper``. | |
| """ | |
| if isinstance(dataset, list): | |
| dataset = DatasetFromList(dataset, copy=False) | |
| if mapper is not None: | |
| dataset = MapDataset(dataset, mapper) | |
| if isinstance(dataset, torchdata.IterableDataset): | |
| assert sampler is None, "sampler must be None if dataset is IterableDataset" | |
| else: | |
| if sampler is None: | |
| sampler = TrainingSampler(len(dataset)) | |
| assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}" | |
| return build_batch_data_loader( | |
| dataset, | |
| sampler, | |
| total_batch_size, | |
| aspect_ratio_grouping=aspect_ratio_grouping, | |
| num_workers=num_workers, | |
| collate_fn=collate_fn, | |
| **kwargs | |
| ) | |
| def _test_loader_from_config(cfg, dataset_name, mapper=None): | |
| """ | |
| Uses the given `dataset_name` argument (instead of the names in cfg), because the | |
| standard practice is to evaluate each test set individually (not combining them). | |
| """ | |
| if isinstance(dataset_name, str): | |
| dataset_name = [dataset_name] | |
| dataset = get_detection_dataset_dicts( | |
| dataset_name, | |
| filter_empty=False, | |
| proposal_files=[ | |
| cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name | |
| ] | |
| if cfg.MODEL.LOAD_PROPOSALS | |
| else None, | |
| ) | |
| if mapper is None: | |
| mapper = DatasetMapper(cfg, False) | |
| return { | |
| "dataset": dataset, | |
| "mapper": mapper, | |
| "num_workers": cfg.DATALOADER.NUM_WORKERS, | |
| "sampler": InferenceSampler(len(dataset)) | |
| if not isinstance(dataset, torchdata.IterableDataset) | |
| else None, | |
| } | |
| def build_detection_test_loader( | |
| dataset: Union[List[Any], torchdata.Dataset], | |
| *, | |
| mapper: Callable[[Dict[str, Any]], Any], | |
| sampler: Optional[torchdata.Sampler] = None, | |
| batch_size: int = 1, | |
| num_workers: int = 0, | |
| collate_fn: Optional[Callable[[List[Any]], Any]] = None, | |
| ) -> torchdata.DataLoader: | |
| """ | |
| Similar to `build_detection_train_loader`, with default batch size = 1, | |
| and sampler = :class:`InferenceSampler`. This sampler coordinates all workers | |
| to produce the exact set of all samples. | |
| Args: | |
| dataset: a list of dataset dicts, | |
| or a pytorch dataset (either map-style or iterable). They can be obtained | |
| by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. | |
| mapper: a callable which takes a sample (dict) from dataset | |
| and returns the format to be consumed by the model. | |
| When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. | |
| sampler: a sampler that produces | |
| indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, | |
| which splits the dataset across all workers. Sampler must be None | |
| if `dataset` is iterable. | |
| batch_size: the batch size of the data loader to be created. | |
| Default to 1 image per worker since this is the standard when reporting | |
| inference time in papers. | |
| num_workers: number of parallel data loading workers | |
| collate_fn: same as the argument of `torch.utils.data.DataLoader`. | |
| Defaults to do no collation and return a list of data. | |
| Returns: | |
| DataLoader: a torch DataLoader, that loads the given detection | |
| dataset, with test-time transformation and batching. | |
| Examples: | |
| :: | |
| data_loader = build_detection_test_loader( | |
| DatasetRegistry.get("my_test"), | |
| mapper=DatasetMapper(...)) | |
| # or, instantiate with a CfgNode: | |
| data_loader = build_detection_test_loader(cfg, "my_test") | |
| """ | |
| if isinstance(dataset, list): | |
| dataset = DatasetFromList(dataset, copy=False) | |
| if mapper is not None: | |
| dataset = MapDataset(dataset, mapper) | |
| if isinstance(dataset, torchdata.IterableDataset): | |
| assert sampler is None, "sampler must be None if dataset is IterableDataset" | |
| else: | |
| if sampler is None: | |
| sampler = InferenceSampler(len(dataset)) | |
| return torchdata.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| sampler=sampler, | |
| drop_last=False, | |
| num_workers=num_workers, | |
| collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, | |
| ) | |
| def trivial_batch_collator(batch): | |
| """ | |
| A batch collator that does nothing. | |
| """ | |
| return batch | |
| def worker_init_reset_seed(worker_id): | |
| initial_seed = torch.initial_seed() % 2**31 | |
| seed_all_rng(initial_seed + worker_id) | |