File size: 3,753 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
import random
from functools import partial
import torch
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import torchvision.transforms.v2 as VT
from torch.utils.data import default_collate
from torchvision.transforms.v2 import InterpolationMode
from torchvision.transforms.v2 import functional as VF
from ..core import register
torchvision.disable_beta_transforms_warning()
__all__ = [
"DataLoader",
"BaseCollateFunction",
"BatchImageCollateFunction",
"batch_image_collate_fn",
]
@register()
class DataLoader(data.DataLoader):
__inject__ = ["dataset", "collate_fn"]
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for n in ["dataset", "batch_size", "num_workers", "drop_last", "collate_fn"]:
format_string += "\n"
format_string += " {0}: {1}".format(n, getattr(self, n))
format_string += "\n)"
return format_string
def set_epoch(self, epoch):
self._epoch = epoch
self.dataset.set_epoch(epoch)
self.collate_fn.set_epoch(epoch)
@property
def epoch(self):
return self._epoch if hasattr(self, "_epoch") else -1
@property
def shuffle(self):
return self._shuffle
@shuffle.setter
def shuffle(self, shuffle):
assert isinstance(shuffle, bool), "shuffle must be a boolean"
self._shuffle = shuffle
@register()
def batch_image_collate_fn(items):
"""only batch image"""
return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
class BaseCollateFunction(object):
def set_epoch(self, epoch):
self._epoch = epoch
@property
def epoch(self):
return self._epoch if hasattr(self, "_epoch") else -1
def __call__(self, items):
raise NotImplementedError("")
def generate_scales(base_size, base_size_repeat):
scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32
scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)]
scales += [base_size] * base_size_repeat
scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)]
return scales
@register()
class BatchImageCollateFunction(BaseCollateFunction):
def __init__(
self,
stop_epoch=None,
ema_restart_decay=0.9999,
base_size=640,
base_size_repeat=None,
) -> None:
super().__init__()
self.base_size = base_size
self.scales = (
generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None
)
self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000
self.ema_restart_decay = ema_restart_decay
# self.interpolation = interpolation
def __call__(self, items):
images = torch.cat([x[0][None] for x in items], dim=0)
targets = [x[1] for x in items]
if self.scales is not None and self.epoch < self.stop_epoch:
# sz = random.choice(self.scales)
# sz = [sz] if isinstance(sz, int) else list(sz)
# VF.resize(inpt, sz, interpolation=self.interpolation)
sz = random.choice(self.scales)
images = F.interpolate(images, size=sz)
if "masks" in targets[0]:
for tg in targets:
tg["masks"] = F.interpolate(tg["masks"], size=sz, mode="nearest")
raise NotImplementedError("")
return images, targets
|