proxydet / third_party /CenterNet2 /centernet /data /custom_dataset_dataloader.py
joonhyun23452's picture
open proxydet demo
8075387
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
import operator
import torch
import torch.utils.data
import json
from detectron2.utils.comm import get_world_size
from detectron2.data import samplers
from torch.utils.data.sampler import BatchSampler, Sampler
from detectron2.data.common import DatasetFromList, MapDataset
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler
from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
from detectron2.data.build import filter_images_with_only_crowd_annotations
from detectron2.data.build import filter_images_with_few_keypoints
from detectron2.data.build import check_metadata_consistency
from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
from detectron2.utils import comm
import itertools
import math
from collections import defaultdict
from typing import Optional
# from .custom_build_augmentation import build_custom_augmentation
def build_custom_train_loader(cfg, mapper=None):
"""
Modified from detectron2.data.build.build_custom_train_loader, but supports
different samplers
"""
source_aware = cfg.DATALOADER.SOURCE_AWARE
if source_aware:
dataset_dicts = get_detection_dataset_dicts_with_source(
cfg.DATASETS.TRAIN,
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)
sizes = [0 for _ in range(len(cfg.DATASETS.TRAIN))]
for d in dataset_dicts:
sizes[d['dataset_source']] += 1
print('dataset sizes', sizes)
else:
dataset_dicts = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)
dataset = DatasetFromList(dataset_dicts, copy=False)
if mapper is None:
assert 0
# mapper = DatasetMapper(cfg, True)
dataset = MapDataset(dataset, mapper)
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
logger = logging.getLogger(__name__)
logger.info("Using training sampler {}".format(sampler_name))
# TODO avoid if-else?
if sampler_name == "TrainingSampler":
sampler = TrainingSampler(len(dataset))
elif sampler_name == "MultiDatasetSampler":
assert source_aware
sampler = MultiDatasetSampler(cfg, sizes, dataset_dicts)
elif sampler_name == "RepeatFactorTrainingSampler":
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
elif sampler_name == "ClassAwareSampler":
sampler = ClassAwareSampler(dataset_dicts)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))
return build_batch_data_loader(
dataset,
sampler,
cfg.SOLVER.IMS_PER_BATCH,
aspect_ratio_grouping=cfg.DATALOADER.ASPECT_RATIO_GROUPING,
num_workers=cfg.DATALOADER.NUM_WORKERS,
)
class ClassAwareSampler(Sampler):
def __init__(self, dataset_dicts, seed: Optional[int] = None):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
seed (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
self._size = len(dataset_dicts)
assert self._size > 0
if seed is None:
seed = comm.shared_random_seed()
self._seed = int(seed)
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
self.weights = self._get_class_balance_factor(dataset_dicts)
def __iter__(self):
start = self._rank
yield from itertools.islice(
self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
ids = torch.multinomial(
self.weights, self._size, generator=g,
replacement=True)
yield from ids
def _get_class_balance_factor(self, dataset_dicts, l=1.):
# 1. For each category c, compute the fraction of images that contain it: f(c)
ret = []
category_freq = defaultdict(int)
for dataset_dict in dataset_dicts: # For each image (without repeats)
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
for cat_id in cat_ids:
category_freq[cat_id] += 1
for i, dataset_dict in enumerate(dataset_dicts):
cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
ret.append(sum(
[1. / (category_freq[cat_id] ** l) for cat_id in cat_ids]))
return torch.tensor(ret).float()
def get_detection_dataset_dicts_with_source(
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
):
assert len(dataset_names)
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
for dataset_name, dicts in zip(dataset_names, dataset_dicts):
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
for source_id, (dataset_name, dicts) in \
enumerate(zip(dataset_names, dataset_dicts)):
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
for d in dicts:
d['dataset_source'] = source_id
if "annotations" in dicts[0]:
try:
class_names = MetadataCatalog.get(dataset_name).thing_classes
check_metadata_consistency("thing_classes", dataset_name)
print_instances_class_histogram(dicts, class_names)
except AttributeError: # class names are not available for this dataset
pass
assert proposal_files is None
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
has_instances = "annotations" in dataset_dicts[0]
if filter_empty and has_instances:
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
if min_keypoints > 0 and has_instances:
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
return dataset_dicts
class MultiDatasetSampler(Sampler):
def __init__(self, cfg, sizes, dataset_dicts, seed: Optional[int] = None):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
seed (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
self.sizes = sizes
dataset_ratio = cfg.DATALOADER.DATASET_RATIO
self._batch_size = cfg.SOLVER.IMS_PER_BATCH
assert len(dataset_ratio) == len(sizes), \
'length of dataset ratio {} should be equal to number if dataset {}'.format(
len(dataset_ratio), len(sizes)
)
if seed is None:
seed = comm.shared_random_seed()
self._seed = int(seed)
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
self._ims_per_gpu = self._batch_size // self._world_size
self.dataset_ids = torch.tensor(
[d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
dataset_weight = torch.cat(dataset_weight)
self.weights = dataset_weight
self.sample_epoch_size = len(self.weights)
def __iter__(self):
start = self._rank
yield from itertools.islice(
self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
ids = torch.multinomial(
self.weights, self.sample_epoch_size, generator=g,
replacement=True)
nums = [(self.dataset_ids[ids] == i).sum().int().item() \
for i in range(len(self.sizes))]
print('_rank, len, nums', self._rank, len(ids), nums, flush=True)
# print('_rank, len, nums, self.dataset_ids[ids[:10]], ',
# self._rank, len(ids), nums, self.dataset_ids[ids[:10]],
# flush=True)
yield from ids