D-FINE / src /data /transforms /container.py
developer0hye's picture
Upload 76 files
e85fecb verified
"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as T
from ...core import GLOBAL_CONFIG, register
from ._transforms import EmptyTransform
torchvision.disable_beta_transforms_warning()
@register()
class Compose(T.Compose):
def __init__(self, ops, policy=None) -> None:
transforms = []
if ops is not None:
for op in ops:
if isinstance(op, dict):
name = op.pop("type")
transform = getattr(
GLOBAL_CONFIG[name]["_pymodule"], GLOBAL_CONFIG[name]["_name"]
)(**op)
transforms.append(transform)
op["type"] = name
elif isinstance(op, nn.Module):
transforms.append(op)
else:
raise ValueError("")
else:
transforms = [
EmptyTransform(),
]
super().__init__(transforms=transforms)
if policy is None:
policy = {"name": "default"}
self.policy = policy
self.global_samples = 0
def forward(self, *inputs: Any) -> Any:
return self.get_forward(self.policy["name"])(*inputs)
def get_forward(self, name):
forwards = {
"default": self.default_forward,
"stop_epoch": self.stop_epoch_forward,
"stop_sample": self.stop_sample_forward,
}
return forwards[name]
def default_forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
def stop_epoch_forward(self, *inputs: Any):
sample = inputs if len(inputs) > 1 else inputs[0]
dataset = sample[-1]
cur_epoch = dataset.epoch
policy_ops = self.policy["ops"]
policy_epoch = self.policy["epoch"]
for transform in self.transforms:
if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch:
pass
else:
sample = transform(sample)
return sample
def stop_sample_forward(self, *inputs: Any):
sample = inputs if len(inputs) > 1 else inputs[0]
dataset = sample[-1]
cur_epoch = dataset.epoch
policy_ops = self.policy["ops"]
policy_sample = self.policy["sample"]
for transform in self.transforms:
if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample:
pass
else:
sample = transform(sample)
self.global_samples += 1
return sample