|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import random
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.data as data
|
|
import torchvision
|
|
import torchvision.transforms.v2 as VT
|
|
from torch.utils.data import default_collate
|
|
from torchvision.transforms.v2 import InterpolationMode
|
|
from torchvision.transforms.v2 import functional as VF
|
|
|
|
from ..core import register
|
|
|
|
torchvision.disable_beta_transforms_warning()
|
|
|
|
|
|
__all__ = [
|
|
"DataLoader",
|
|
"BaseCollateFunction",
|
|
"BatchImageCollateFunction",
|
|
"batch_image_collate_fn",
|
|
]
|
|
|
|
|
|
@register()
|
|
class DataLoader(data.DataLoader):
|
|
__inject__ = ["dataset", "collate_fn"]
|
|
|
|
def __repr__(self) -> str:
|
|
format_string = self.__class__.__name__ + "("
|
|
for n in ["dataset", "batch_size", "num_workers", "drop_last", "collate_fn"]:
|
|
format_string += "\n"
|
|
format_string += " {0}: {1}".format(n, getattr(self, n))
|
|
format_string += "\n)"
|
|
return format_string
|
|
|
|
def set_epoch(self, epoch):
|
|
self._epoch = epoch
|
|
self.dataset.set_epoch(epoch)
|
|
self.collate_fn.set_epoch(epoch)
|
|
|
|
@property
|
|
def epoch(self):
|
|
return self._epoch if hasattr(self, "_epoch") else -1
|
|
|
|
@property
|
|
def shuffle(self):
|
|
return self._shuffle
|
|
|
|
@shuffle.setter
|
|
def shuffle(self, shuffle):
|
|
assert isinstance(shuffle, bool), "shuffle must be a boolean"
|
|
self._shuffle = shuffle
|
|
|
|
|
|
@register()
|
|
def batch_image_collate_fn(items):
|
|
"""only batch image"""
|
|
return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
|
|
|
|
|
|
class BaseCollateFunction(object):
|
|
def set_epoch(self, epoch):
|
|
self._epoch = epoch
|
|
|
|
@property
|
|
def epoch(self):
|
|
return self._epoch if hasattr(self, "_epoch") else -1
|
|
|
|
def __call__(self, items):
|
|
raise NotImplementedError("")
|
|
|
|
|
|
def generate_scales(base_size, base_size_repeat):
|
|
scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32
|
|
scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)]
|
|
scales += [base_size] * base_size_repeat
|
|
scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)]
|
|
return scales
|
|
|
|
|
|
@register()
|
|
class BatchImageCollateFunction(BaseCollateFunction):
|
|
def __init__(
|
|
self,
|
|
stop_epoch=None,
|
|
ema_restart_decay=0.9999,
|
|
base_size=640,
|
|
base_size_repeat=None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.base_size = base_size
|
|
self.scales = (
|
|
generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None
|
|
)
|
|
self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000
|
|
self.ema_restart_decay = ema_restart_decay
|
|
|
|
|
|
def __call__(self, items):
|
|
images = torch.cat([x[0][None] for x in items], dim=0)
|
|
targets = [x[1] for x in items]
|
|
|
|
if self.scales is not None and self.epoch < self.stop_epoch:
|
|
|
|
|
|
|
|
|
|
sz = random.choice(self.scales)
|
|
images = F.interpolate(images, size=sz)
|
|
if "masks" in targets[0]:
|
|
for tg in targets:
|
|
tg["masks"] = F.interpolate(tg["masks"], size=sz, mode="nearest")
|
|
raise NotImplementedError("")
|
|
|
|
return images, targets
|
|
|