Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
from typing import Any, Dict, List, Optional | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.transforms.v2 as T | |
from ...core import GLOBAL_CONFIG, register | |
from ._transforms import EmptyTransform | |
torchvision.disable_beta_transforms_warning() | |
class Compose(T.Compose): | |
def __init__(self, ops, policy=None) -> None: | |
transforms = [] | |
if ops is not None: | |
for op in ops: | |
if isinstance(op, dict): | |
name = op.pop("type") | |
transform = getattr( | |
GLOBAL_CONFIG[name]["_pymodule"], GLOBAL_CONFIG[name]["_name"] | |
)(**op) | |
transforms.append(transform) | |
op["type"] = name | |
elif isinstance(op, nn.Module): | |
transforms.append(op) | |
else: | |
raise ValueError("") | |
else: | |
transforms = [ | |
EmptyTransform(), | |
] | |
super().__init__(transforms=transforms) | |
if policy is None: | |
policy = {"name": "default"} | |
self.policy = policy | |
self.global_samples = 0 | |
def forward(self, *inputs: Any) -> Any: | |
return self.get_forward(self.policy["name"])(*inputs) | |
def get_forward(self, name): | |
forwards = { | |
"default": self.default_forward, | |
"stop_epoch": self.stop_epoch_forward, | |
"stop_sample": self.stop_sample_forward, | |
} | |
return forwards[name] | |
def default_forward(self, *inputs: Any) -> Any: | |
sample = inputs if len(inputs) > 1 else inputs[0] | |
for transform in self.transforms: | |
sample = transform(sample) | |
return sample | |
def stop_epoch_forward(self, *inputs: Any): | |
sample = inputs if len(inputs) > 1 else inputs[0] | |
dataset = sample[-1] | |
cur_epoch = dataset.epoch | |
policy_ops = self.policy["ops"] | |
policy_epoch = self.policy["epoch"] | |
for transform in self.transforms: | |
if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch: | |
pass | |
else: | |
sample = transform(sample) | |
return sample | |
def stop_sample_forward(self, *inputs: Any): | |
sample = inputs if len(inputs) > 1 else inputs[0] | |
dataset = sample[-1] | |
cur_epoch = dataset.epoch | |
policy_ops = self.policy["ops"] | |
policy_sample = self.policy["sample"] | |
for transform in self.transforms: | |
if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample: | |
pass | |
else: | |
sample = transform(sample) | |
self.global_samples += 1 | |
return sample | |