diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4af7a3ef436ef75661194b9b6ce03d93809210b --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,6 @@ +""" +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +""" + +# for register purpose +from . import data, nn, optim, zoo diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..244e86960908ab23d912a0338e396c201307bb28 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,9 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from ._config import BaseConfig +from .workspace import GLOBAL_CONFIG, create, register +from .yaml_config import YAMLConfig +from .yaml_utils import * diff --git a/src/core/_config.py b/src/core/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f4abdad33e8382c56b4dd8a798a9be334e98dba8 --- /dev/null +++ b/src/core/_config.py @@ -0,0 +1,299 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from pathlib import Path +from typing import Callable, Dict, List + +import torch +import torch.nn as nn +from torch.cuda.amp.grad_scaler import GradScaler +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data import DataLoader, Dataset +from torch.utils.tensorboard import SummaryWriter + +__all__ = [ + "BaseConfig", +] + + +class BaseConfig(object): + # TODO property + + def __init__(self) -> None: + super().__init__() + + self.task: str = None + + # instance / function + self._model: nn.Module = None + self._postprocessor: nn.Module = None + self._criterion: nn.Module = None + self._optimizer: Optimizer = None + self._lr_scheduler: LRScheduler = None + self._lr_warmup_scheduler: LRScheduler = None + self._train_dataloader: DataLoader = None + self._val_dataloader: DataLoader = None + self._ema: nn.Module = None + self._scaler: GradScaler = None + self._train_dataset: Dataset = None + self._val_dataset: Dataset = None + self._collate_fn: Callable = None + self._evaluator: Callable[[nn.Module, DataLoader, str],] = None + self._writer: SummaryWriter = None + + # dataset + self.num_workers: int = 0 + self.batch_size: int = None + self._train_batch_size: int = None + self._val_batch_size: int = None + self._train_shuffle: bool = None + self._val_shuffle: bool = None + + # runtime + self.resume: str = None + self.tuning: str = None + + self.epochs: int = None + self.last_epoch: int = -1 + + self.use_amp: bool = False + self.use_ema: bool = False + self.ema_decay: float = 0.9999 + self.ema_warmups: int = 2000 + self.sync_bn: bool = False + self.clip_max_norm: float = 0.0 + self.find_unused_parameters: bool = None + + self.seed: int = None + self.print_freq: int = None + self.checkpoint_freq: int = 1 + self.output_dir: str = None + self.summary_dir: str = None + self.device: str = "" + + @property + def model(self) -> nn.Module: + return self._model + + @model.setter + def model(self, m): + assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class" + self._model = m + + @property + def postprocessor(self) -> nn.Module: + return self._postprocessor + + @postprocessor.setter + def postprocessor(self, m): + assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class" + self._postprocessor = m + + @property + def criterion(self) -> nn.Module: + return self._criterion + + @criterion.setter + def criterion(self, m): + assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class" + self._criterion = m + + @property + def optimizer(self) -> Optimizer: + return self._optimizer + + @optimizer.setter + def optimizer(self, m): + assert isinstance( + m, Optimizer + ), f"{type(m)} != optim.Optimizer, please check your model class" + self._optimizer = m + + @property + def lr_scheduler(self) -> LRScheduler: + return self._lr_scheduler + + @lr_scheduler.setter + def lr_scheduler(self, m): + assert isinstance( + m, LRScheduler + ), f"{type(m)} != LRScheduler, please check your model class" + self._lr_scheduler = m + + @property + def lr_warmup_scheduler(self) -> LRScheduler: + return self._lr_warmup_scheduler + + @lr_warmup_scheduler.setter + def lr_warmup_scheduler(self, m): + self._lr_warmup_scheduler = m + + @property + def train_dataloader(self) -> DataLoader: + if self._train_dataloader is None and self.train_dataset is not None: + loader = DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + shuffle=self.train_shuffle, + ) + loader.shuffle = self.train_shuffle + self._train_dataloader = loader + + return self._train_dataloader + + @train_dataloader.setter + def train_dataloader(self, loader): + self._train_dataloader = loader + + @property + def val_dataloader(self) -> DataLoader: + if self._val_dataloader is None and self.val_dataset is not None: + loader = DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_workers, + drop_last=False, + collate_fn=self.collate_fn, + shuffle=self.val_shuffle, + persistent_workers=True, + ) + loader.shuffle = self.val_shuffle + self._val_dataloader = loader + + return self._val_dataloader + + @val_dataloader.setter + def val_dataloader(self, loader): + self._val_dataloader = loader + + @property + def ema(self) -> nn.Module: + if self._ema is None and self.use_ema and self.model is not None: + from ..optim import ModelEMA + + self._ema = ModelEMA(self.model, self.ema_decay, self.ema_warmups) + return self._ema + + @ema.setter + def ema(self, obj): + self._ema = obj + + @property + def scaler(self) -> GradScaler: + if self._scaler is None and self.use_amp and torch.cuda.is_available(): + self._scaler = GradScaler() + return self._scaler + + @scaler.setter + def scaler(self, obj: GradScaler): + self._scaler = obj + + @property + def val_shuffle(self) -> bool: + if self._val_shuffle is None: + print("warning: set default val_shuffle=False") + return False + return self._val_shuffle + + @val_shuffle.setter + def val_shuffle(self, shuffle): + assert isinstance(shuffle, bool), "shuffle must be bool" + self._val_shuffle = shuffle + + @property + def train_shuffle(self) -> bool: + if self._train_shuffle is None: + print("warning: set default train_shuffle=True") + return True + return self._train_shuffle + + @train_shuffle.setter + def train_shuffle(self, shuffle): + assert isinstance(shuffle, bool), "shuffle must be bool" + self._train_shuffle = shuffle + + @property + def train_batch_size(self) -> int: + if self._train_batch_size is None and isinstance(self.batch_size, int): + print(f"warning: set train_batch_size=batch_size={self.batch_size}") + return self.batch_size + return self._train_batch_size + + @train_batch_size.setter + def train_batch_size(self, batch_size): + assert isinstance(batch_size, int), "batch_size must be int" + self._train_batch_size = batch_size + + @property + def val_batch_size(self) -> int: + if self._val_batch_size is None: + print(f"warning: set val_batch_size=batch_size={self.batch_size}") + return self.batch_size + return self._val_batch_size + + @val_batch_size.setter + def val_batch_size(self, batch_size): + assert isinstance(batch_size, int), "batch_size must be int" + self._val_batch_size = batch_size + + @property + def train_dataset(self) -> Dataset: + return self._train_dataset + + @train_dataset.setter + def train_dataset(self, dataset): + assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset" + self._train_dataset = dataset + + @property + def val_dataset(self) -> Dataset: + return self._val_dataset + + @val_dataset.setter + def val_dataset(self, dataset): + assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset" + self._val_dataset = dataset + + @property + def collate_fn(self) -> Callable: + return self._collate_fn + + @collate_fn.setter + def collate_fn(self, fn): + assert isinstance(fn, Callable), f"{type(fn)} must be Callable" + self._collate_fn = fn + + @property + def evaluator(self) -> Callable: + return self._evaluator + + @evaluator.setter + def evaluator(self, fn): + assert isinstance(fn, Callable), f"{type(fn)} must be Callable" + self._evaluator = fn + + @property + def writer(self) -> SummaryWriter: + if self._writer is None: + if self.summary_dir: + self._writer = SummaryWriter(self.summary_dir) + elif self.output_dir: + self._writer = SummaryWriter(Path(self.output_dir) / "summary") + return self._writer + + @writer.setter + def writer(self, m): + assert isinstance(m, SummaryWriter), f"{type(m)} must be SummaryWriter" + self._writer = m + + def __repr__(self): + s = "" + for k, v in self.__dict__.items(): + if not k.startswith("_"): + s += f"{k}: {v}\n" + return s diff --git a/src/core/workspace.py b/src/core/workspace.py new file mode 100644 index 0000000000000000000000000000000000000000..993f6fe4d744cc9ecf0a531ba8c1f42d0b092611 --- /dev/null +++ b/src/core/workspace.py @@ -0,0 +1,178 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import functools +import importlib +import inspect +from collections import defaultdict +from typing import Any, Dict, List, Optional + +GLOBAL_CONFIG = defaultdict(dict) + + +def register(dct: Any = GLOBAL_CONFIG, name=None, force=False): + """ + dct: + if dct is Dict, register foo into dct as key-value pair + if dct is Clas, register as modules attibute + force + whether force register. + """ + + def decorator(foo): + register_name = foo.__name__ if name is None else name + if not force: + if inspect.isclass(dct): + assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}" + else: + assert foo.__name__ not in dct, f"{foo.__name__} has been already registered" + + if inspect.isfunction(foo): + + @functools.wraps(foo) + def wrap_func(*args, **kwargs): + return foo(*args, **kwargs) + + if isinstance(dct, dict): + dct[foo.__name__] = wrap_func + elif inspect.isclass(dct): + setattr(dct, foo.__name__, wrap_func) + else: + raise AttributeError("") + return wrap_func + + elif inspect.isclass(foo): + dct[register_name] = extract_schema(foo) + + else: + raise ValueError(f"Do not support {type(foo)} register") + + return foo + + return decorator + + +def extract_schema(module: type): + """ + Args: + module (type), + Return: + Dict, + """ + argspec = inspect.getfullargspec(module.__init__) + arg_names = [arg for arg in argspec.args if arg != "self"] + num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0 + num_requires = len(arg_names) - num_defualts + + schame = dict() + schame["_name"] = module.__name__ + schame["_pymodule"] = importlib.import_module(module.__module__) + schame["_inject"] = getattr(module, "__inject__", []) + schame["_share"] = getattr(module, "__share__", []) + schame["_kwargs"] = {} + for i, name in enumerate(arg_names): + if name in schame["_share"]: + assert i >= num_requires, "share config must have default value." + value = argspec.defaults[i - num_requires] + + elif i >= num_requires: + value = argspec.defaults[i - num_requires] + + else: + value = None + + schame[name] = value + schame["_kwargs"][name] = value + + return schame + + +def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs): + """ """ + assert type(type_or_name) in (type, str), "create should be modules or name." + + name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__ + + if name in global_cfg: + if hasattr(global_cfg[name], "__dict__"): + return global_cfg[name] + else: + raise ValueError("The module {} is not registered".format(name)) + + cfg = global_cfg[name] + + if isinstance(cfg, dict) and "type" in cfg: + _cfg: dict = global_cfg[cfg["type"]] + # clean args + _keys = [k for k in _cfg.keys() if not k.startswith("_")] + for _arg in _keys: + del _cfg[_arg] + _cfg.update(_cfg["_kwargs"]) # restore default args + _cfg.update(cfg) # load config args + _cfg.update(kwargs) # TODO recive extra kwargs + name = _cfg.pop("type") # pop extra key `type` (from cfg) + + return create(name, global_cfg) + + module = getattr(cfg["_pymodule"], name) + module_kwargs = {} + module_kwargs.update(cfg) + + # shared var + for k in cfg["_share"]: + if k in global_cfg: + module_kwargs[k] = global_cfg[k] + else: + module_kwargs[k] = cfg[k] + + # inject + for k in cfg["_inject"]: + _k = cfg[k] + + if _k is None: + continue + + if isinstance(_k, str): + if _k not in global_cfg: + raise ValueError(f"Missing inject config of {_k}.") + + _cfg = global_cfg[_k] + + if isinstance(_cfg, dict): + module_kwargs[k] = create(_cfg["_name"], global_cfg) + else: + module_kwargs[k] = _cfg + + elif isinstance(_k, dict): + if "type" not in _k.keys(): + raise ValueError("Missing inject for `type` style.") + + _type = str(_k["type"]) + if _type not in global_cfg: + raise ValueError(f"Missing {_type} in inspect stage.") + + # TODO + _cfg: dict = global_cfg[_type] + # clean args + _keys = [k for k in _cfg.keys() if not k.startswith("_")] + for _arg in _keys: + del _cfg[_arg] + _cfg.update(_cfg["_kwargs"]) # restore default values + _cfg.update(_k) # load config args + name = _cfg.pop("type") # pop extra key (`type` from _k) + module_kwargs[k] = create(name, global_cfg) + + else: + raise ValueError(f"Inject does not support {_k}") + + # TODO hard code + module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")} + + # TODO for **kwargs + # extra_args = set(module_kwargs.keys()) - set(arg_names) + # if len(extra_args) > 0: + # raise RuntimeError(f'Error: unknown args {extra_args} for {module}') + + return module(**module_kwargs) diff --git a/src/core/yaml_config.py b/src/core/yaml_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d1b62acf021aaa8e98360843ce2354de02cbc4 --- /dev/null +++ b/src/core/yaml_config.py @@ -0,0 +1,187 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import copy +import re + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader + +from ._config import BaseConfig +from .workspace import create +from .yaml_utils import load_config, merge_config, merge_dict + + +class YAMLConfig(BaseConfig): + def __init__(self, cfg_path: str, **kwargs) -> None: + super().__init__() + + cfg = load_config(cfg_path) + cfg = merge_dict(cfg, kwargs) + + self.yaml_cfg = copy.deepcopy(cfg) + + for k in super().__dict__: + if not k.startswith("_") and k in cfg: + self.__dict__[k] = cfg[k] + + @property + def global_cfg(self): + return merge_config(self.yaml_cfg, inplace=False, overwrite=False) + + @property + def model(self) -> torch.nn.Module: + if self._model is None and "model" in self.yaml_cfg: + self._model = create(self.yaml_cfg["model"], self.global_cfg) + return super().model + + @property + def postprocessor(self) -> torch.nn.Module: + if self._postprocessor is None and "postprocessor" in self.yaml_cfg: + self._postprocessor = create(self.yaml_cfg["postprocessor"], self.global_cfg) + return super().postprocessor + + @property + def criterion(self) -> torch.nn.Module: + if self._criterion is None and "criterion" in self.yaml_cfg: + self._criterion = create(self.yaml_cfg["criterion"], self.global_cfg) + return super().criterion + + @property + def optimizer(self) -> optim.Optimizer: + if self._optimizer is None and "optimizer" in self.yaml_cfg: + params = self.get_optim_params(self.yaml_cfg["optimizer"], self.model) + self._optimizer = create("optimizer", self.global_cfg, params=params) + return super().optimizer + + @property + def lr_scheduler(self) -> optim.lr_scheduler.LRScheduler: + if self._lr_scheduler is None and "lr_scheduler" in self.yaml_cfg: + self._lr_scheduler = create("lr_scheduler", self.global_cfg, optimizer=self.optimizer) + print(f"Initial lr: {self._lr_scheduler.get_last_lr()}") + return super().lr_scheduler + + @property + def lr_warmup_scheduler(self) -> optim.lr_scheduler.LRScheduler: + if self._lr_warmup_scheduler is None and "lr_warmup_scheduler" in self.yaml_cfg: + self._lr_warmup_scheduler = create( + "lr_warmup_scheduler", self.global_cfg, lr_scheduler=self.lr_scheduler + ) + return super().lr_warmup_scheduler + + @property + def train_dataloader(self) -> DataLoader: + if self._train_dataloader is None and "train_dataloader" in self.yaml_cfg: + self._train_dataloader = self.build_dataloader("train_dataloader") + return super().train_dataloader + + @property + def val_dataloader(self) -> DataLoader: + if self._val_dataloader is None and "val_dataloader" in self.yaml_cfg: + self._val_dataloader = self.build_dataloader("val_dataloader") + return super().val_dataloader + + @property + def ema(self) -> torch.nn.Module: + if self._ema is None and self.yaml_cfg.get("use_ema", False): + self._ema = create("ema", self.global_cfg, model=self.model) + return super().ema + + @property + def scaler(self): + if self._scaler is None and self.yaml_cfg.get("use_amp", False): + self._scaler = create("scaler", self.global_cfg) + return super().scaler + + @property + def evaluator(self): + if self._evaluator is None and "evaluator" in self.yaml_cfg: + if self.yaml_cfg["evaluator"]["type"] == "CocoEvaluator": + from ..data import get_coco_api_from_dataset + + base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset) + self._evaluator = create("evaluator", self.global_cfg, coco_gt=base_ds) + else: + raise NotImplementedError(f"{self.yaml_cfg['evaluator']['type']}") + return super().evaluator + + @property + def use_wandb(self) -> bool: + return self.yaml_cfg.get("use_wandb", False) + + @staticmethod + def get_optim_params(cfg: dict, model: nn.Module): + """ + E.g.: + ^(?=.*a)(?=.*b).*$ means including a and b + ^(?=.*(?:a|b)).*$ means including a or b + ^(?=.*a)(?!.*b).*$ means including a, but not b + """ + assert "type" in cfg, "" + cfg = copy.deepcopy(cfg) + + if "params" not in cfg: + return model.parameters() + + assert isinstance(cfg["params"], list), "" + + param_groups = [] + visited = [] + for pg in cfg["params"]: + pattern = pg["params"] + params = { + k: v + for k, v in model.named_parameters() + if v.requires_grad and len(re.findall(pattern, k)) > 0 + } + pg["params"] = params.values() + param_groups.append(pg) + visited.extend(list(params.keys())) + # print(params.keys()) + + names = [k for k, v in model.named_parameters() if v.requires_grad] + + if len(visited) < len(names): + unseen = set(names) - set(visited) + params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen} + param_groups.append({"params": params.values()}) + visited.extend(list(params.keys())) + # print(params.keys()) + + assert len(visited) == len(names), "" + + return param_groups + + @staticmethod + def get_rank_batch_size(cfg): + """compute batch size for per rank if total_batch_size is provided.""" + assert ("total_batch_size" in cfg or "batch_size" in cfg) and not ( + "total_batch_size" in cfg and "batch_size" in cfg + ), "`batch_size` or `total_batch_size` should be choosed one" + + total_batch_size = cfg.get("total_batch_size", None) + if total_batch_size is None: + bs = cfg.get("batch_size") + else: + from ..misc import dist_utils + + assert ( + total_batch_size % dist_utils.get_world_size() == 0 + ), "total_batch_size should be divisible by world size" + bs = total_batch_size // dist_utils.get_world_size() + return bs + + def build_dataloader(self, name: str): + bs = self.get_rank_batch_size(self.yaml_cfg[name]) + global_cfg = self.global_cfg + if "total_batch_size" in global_cfg[name]: + # pop unexpected key for dataloader init + _ = global_cfg[name].pop("total_batch_size") + print(f"building {name} with batch_size={bs}...") + loader = create(name, global_cfg, batch_size=bs) + loader.shuffle = self.yaml_cfg[name].get("shuffle", False) + return loader diff --git a/src/core/yaml_utils.py b/src/core/yaml_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0cf17418131ca601df4d3f8c2c7e3be4f3450e1 --- /dev/null +++ b/src/core/yaml_utils.py @@ -0,0 +1,126 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import copy +import os +from typing import Any, Dict, List, Optional + +import yaml + +from .workspace import GLOBAL_CONFIG + +__all__ = [ + "load_config", + "merge_config", + "merge_dict", + "parse_cli", +] + + +INCLUDE_KEY = "__include__" + + +def load_config(file_path, cfg=dict()): + """load config""" + _, ext = os.path.splitext(file_path) + assert ext in [".yml", ".yaml"], "only support yaml files" + + with open(file_path) as f: + file_cfg = yaml.load(f, Loader=yaml.Loader) + if file_cfg is None: + return {} + + if INCLUDE_KEY in file_cfg: + base_yamls = list(file_cfg[INCLUDE_KEY]) + for base_yaml in base_yamls: + if base_yaml.startswith("~"): + base_yaml = os.path.expanduser(base_yaml) + + if not base_yaml.startswith("/"): + base_yaml = os.path.join(os.path.dirname(file_path), base_yaml) + + with open(base_yaml) as f: + base_cfg = load_config(base_yaml, cfg) + merge_dict(cfg, base_cfg) + + return merge_dict(cfg, file_cfg) + + +def merge_dict(dct, another_dct, inplace=True) -> Dict: + """merge another_dct into dct""" + + def _merge(dct, another) -> Dict: + for k in another: + if k in dct and isinstance(dct[k], dict) and isinstance(another[k], dict): + _merge(dct[k], another[k]) + else: + dct[k] = another[k] + + return dct + + if not inplace: + dct = copy.deepcopy(dct) + + return _merge(dct, another_dct) + + +def dictify(s: str, v: Any) -> Dict: + if "." not in s: + return {s: v} + key, rest = s.split(".", 1) + return {key: dictify(rest, v)} + + +def parse_cli(nargs: List[str]) -> Dict: + """ + parse command-line arguments + convert `a.c=3 b=10` to `{'a': {'c': 3}, 'b': 10}` + """ + cfg = {} + if nargs is None or len(nargs) == 0: + return cfg + + for s in nargs: + s = s.strip() + k, v = s.split("=", 1) + d = dictify(k, yaml.load(v, Loader=yaml.Loader)) + cfg = merge_dict(cfg, d) + + return cfg + + +def merge_config(cfg, another_cfg=GLOBAL_CONFIG, inplace: bool = False, overwrite: bool = False): + """ + Merge another_cfg into cfg, return the merged config + + Example: + + cfg1 = load_config('./dfine_r18vd_6x_coco.yml') + cfg1 = merge_config(cfg, inplace=True) + + cfg2 = load_config('./dfine_r50vd_6x_coco.yml') + cfg2 = merge_config(cfg2, inplace=True) + + model1 = create(cfg1['model'], cfg1) + model2 = create(cfg2['model'], cfg2) + """ + + def _merge(dct, another): + for k in another: + if k not in dct: + dct[k] = another[k] + + elif isinstance(dct[k], dict) and isinstance(another[k], dict): + _merge(dct[k], another[k]) + + elif overwrite: + dct[k] = another[k] + + return cfg + + if not inplace: + cfg = copy.deepcopy(cfg) + + return _merge(cfg, another_cfg) diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..318a7da3162a231a4a388a6c53192f3b777fbe36 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,20 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from ._misc import convert_to_tv_tensor +from .dataloader import * +from .dataset import * +from .transforms import * + + +# def set_epoch(self, epoch) -> None: +# self.epoch = epoch +# def _set_epoch_func(datasets): +# """Add `set_epoch` for datasets +# """ +# from ..core import register +# for ds in datasets: +# register(ds)(set_epoch) +# _set_epoch_func([CIFAR10, VOCDetection, CocoDetection]) diff --git a/src/data/_misc.py b/src/data/_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..343d1038ab23af89f71b157ef3ad6a201d62e6f0 --- /dev/null +++ b/src/data/_misc.py @@ -0,0 +1,62 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import importlib.metadata + +from torch import Tensor + +if "0.15.2" in importlib.metadata.version("torchvision"): + import torchvision + + torchvision.disable_beta_transforms_warning() + + from torchvision.datapoints import BoundingBox as BoundingBoxes + from torchvision.datapoints import BoundingBoxFormat, Image, Mask, Video + from torchvision.transforms.v2 import SanitizeBoundingBox as SanitizeBoundingBoxes + + _boxes_keys = ["format", "spatial_size"] + +elif "0.17" > importlib.metadata.version("torchvision") >= "0.16": + import torchvision + + torchvision.disable_beta_transforms_warning() + + from torchvision.transforms.v2 import SanitizeBoundingBoxes + from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video + + _boxes_keys = ["format", "canvas_size"] + +elif importlib.metadata.version("torchvision") >= "0.17": + import torchvision + from torchvision.transforms.v2 import SanitizeBoundingBoxes + from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video + + _boxes_keys = ["format", "canvas_size"] + +else: + raise RuntimeError("Please make sure torchvision version >= 0.15.2") + + +def convert_to_tv_tensor(tensor: Tensor, key: str, box_format="xyxy", spatial_size=None) -> Tensor: + """ + Args: + tensor (Tensor): input tensor + key (str): transform to key + + Return: + Dict[str, TV_Tensor] + """ + assert key in ( + "boxes", + "masks", + ), "Only support 'boxes' and 'masks'" + + if key == "boxes": + box_format = getattr(BoundingBoxFormat, box_format.upper()) + _kwargs = dict(zip(_boxes_keys, [box_format, spatial_size])) + return BoundingBoxes(tensor, **_kwargs) + + if key == "masks": + return Mask(tensor) diff --git a/src/data/dataloader.py b/src/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..0439a30976e2ec6123445f8b81ce6df2881a6c02 --- /dev/null +++ b/src/data/dataloader.py @@ -0,0 +1,122 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import random +from functools import partial + +import torch +import torch.nn.functional as F +import torch.utils.data as data +import torchvision +import torchvision.transforms.v2 as VT +from torch.utils.data import default_collate +from torchvision.transforms.v2 import InterpolationMode +from torchvision.transforms.v2 import functional as VF + +from ..core import register + +torchvision.disable_beta_transforms_warning() + + +__all__ = [ + "DataLoader", + "BaseCollateFunction", + "BatchImageCollateFunction", + "batch_image_collate_fn", +] + + +@register() +class DataLoader(data.DataLoader): + __inject__ = ["dataset", "collate_fn"] + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + for n in ["dataset", "batch_size", "num_workers", "drop_last", "collate_fn"]: + format_string += "\n" + format_string += " {0}: {1}".format(n, getattr(self, n)) + format_string += "\n)" + return format_string + + def set_epoch(self, epoch): + self._epoch = epoch + self.dataset.set_epoch(epoch) + self.collate_fn.set_epoch(epoch) + + @property + def epoch(self): + return self._epoch if hasattr(self, "_epoch") else -1 + + @property + def shuffle(self): + return self._shuffle + + @shuffle.setter + def shuffle(self, shuffle): + assert isinstance(shuffle, bool), "shuffle must be a boolean" + self._shuffle = shuffle + + +@register() +def batch_image_collate_fn(items): + """only batch image""" + return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items] + + +class BaseCollateFunction(object): + def set_epoch(self, epoch): + self._epoch = epoch + + @property + def epoch(self): + return self._epoch if hasattr(self, "_epoch") else -1 + + def __call__(self, items): + raise NotImplementedError("") + + +def generate_scales(base_size, base_size_repeat): + scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32 + scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)] + scales += [base_size] * base_size_repeat + scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)] + return scales + + +@register() +class BatchImageCollateFunction(BaseCollateFunction): + def __init__( + self, + stop_epoch=None, + ema_restart_decay=0.9999, + base_size=640, + base_size_repeat=None, + ) -> None: + super().__init__() + self.base_size = base_size + self.scales = ( + generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None + ) + self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000 + self.ema_restart_decay = ema_restart_decay + # self.interpolation = interpolation + + def __call__(self, items): + images = torch.cat([x[0][None] for x in items], dim=0) + targets = [x[1] for x in items] + + if self.scales is not None and self.epoch < self.stop_epoch: + # sz = random.choice(self.scales) + # sz = [sz] if isinstance(sz, int) else list(sz) + # VF.resize(inpt, sz, interpolation=self.interpolation) + + sz = random.choice(self.scales) + images = F.interpolate(images, size=sz) + if "masks" in targets[0]: + for tg in targets: + tg["masks"] = F.interpolate(tg["masks"], size=sz, mode="nearest") + raise NotImplementedError("") + + return images, targets diff --git a/src/data/dataset/__init__.py b/src/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00c6ce3d632e87fd4496cc5ba8d5f0a7d14eba39 --- /dev/null +++ b/src/data/dataset/__init__.py @@ -0,0 +1,17 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +# from ._dataset import DetDataset +from .cifar_dataset import CIFAR10 +from .coco_dataset import ( + CocoDetection, + mscoco_category2label, + mscoco_category2name, + mscoco_label2category, +) +from .coco_eval import CocoEvaluator +from .coco_utils import get_coco_api_from_dataset +from .voc_detection import VOCDetection +from .voc_eval import VOCEvaluator diff --git a/src/data/dataset/_dataset.py b/src/data/dataset/_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef1dadf4f783b2eb3971199d5251717c2ff8761 --- /dev/null +++ b/src/data/dataset/_dataset.py @@ -0,0 +1,27 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.utils.data as data + + +class DetDataset(data.Dataset): + def __getitem__(self, index): + img, target = self.load_item(index) + if self.transforms is not None: + img, target, _ = self.transforms(img, target, self) + return img, target + + def load_item(self, index): + raise NotImplementedError( + "Please implement this function to return item before `transforms`." + ) + + def set_epoch(self, epoch) -> None: + self._epoch = epoch + + @property + def epoch(self): + return self._epoch if hasattr(self, "_epoch") else -1 diff --git a/src/data/dataset/cifar_dataset.py b/src/data/dataset/cifar_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..da98485f450cd762f39bca80ffb9c1da56c3132c --- /dev/null +++ b/src/data/dataset/cifar_dataset.py @@ -0,0 +1,25 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from typing import Callable, Optional + +import torchvision + +from ...core import register + + +@register() +class CIFAR10(torchvision.datasets.CIFAR10): + __inject__ = ["transform", "target_transform"] + + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, train, transform, target_transform, download) diff --git a/src/data/dataset/coco_dataset.py b/src/data/dataset/coco_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1f1f0500089b6c504736433cfa3eeceef8f879 --- /dev/null +++ b/src/data/dataset/coco_dataset.py @@ -0,0 +1,282 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py + +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import faster_coco_eval +import faster_coco_eval.core.mask as coco_mask +import torch +import torch.utils.data +import torchvision +import os +from PIL import Image + +from ...core import register +from .._misc import convert_to_tv_tensor +from ._dataset import DetDataset + +torchvision.disable_beta_transforms_warning() +faster_coco_eval.init_as_pycocotools() +Image.MAX_IMAGE_PIXELS = None + +__all__ = ["CocoDetection"] + + +@register() +class CocoDetection(torchvision.datasets.CocoDetection, DetDataset): + __inject__ = [ + "transforms", + ] + __share__ = ["remap_mscoco_category"] + + def __init__( + self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False + ): + super(CocoDetection, self).__init__(img_folder, ann_file) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks) + self.img_folder = img_folder + self.ann_file = ann_file + self.return_masks = return_masks + self.remap_mscoco_category = remap_mscoco_category + + def __getitem__(self, idx): + img, target = self.load_item(idx) + if self._transforms is not None: + img, target, _ = self._transforms(img, target, self) + return img, target + + def load_item(self, idx): + image, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + image_path = os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"]) + target = {"image_id": image_id, "image_path": image_path, "annotations": target} + + if self.remap_mscoco_category: + image, target = self.prepare(image, target, category2label=mscoco_category2label) + else: + image, target = self.prepare(image, target) + + target["idx"] = torch.tensor([idx]) + + if "boxes" in target: + target["boxes"] = convert_to_tv_tensor( + target["boxes"], key="boxes", spatial_size=image.size[::-1] + ) + + if "masks" in target: + target["masks"] = convert_to_tv_tensor(target["masks"], key="masks") + + return image, target + + def extra_repr(self) -> str: + s = f" img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n" + s += f" return_masks: {self.return_masks}\n" + if hasattr(self, "_transforms") and self._transforms is not None: + s += f" transforms:\n {repr(self._transforms)}" + if hasattr(self, "_preset") and self._preset is not None: + s += f" preset:\n {repr(self._preset)}" + return s + + @property + def categories( + self, + ): + return self.coco.dataset["categories"] + + @property + def category2name( + self, + ): + return {cat["id"]: cat["name"] for cat in self.categories} + + @property + def category2label( + self, + ): + return {cat["id"]: i for i, cat in enumerate(self.categories)} + + @property + def label2category( + self, + ): + return {i: cat["id"] for i, cat in enumerate(self.categories)} + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False): + self.return_masks = return_masks + + def __call__(self, image: Image.Image, target, **kwargs): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + image_path = target["image_path"] + + anno = target["annotations"] + + anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + category2label = kwargs.get("category2label", None) + if category2label is not None: + labels = [category2label[obj["category_id"]] for obj in anno] + else: + labels = [obj["category_id"] for obj in anno] + + labels = torch.tensor(labels, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + labels = labels[keep] + if self.return_masks: + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = labels + if self.return_masks: + target["masks"] = masks + target["image_id"] = image_id + target["image_path"] = image_path + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(w), int(h)]) + # target["size"] = torch.as_tensor([int(w), int(h)]) + + return image, target + + +mscoco_category2name = { + 1: "person", + 2: "bicycle", + 3: "car", + 4: "motorcycle", + 5: "airplane", + 6: "bus", + 7: "train", + 8: "truck", + 9: "boat", + 10: "traffic light", + 11: "fire hydrant", + 13: "stop sign", + 14: "parking meter", + 15: "bench", + 16: "bird", + 17: "cat", + 18: "dog", + 19: "horse", + 20: "sheep", + 21: "cow", + 22: "elephant", + 23: "bear", + 24: "zebra", + 25: "giraffe", + 27: "backpack", + 28: "umbrella", + 31: "handbag", + 32: "tie", + 33: "suitcase", + 34: "frisbee", + 35: "skis", + 36: "snowboard", + 37: "sports ball", + 38: "kite", + 39: "baseball bat", + 40: "baseball glove", + 41: "skateboard", + 42: "surfboard", + 43: "tennis racket", + 44: "bottle", + 46: "wine glass", + 47: "cup", + 48: "fork", + 49: "knife", + 50: "spoon", + 51: "bowl", + 52: "banana", + 53: "apple", + 54: "sandwich", + 55: "orange", + 56: "broccoli", + 57: "carrot", + 58: "hot dog", + 59: "pizza", + 60: "donut", + 61: "cake", + 62: "chair", + 63: "couch", + 64: "potted plant", + 65: "bed", + 67: "dining table", + 70: "toilet", + 72: "tv", + 73: "laptop", + 74: "mouse", + 75: "remote", + 76: "keyboard", + 77: "cell phone", + 78: "microwave", + 79: "oven", + 80: "toaster", + 81: "sink", + 82: "refrigerator", + 84: "book", + 85: "clock", + 86: "vase", + 87: "scissors", + 88: "teddy bear", + 89: "hair drier", + 90: "toothbrush", +} + +mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())} +mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} diff --git a/src/data/dataset/coco_eval.py b/src/data/dataset/coco_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2593d6ad89ac694ae5569528bb4d8dc12df2b9e9 --- /dev/null +++ b/src/data/dataset/coco_eval.py @@ -0,0 +1,214 @@ +""" +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +COCO evaluator that works in distributed mode. +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" + +import contextlib +import copy +import os + +import faster_coco_eval.core.mask as mask_util +import numpy as np +import torch +from faster_coco_eval import COCO, COCOeval_faster + +from ...core import register +from ...misc import dist_utils + +__all__ = [ + "CocoEvaluator", +] + + +@register() +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt: COCO = coco_gt + self.iou_types = iou_types + + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval_faster( + coco_gt, iouType=iou_type, print_function=print, separate_eval=True + ) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def cleanup(self): + self.coco_eval = {} + for iou_type in self.iou_types: + self.coco_eval[iou_type] = COCOeval_faster( + self.coco_gt, iouType=iou_type, print_function=print, separate_eval=True + ) + self.img_ids = [] + self.eval_imgs = {k: [] for k in self.iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + coco_eval = self.coco_eval[iou_type] + + # suppress pycocotools prints + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = self.coco_gt.loadRes(results) if results else COCO() + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + coco_eval.evaluate() + + self.eval_imgs[iou_type].append( + np.array(coco_eval._evalImgs_cpp).reshape( + len(coco_eval.params.catIds), + len(coco_eval.params.areaRng), + len(coco_eval.params.imgIds), + ) + ) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + img_ids, eval_imgs = merge(self.img_ids, self.eval_imgs[iou_type]) + + coco_eval = self.coco_eval[iou_type] + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + coco_eval._evalImgs_cpp = eval_imgs + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "keypoints": keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = dist_utils.all_gather(img_ids) + all_eval_imgs = dist_utils.all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.extend(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, axis=2).ravel() + # merged_eval_imgs = np.array(merged_eval_imgs).T.ravel() + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + + return merged_img_ids.tolist(), merged_eval_imgs.tolist() diff --git a/src/data/dataset/coco_utils.py b/src/data/dataset/coco_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb25baa3fa40b2fec4caf75c59b6e3b5c02fc0b9 --- /dev/null +++ b/src/data/dataset/coco_utils.py @@ -0,0 +1,191 @@ +""" +copy and modified https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py + +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import faster_coco_eval.core.mask as coco_mask +import torch +import torch.utils.data +import torchvision +import torchvision.transforms.functional as TVF +from faster_coco_eval import COCO + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask: + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + + anno = target["annotations"] + + anno = [obj for obj in anno if obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) + target["area"] = area + target["iscrowd"] = iscrowd + + return image, target + + +def _coco_remove_images_without_annotations(dataset, cat_list=None): + def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + min_keypoints_per_image = 10 + + def _has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different criteria for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= min_keypoints_per_image: + return True + return False + + ids = [] + for ds_idx, img_id in enumerate(dataset.ids): + ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = dataset.coco.loadAnns(ann_ids) + if cat_list: + anno = [obj for obj in anno if obj["category_id"] in cat_list] + if _has_valid_annotation(anno): + ids.append(ds_idx) + + dataset = torch.utils.data.Subset(dataset, ids) + return dataset + + +def convert_to_coco_api(ds): + coco_ds = COCO() + # annotation IDs need to start at 1, not 0, see torchvision issue #1530 + ann_id = 1 + dataset = {"images": [], "categories": [], "annotations": []} + categories = set() + for img_idx in range(len(ds)): + # find better way to get target + # targets = ds.get_annotations(img_idx) + # img, targets = ds[img_idx] + + img, targets = ds.load_item(img_idx) + width, height = img.size + + image_id = targets["image_id"].item() + img_dict = {} + img_dict["id"] = image_id + img_dict["width"] = width + img_dict["height"] = height + dataset["images"].append(img_dict) + bboxes = targets["boxes"].clone() + bboxes[:, 2:] -= bboxes[:, :2] # xyxy -> xywh + bboxes = bboxes.tolist() + labels = targets["labels"].tolist() + areas = targets["area"].tolist() + iscrowd = targets["iscrowd"].tolist() + if "masks" in targets: + masks = targets["masks"] + # make masks Fortran contiguous for coco_mask + masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) + if "keypoints" in targets: + keypoints = targets["keypoints"] + keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() + num_objs = len(bboxes) + for i in range(num_objs): + ann = {} + ann["image_id"] = image_id + ann["bbox"] = bboxes[i] + ann["category_id"] = labels[i] + categories.add(labels[i]) + ann["area"] = areas[i] + ann["iscrowd"] = iscrowd[i] + ann["id"] = ann_id + if "masks" in targets: + ann["segmentation"] = coco_mask.encode(masks[i].numpy()) + if "keypoints" in targets: + ann["keypoints"] = keypoints[i] + ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) + dataset["annotations"].append(ann) + ann_id += 1 + dataset["categories"] = [{"id": i} for i in sorted(categories)] + coco_ds.dataset = dataset + coco_ds.createIndex() + return coco_ds + + +def get_coco_api_from_dataset(dataset): + # FIXME: This is... awful? + for _ in range(10): + if isinstance(dataset, torchvision.datasets.CocoDetection): + break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, torchvision.datasets.CocoDetection): + return dataset.coco + return convert_to_coco_api(dataset) diff --git a/src/data/dataset/voc_detection.py b/src/data/dataset/voc_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..e652a36fa5248edc9eb24bf40a6942edeb3aa11d --- /dev/null +++ b/src/data/dataset/voc_detection.py @@ -0,0 +1,86 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import os +from typing import Callable, Optional + +import torch +import torchvision +import torchvision.transforms.functional as TVF +from PIL import Image +from sympy import im + +try: + from defusedxml.ElementTree import parse as ET_parse +except ImportError: + from xml.etree.ElementTree import parse as ET_parse + +from ...core import register +from .._misc import convert_to_tv_tensor +from ._dataset import DetDataset + + +@register() +class VOCDetection(torchvision.datasets.VOCDetection, DetDataset): + __inject__ = [ + "transforms", + ] + + def __init__( + self, + root: str, + ann_file: str = "trainval.txt", + label_file: str = "label_list.txt", + transforms: Optional[Callable] = None, + ): + with open(os.path.join(root, ann_file), "r") as f: + lines = [x.strip() for x in f.readlines()] + lines = [x.split(" ") for x in lines] + + self.images = [os.path.join(root, lin[0]) for lin in lines] + self.targets = [os.path.join(root, lin[1]) for lin in lines] + assert len(self.images) == len(self.targets) + + with open(os.path.join(root + label_file), "r") as f: + labels = f.readlines() + labels = [lab.strip() for lab in labels] + + self.transforms = transforms + self.labels_map = {lab: i for i, lab in enumerate(labels)} + + def __getitem__(self, index: int): + image, target = self.load_item(index) + if self.transforms is not None: + image, target, _ = self.transforms(image, target, self) + # target["orig_size"] = torch.tensor(TVF.get_image_size(image)) + return image, target + + def load_item(self, index: int): + image = Image.open(self.images[index]).convert("RGB") + target = self.parse_voc_xml(ET_parse(self.annotations[index]).getroot()) + + output = {} + output["image_id"] = torch.tensor([index]) + for k in ["area", "boxes", "labels", "iscrowd"]: + output[k] = [] + + for blob in target["annotation"]["object"]: + box = [float(v) for v in blob["bndbox"].values()] + output["boxes"].append(box) + output["labels"].append(blob["name"]) + output["area"].append((box[2] - box[0]) * (box[3] - box[1])) + output["iscrowd"].append(0) + + w, h = image.size + boxes = torch.tensor(output["boxes"]) if len(output["boxes"]) > 0 else torch.zeros(0, 4) + output["boxes"] = convert_to_tv_tensor( + boxes, "boxes", box_format="xyxy", spatial_size=[h, w] + ) + output["labels"] = torch.tensor([self.labels_map[lab] for lab in output["labels"]]) + output["area"] = torch.tensor(output["area"]) + output["iscrowd"] = torch.tensor(output["iscrowd"]) + output["orig_size"] = torch.tensor([w, h]) + + return image, output diff --git a/src/data/dataset/voc_eval.py b/src/data/dataset/voc_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8c3980104857c95289dc177106ebc24cf760e9 --- /dev/null +++ b/src/data/dataset/voc_eval.py @@ -0,0 +1,12 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torchvision + + +class VOCEvaluator(object): + def __init__(self) -> None: + pass diff --git a/src/data/transforms/__init__.py b/src/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad83cc0a99397f79451ce3f483b1b9f79494d1b0 --- /dev/null +++ b/src/data/transforms/__init__.py @@ -0,0 +1,21 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from ._transforms import ( + ConvertBoxes, + ConvertPILImage, + EmptyTransform, + Normalize, + PadToSize, + RandomCrop, + RandomHorizontalFlip, + RandomIoUCrop, + RandomPhotometricDistort, + RandomZoomOut, + Resize, + SanitizeBoundingBoxes, +) +from .container import Compose +from .mosaic import Mosaic diff --git a/src/data/transforms/_transforms.py b/src/data/transforms/_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..632c6caecc9da41efcfd65ed1a21a5cd573bde79 --- /dev/null +++ b/src/data/transforms/_transforms.py @@ -0,0 +1,161 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from typing import Any, Dict, List, Optional + +import PIL +import PIL.Image +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms.v2 as T +import torchvision.transforms.v2.functional as F + +from ...core import register +from .._misc import ( + BoundingBoxes, + Image, + Mask, + SanitizeBoundingBoxes, + Video, + _boxes_keys, + convert_to_tv_tensor, +) + +torchvision.disable_beta_transforms_warning() + + +RandomPhotometricDistort = register()(T.RandomPhotometricDistort) +RandomZoomOut = register()(T.RandomZoomOut) +RandomHorizontalFlip = register()(T.RandomHorizontalFlip) +Resize = register()(T.Resize) +# ToImageTensor = register()(T.ToImageTensor) +# ConvertDtype = register()(T.ConvertDtype) +# PILToTensor = register()(T.PILToTensor) +SanitizeBoundingBoxes = register(name="SanitizeBoundingBoxes")(SanitizeBoundingBoxes) +RandomCrop = register()(T.RandomCrop) +Normalize = register()(T.Normalize) + + +@register() +class EmptyTransform(T.Transform): + def __init__( + self, + ) -> None: + super().__init__() + + def forward(self, *inputs): + inputs = inputs if len(inputs) > 1 else inputs[0] + return inputs + + +@register() +class PadToSize(T.Pad): + _transformed_types = ( + PIL.Image.Image, + Image, + Video, + Mask, + BoundingBoxes, + ) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + sp = F.get_spatial_size(flat_inputs[0]) + h, w = self.size[1] - sp[0], self.size[0] - sp[1] + self.padding = [0, 0, w, h] + return dict(padding=self.padding) + + def __init__(self, size, fill=0, padding_mode="constant") -> None: + if isinstance(size, int): + size = (size, size) + self.size = size + super().__init__(0, fill, padding_mode) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self._fill[type(inpt)] + padding = params["padding"] + return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + + def __call__(self, *inputs: Any) -> Any: + outputs = super().forward(*inputs) + if len(outputs) > 1 and isinstance(outputs[1], dict): + outputs[1]["padding"] = torch.tensor(self.padding) + return outputs + + +@register() +class RandomIoUCrop(T.RandomIoUCrop): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + p: float = 1.0, + ): + super().__init__( + min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials + ) + self.p = p + + def __call__(self, *inputs: Any) -> Any: + if torch.rand(1) >= self.p: + return inputs if len(inputs) > 1 else inputs[0] + + return super().forward(*inputs) + + +@register() +class ConvertBoxes(T.Transform): + _transformed_types = (BoundingBoxes,) + + def __init__(self, fmt="", normalize=False) -> None: + super().__init__() + self.fmt = fmt + self.normalize = normalize + + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._transform(inpt, params) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + spatial_size = getattr(inpt, _boxes_keys[1]) + if self.fmt: + in_fmt = inpt.format.value.lower() + inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.fmt.lower()) + inpt = convert_to_tv_tensor( + inpt, key="boxes", box_format=self.fmt.upper(), spatial_size=spatial_size + ) + + if self.normalize: + inpt = inpt / torch.tensor(spatial_size[::-1]).tile(2)[None] + + return inpt + + +@register() +class ConvertPILImage(T.Transform): + _transformed_types = (PIL.Image.Image,) + + def __init__(self, dtype="float32", scale=True) -> None: + super().__init__() + self.dtype = dtype + self.scale = scale + + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._transform(inpt, params) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + inpt = F.pil_to_tensor(inpt) + if self.dtype == "float32": + inpt = inpt.float() + + if self.scale: + inpt = inpt / 255.0 + + inpt = Image(inpt) + + return inpt diff --git a/src/data/transforms/container.py b/src/data/transforms/container.py new file mode 100644 index 0000000000000000000000000000000000000000..6d94b6cfa972e823e75b093399b18976eaab9e26 --- /dev/null +++ b/src/data/transforms/container.py @@ -0,0 +1,99 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms.v2 as T + +from ...core import GLOBAL_CONFIG, register +from ._transforms import EmptyTransform + +torchvision.disable_beta_transforms_warning() + + +@register() +class Compose(T.Compose): + def __init__(self, ops, policy=None) -> None: + transforms = [] + if ops is not None: + for op in ops: + if isinstance(op, dict): + name = op.pop("type") + transform = getattr( + GLOBAL_CONFIG[name]["_pymodule"], GLOBAL_CONFIG[name]["_name"] + )(**op) + transforms.append(transform) + op["type"] = name + + elif isinstance(op, nn.Module): + transforms.append(op) + + else: + raise ValueError("") + else: + transforms = [ + EmptyTransform(), + ] + + super().__init__(transforms=transforms) + + if policy is None: + policy = {"name": "default"} + + self.policy = policy + self.global_samples = 0 + + def forward(self, *inputs: Any) -> Any: + return self.get_forward(self.policy["name"])(*inputs) + + def get_forward(self, name): + forwards = { + "default": self.default_forward, + "stop_epoch": self.stop_epoch_forward, + "stop_sample": self.stop_sample_forward, + } + return forwards[name] + + def default_forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + for transform in self.transforms: + sample = transform(sample) + return sample + + def stop_epoch_forward(self, *inputs: Any): + sample = inputs if len(inputs) > 1 else inputs[0] + dataset = sample[-1] + cur_epoch = dataset.epoch + policy_ops = self.policy["ops"] + policy_epoch = self.policy["epoch"] + + for transform in self.transforms: + if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch: + pass + else: + sample = transform(sample) + + return sample + + def stop_sample_forward(self, *inputs: Any): + sample = inputs if len(inputs) > 1 else inputs[0] + dataset = sample[-1] + + cur_epoch = dataset.epoch + policy_ops = self.policy["ops"] + policy_sample = self.policy["sample"] + + for transform in self.transforms: + if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample: + pass + else: + sample = transform(sample) + + self.global_samples += 1 + + return sample diff --git a/src/data/transforms/functional.py b/src/data/transforms/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..0762bcf2a88ad95b4f70c97480cd9a30bfe68798 --- /dev/null +++ b/src/data/transforms/functional.py @@ -0,0 +1,172 @@ +from typing import List, Optional + +import torch + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +import torchvision.transforms.functional as F +from packaging import version +from torch import Tensor + +if version.parse(torchvision.__version__) < version.parse("0.7"): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse("0.7"): + if input.numel() > 0: + return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target["masks"] = target["masks"][:, i : i + h, j : j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target["boxes"].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target["masks"].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( + [w, 0, w, 0] + ) + target["boxes"] = boxes + + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + # r = min(size / min(h, w), max_size / max(h, w)) + # ow = int(w * r) + # oh = int(h * r) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor( + [ratio_width, ratio_height, ratio_width, ratio_height] + ) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target["masks"] = ( + interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + ) + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image.size[::-1]) + if "masks" in target: + target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1])) + return padded_image, target diff --git a/src/data/transforms/mosaic.py b/src/data/transforms/mosaic.py new file mode 100644 index 0000000000000000000000000000000000000000..90ede954039a1c41789cf6dd9fe91c2e526cf657 --- /dev/null +++ b/src/data/transforms/mosaic.py @@ -0,0 +1,83 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import random + +import torch +import torchvision +import torchvision.transforms.v2 as T +import torchvision.transforms.v2.functional as F +from PIL import Image + +from ...core import register +from .._misc import convert_to_tv_tensor + +torchvision.disable_beta_transforms_warning() + + +@register() +class Mosaic(T.Transform): + def __init__( + self, + size, + max_size=None, + ) -> None: + super().__init__() + self.resize = T.Resize(size=size, max_size=max_size) + self.crop = T.RandomCrop(size=max_size if max_size else size) + + # TODO add arg `output_size` for affine` + # self.random_perspective = T.RandomPerspective(distortion_scale=0.5, p=1., ) + self.random_affine = T.RandomAffine( + degrees=0, translate=(0.1, 0.1), scale=(0.5, 1.5), fill=114 + ) + + def forward(self, *inputs): + inputs = inputs if len(inputs) > 1 else inputs[0] + image, target, dataset = inputs + + images = [] + targets = [] + indices = random.choices(range(len(dataset)), k=3) + for i in indices: + image, target = dataset.load_item(i) + image, target = self.resize(image, target) + images.append(image) + targets.append(target) + + h, w = F.get_spatial_size(images[0]) + offset = [[0, 0], [w, 0], [0, h], [w, h]] + image = Image.new(mode=images[0].mode, size=(w * 2, h * 2), color=0) + for i, im in enumerate(images): + image.paste(im, offset[i]) + + offset = torch.tensor([[0, 0], [w, 0], [0, h], [w, h]]).repeat(1, 2) + target = {} + for k in targets[0]: + if k == "boxes": + v = [t[k] + offset[i] for i, t in enumerate(targets)] + else: + v = [t[k] for t in targets] + + if isinstance(v[0], torch.Tensor): + v = torch.cat(v, dim=0) + + target[k] = v + + if "boxes" in target: + # target['boxes'] = target['boxes'].clamp(0, 640 * 2 - 1) + w, h = image.size + target["boxes"] = convert_to_tv_tensor( + target["boxes"], "boxes", box_format="xyxy", spatial_size=[h, w] + ) + + if "masks" in target: + target["masks"] = convert_to_tv_tensor(target["masks"], "masks") + + image, target = self.random_affine(image, target) + # image, target = self.resize(image, target) + image, target = self.crop(image, target) + + return image, target, dataset diff --git a/src/data/transforms/presets.py b/src/data/transforms/presets.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcf4ae9894038e8b18a07419e4c2a3f2a44bf83 --- /dev/null +++ b/src/data/transforms/presets.py @@ -0,0 +1,4 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" diff --git a/src/misc/__init__.py b/src/misc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa6819166d705ad4a4489ec14a01dec9f2e2575 --- /dev/null +++ b/src/misc/__init__.py @@ -0,0 +1,9 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from .dist_utils import setup_print, setup_seed +from .logger import * +from .profiler_utils import stats +from .visualizer import * diff --git a/src/misc/box_ops.py b/src/misc/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..25ebea60b00b283a67bbc5a9db756ee9dbec08bb --- /dev/null +++ b/src/misc/box_ops.py @@ -0,0 +1,106 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from typing import List, Tuple + +import torch +import torchvision +from torch import Tensor + + +def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + return torchvision.ops.generalized_box_iou(boxes1, boxes2) + + +# elementwise +def elementwise_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: + """ + Args: + boxes1, [N, 4] + boxes2, [N, 4] + Returns: + iou, [N, ] + union, [N, ] + """ + area1 = torchvision.ops.box_area(boxes1) # [N, ] + area2 = torchvision.ops.box_area(boxes2) # [N, ] + lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N, 2] + rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N, 2] + wh = (rb - lt).clamp(min=0) # [N, 2] + inter = wh[:, 0] * wh[:, 1] # [N, ] + union = area1 + area2 - inter + iou = inter / union + return iou, union + + +def elementwise_generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: + """ + Args: + boxes1, [N, 4] with [x1, y1, x2, y2] + boxes2, [N, 4] with [x1, y1, x2, y2] + Returns: + giou, [N, ] + """ + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = elementwise_box_iou(boxes1, boxes2) + lt = torch.min(boxes1[:, :2], boxes2[:, :2]) # [N, 2] + rb = torch.max(boxes1[:, 2:], boxes2[:, 2:]) # [N, 2] + wh = (rb - lt).clamp(min=0) # [N, 2] + area = wh[:, 0] * wh[:, 1] + return iou - (area - union) / area + + +def check_point_inside_box(points: Tensor, boxes: Tensor, eps=1e-9) -> Tensor: + """ + Args: + points, [K, 2], (x, y) + boxes, [N, 4], (x1, y1, y2, y2) + Returns: + Tensor (bool), [K, N] + """ + x, y = [p.unsqueeze(-1) for p in points.unbind(-1)] + x1, y1, x2, y2 = [x.unsqueeze(0) for x in boxes.unbind(-1)] + + l = x - x1 + t = y - y1 + r = x2 - x + b = y2 - y + + ltrb = torch.stack([l, t, r, b], dim=-1) + mask = ltrb.min(dim=-1).values > eps + + return mask + + +def point_box_distance(points: Tensor, boxes: Tensor) -> Tensor: + """ + Args: + boxes, [N, 4], (x1, y1, x2, y2) + points, [N, 2], (x, y) + Returns: + Tensor (N, 4), (l, t, r, b) + """ + x1y1, x2y2 = torch.split(boxes, 2, dim=-1) + lt = points - x1y1 + rb = x2y2 - points + return torch.concat([lt, rb], dim=-1) + + +def point_distance_box(points: Tensor, distances: Tensor) -> Tensor: + """ + Args: + points (Tensor), [N, 2], (x, y) + distances (Tensor), [N, 4], (l, t, r, b) + Returns: + boxes (Tensor), (N, 4), (x1, y1, x2, y2) + """ + lt, rb = torch.split(distances, 2, dim=-1) + x1y1 = -lt + points + x2y2 = rb + points + boxes = torch.concat([x1y1, x2y2], dim=-1) + return boxes diff --git a/src/misc/dist_utils.py b/src/misc/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0422814c5b6de3db97b6868fa8164ca72438a120 --- /dev/null +++ b/src/misc/dist_utils.py @@ -0,0 +1,281 @@ +""" +reference +- https://github.com/pytorch/vision/blob/main/references/detection/utils.py +- https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406 + +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import atexit +import os +import random +import time + +import numpy as np +import torch +import torch.backends.cudnn +import torch.distributed +import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.nn.parallel import DataParallel as DP +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DistributedSampler + +# from torch.utils.data.dataloader import DataLoader +from ..data import DataLoader + + +def setup_distributed( + print_rank: int = 0, + print_method: str = "builtin", + seed: int = None, +): + """ + env setup + args: + print_rank, + print_method, (builtin, rich) + seed, + """ + try: + # https://pytorch.org/docs/stable/elastic/run.html + RANK = int(os.getenv("RANK", -1)) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) + + # torch.distributed.init_process_group(backend=backend, init_method='env://') + torch.distributed.init_process_group(init_method="env://") + torch.distributed.barrier() + + rank = torch.distributed.get_rank() + torch.cuda.set_device(rank) + torch.cuda.empty_cache() + enabled_dist = True + if get_rank() == print_rank: + print("Initialized distributed mode...") + + except Exception: + enabled_dist = False + print("Not init distributed mode.") + + setup_print(get_rank() == print_rank, method=print_method) + if seed is not None: + setup_seed(seed) + + return enabled_dist + + +def setup_print(is_main, method="builtin"): + """This function disables printing when not in master process""" + import builtins as __builtin__ + + if method == "builtin": + builtin_print = __builtin__.print + + elif method == "rich": + import rich + + builtin_print = rich.print + + else: + raise AttributeError("") + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_available_and_initialized(): + if not torch.distributed.is_available(): + return False + if not torch.distributed.is_initialized(): + return False + return True + + +@atexit.register +def cleanup(): + """cleanup distributed environment""" + if is_dist_available_and_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +def get_rank(): + if not is_dist_available_and_initialized(): + return 0 + return torch.distributed.get_rank() + + +def get_world_size(): + if not is_dist_available_and_initialized(): + return 1 + return torch.distributed.get_world_size() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def warp_model( + model: torch.nn.Module, + sync_bn: bool = False, + dist_mode: str = "ddp", + find_unused_parameters: bool = False, + compile: bool = False, + compile_mode: str = "reduce-overhead", + **kwargs, +): + if is_dist_available_and_initialized(): + rank = get_rank() + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model + if dist_mode == "dp": + model = DP(model, device_ids=[rank], output_device=rank) + elif dist_mode == "ddp": + model = DDP( + model, + device_ids=[rank], + output_device=rank, + find_unused_parameters=find_unused_parameters, + ) + else: + raise AttributeError("") + + if compile: + model = torch.compile(model, mode=compile_mode) + + return model + + +def de_model(model): + return de_parallel(de_complie(model)) + + +def warp_loader(loader, shuffle=False): + if is_dist_available_and_initialized(): + sampler = DistributedSampler(loader.dataset, shuffle=shuffle) + loader = DataLoader( + loader.dataset, + loader.batch_size, + sampler=sampler, + drop_last=loader.drop_last, + collate_fn=loader.collate_fn, + pin_memory=loader.pin_memory, + num_workers=loader.num_workers, + ) + return loader + + +def is_parallel(model) -> bool: + # Returns True if model is of type DP or DDP + return type(model) in ( + torch.nn.parallel.DataParallel, + torch.nn.parallel.DistributedDataParallel, + ) + + +def de_parallel(model) -> nn.Module: + # De-parallelize a model: returns single-GPU model if model is of type DP or DDP + return model.module if is_parallel(model) else model + + +def reduce_dict(data, avg=True): + """ + Args + data dict: input, {k: v, ...} + avg bool: true + """ + world_size = get_world_size() + if world_size < 2: + return data + + with torch.no_grad(): + keys, values = [], [] + for k in sorted(data.keys()): + keys.append(k) + values.append(data[k]) + + values = torch.stack(values, dim=0) + torch.distributed.all_reduce(values) + + if avg is True: + values /= world_size + + return {k: v for k, v in zip(keys, values)} + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + data_list = [None] * world_size + torch.distributed.all_gather_object(data_list, data) + return data_list + + +def sync_time(): + """sync_time""" + if torch.cuda.is_available(): + torch.cuda.synchronize() + + return time.time() + + +def setup_seed(seed: int, deterministic=False): + """setup_seed for reproducibility + torch.manual_seed(3407) is all you need. https://arxiv.org/abs/2109.08203 + """ + seed = seed + get_rank() + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # memory will be large when setting deterministic to True + if torch.backends.cudnn.is_available() and deterministic: + torch.backends.cudnn.deterministic = True + + +# for torch.compile +def check_compile(): + import warnings + + import torch + + gpu_ok = False + if torch.cuda.is_available(): + device_cap = torch.cuda.get_device_capability() + if device_cap in ((7, 0), (8, 0), (9, 0)): + gpu_ok = True + if not gpu_ok: + warnings.warn( + "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " "than expected." + ) + return gpu_ok + + +def is_compile(model): + import torch._dynamo + + return type(model) in (torch._dynamo.OptimizedModule,) + + +def de_complie(model): + return model._orig_mod if is_compile(model) else model diff --git a/src/misc/lazy_loader.py b/src/misc/lazy_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc54ab2144de8555b0a44492cf57ad944b170be --- /dev/null +++ b/src/misc/lazy_loader.py @@ -0,0 +1,70 @@ +""" +https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/util/lazy_loader.py +""" + +import importlib +import types + + +class LazyLoader(types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies. + + `paddle`, and `ffmpeg` are examples of modules that are large and not always + needed, and this allows them to only be loaded when they are used. + """ + + # The lint error here is incorrect. + def __init__(self, local_name, parent_module_globals, name, warning=None): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._warning = warning + + # These members allows doctest correctly process this module member without + # triggering self._load(). self._load() mutates parant_module_globals and + # triggers a dict mutated during iteration error from doctest.py. + # - for from_module() + self.__module__ = name.rsplit(".", 1)[0] + # - for is_routine() + self.__wrapped__ = None + + super(LazyLoader, self).__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + + # Emit a warning if one was specified + if self._warning: + # logging.warning(self._warning) + # Make sure to only warn once. + self._warning = None + + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on lookups + # that fail). + self.__dict__.update(module.__dict__) + + return module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + def __repr__(self): + # Carefully to not trigger _load, since repr may be called in very + # sensitive places. + return f"" + + def __dir__(self): + module = self._load() + return dir(module) + + +# import paddle.nn as nn +# nn = LazyLoader("nn", globals(), "paddle.nn") + +# class M(nn.Layer): +# def __init__(self) -> None: +# super().__init__() diff --git a/src/misc/logger.py b/src/misc/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ab2078a0bda1ab4adcf6c1393c1d7b718ae7bd --- /dev/null +++ b/src/misc/logger.py @@ -0,0 +1,255 @@ +""" +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +https://github.com/facebookresearch/detr/blob/main/util/misc.py +Mostly copy-paste from torchvision references. +""" + +import datetime +import pickle +import time +from collections import defaultdict, deque +from typing import Dict + +import torch +import torch.distributed as tdist + +from .dist_utils import get_world_size, is_dist_available_and_initialized + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_available_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + tdist.barrier() + tdist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + tdist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + tdist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True) -> Dict[str, torch.Tensor]: + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + tdist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) diff --git a/src/misc/profiler_utils.py b/src/misc/profiler_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c6e0930b9f1be854b61b3f69b0cef5e44b5c5e --- /dev/null +++ b/src/misc/profiler_utils.py @@ -0,0 +1,30 @@ +""" +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +""" + +import copy +from typing import Tuple + +from calflops import calculate_flops + + +def stats( + cfg, + input_shape: Tuple = (1, 3, 640, 640), +) -> Tuple[int, dict]: + base_size = cfg.train_dataloader.collate_fn.base_size + input_shape = (1, 3, base_size, base_size) + + model_for_info = copy.deepcopy(cfg.model).deploy() + + flops, macs, _ = calculate_flops( + model=model_for_info, + input_shape=input_shape, + output_as_string=True, + output_precision=4, + print_detailed=False, + ) + params = sum(p.numel() for p in model_for_info.parameters()) + del model_for_info + + return params, {"Model FLOPs:%s MACs:%s Params:%s" % (flops, macs, params)} diff --git a/src/misc/visualizer.py b/src/misc/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..414874f9219d28e63efb21d897a2ed01d7dc2af4 --- /dev/null +++ b/src/misc/visualizer.py @@ -0,0 +1,121 @@ +""" " +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import PIL +import numpy as np +import torch +import torch.utils.data +import torchvision +from typing import List, Dict + +torchvision.disable_beta_transforms_warning() + +__all__ = ["show_sample", "save_samples"] + +def save_samples(samples: torch.Tensor, targets: List[Dict], output_dir: str, split: str, normalized: bool, box_fmt: str): + ''' + normalized: whether the boxes are normalized to [0, 1] + box_fmt: 'xyxy', 'xywh', 'cxcywh', D-FINE uses 'cxcywh' for training, 'xyxy' for validation + ''' + from torchvision.transforms.functional import to_pil_image + from torchvision.ops import box_convert + from pathlib import Path + from PIL import ImageDraw, ImageFont + import os + + os.makedirs(Path(output_dir) / Path(f"{split}_samples"), exist_ok=True) + # Predefined colors (standard color names recognized by PIL) + BOX_COLORS = [ + "red", "blue", "green", "orange", "purple", + "cyan", "magenta", "yellow", "lime", "pink", + "teal", "lavender", "brown", "beige", "maroon", + "navy", "olive", "coral", "turquoise", "gold" + ] + + LABEL_TEXT_COLOR = "white" + + font = ImageFont.load_default() + font.size = 32 + + for i, (sample, target) in enumerate(zip(samples, targets)): + sample_visualization = sample.clone().cpu() + target_boxes = target["boxes"].clone().cpu() + target_labels = target["labels"].clone().cpu() + target_image_id = target["image_id"].item() + target_image_path = target["image_path"] + target_image_path_stem = Path(target_image_path).stem + + sample_visualization = to_pil_image(sample_visualization) + sample_visualization_w, sample_visualization_h = sample_visualization.size + + # normalized to pixel space + if normalized: + target_boxes[:, 0] = target_boxes[:, 0] * sample_visualization_w + target_boxes[:, 2] = target_boxes[:, 2] * sample_visualization_w + target_boxes[:, 1] = target_boxes[:, 1] * sample_visualization_h + target_boxes[:, 3] = target_boxes[:, 3] * sample_visualization_h + + # any box format -> xyxy + target_boxes = box_convert(target_boxes, in_fmt=box_fmt, out_fmt="xyxy") + + # clip to image size + target_boxes[:, 0] = torch.clamp(target_boxes[:, 0], 0, sample_visualization_w) + target_boxes[:, 1] = torch.clamp(target_boxes[:, 1], 0, sample_visualization_h) + target_boxes[:, 2] = torch.clamp(target_boxes[:, 2], 0, sample_visualization_w) + target_boxes[:, 3] = torch.clamp(target_boxes[:, 3], 0, sample_visualization_h) + + target_boxes = target_boxes.numpy().astype(np.int32) + target_labels = target_labels.numpy().astype(np.int32) + + draw = ImageDraw.Draw(sample_visualization) + + # draw target boxes + for box, label in zip(target_boxes, target_labels): + x1, y1, x2, y2 = box + + # Select color based on class ID + box_color = BOX_COLORS[int(label) % len(BOX_COLORS)] + + # Draw box (thick) + draw.rectangle([x1, y1, x2, y2], outline=box_color, width=3) + + label_text = f"{label}" + + # Measure text size + text_width, text_height = draw.textbbox((0, 0), label_text, font=font)[2:4] + + # Draw text background + padding = 2 + draw.rectangle( + [x1, y1 - text_height - padding * 2, x1 + text_width + padding * 2, y1], + fill=box_color + ) + + # Draw text (LABEL_TEXT_COLOR) + draw.text((x1 + padding, y1 - text_height - padding), label_text, + fill=LABEL_TEXT_COLOR, font=font) + + save_path = Path(output_dir) / f"{split}_samples" / f"{target_image_id}_{target_image_path_stem}.webp" + sample_visualization.save(save_path) + +def show_sample(sample): + """for coco dataset/dataloader""" + import matplotlib.pyplot as plt + from torchvision.transforms.v2 import functional as F + from torchvision.utils import draw_bounding_boxes + + image, target = sample + if isinstance(image, PIL.Image.Image): + image = F.to_image_tensor(image) + + image = F.convert_dtype(image, torch.uint8) + annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) + + fig, ax = plt.subplots() + ax.imshow(annotated_image.permute(1, 2, 0).numpy()) + ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + fig.tight_layout() + fig.show() + plt.show() diff --git a/src/nn/__init__.py b/src/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a96d1338807e767c1da097276d5d804d338dbf3e --- /dev/null +++ b/src/nn/__init__.py @@ -0,0 +1,16 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from .arch import * + +# +from .backbone import * +from .backbone import ( + FrozenBatchNorm2d, + freeze_batch_norm2d, + get_activation, +) +from .criterion import * +from .postprocessor import * diff --git a/src/nn/arch/__init__.py b/src/nn/arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4b8b9c9984fa9d2393c8e1a15f6efb5e2c01c4 --- /dev/null +++ b/src/nn/arch/__init__.py @@ -0,0 +1,7 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from .classification import ClassHead, Classification +from .yolo import YOLO diff --git a/src/nn/arch/classification.py b/src/nn/arch/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..7b01be85c8cbaea553916ab1e7c846ca3bf1ec1a --- /dev/null +++ b/src/nn/arch/classification.py @@ -0,0 +1,45 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.nn as nn + +from ...core import register + +__all__ = ["Classification", "ClassHead"] + + +@register() +class Classification(torch.nn.Module): + __inject__ = ["backbone", "head"] + + def __init__(self, backbone: nn.Module, head: nn.Module = None): + super().__init__() + + self.backbone = backbone + self.head = head + + def forward(self, x): + x = self.backbone(x) + + if self.head is not None: + x = self.head(x) + + return x + + +@register() +class ClassHead(nn.Module): + def __init__(self, hidden_dim, num_classes): + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.proj = nn.Linear(hidden_dim, num_classes) + + def forward(self, x): + x = x[0] if isinstance(x, (list, tuple)) else x + x = self.pool(x) + x = x.reshape(x.shape[0], -1) + x = self.proj(x) + return x diff --git a/src/nn/arch/yolo.py b/src/nn/arch/yolo.py new file mode 100644 index 0000000000000000000000000000000000000000..bbec1b51bf25746038aead49fcd15d5b95c88839 --- /dev/null +++ b/src/nn/arch/yolo.py @@ -0,0 +1,42 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch + +from ...core import register + +__all__ = [ + "YOLO", +] + + +@register() +class YOLO(torch.nn.Module): + __inject__ = [ + "backbone", + "neck", + "head", + ] + + def __init__(self, backbone: torch.nn.Module, neck, head): + super().__init__() + self.backbone = backbone + self.neck = neck + self.head = head + + def forward(self, x, **kwargs): + x = self.backbone(x) + x = self.neck(x) + x = self.head(x) + return x + + def deploy( + self, + ): + self.eval() + for m in self.modules(): + if m is not self and hasattr(m, "deploy"): + m.deploy() + return self diff --git a/src/nn/backbone/__init__.py b/src/nn/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..806026f5ab75f77a83c5992f7568ae99d747a76d --- /dev/null +++ b/src/nn/backbone/__init__.py @@ -0,0 +1,17 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from .common import ( + FrozenBatchNorm2d, + freeze_batch_norm2d, + get_activation, +) +from .csp_darknet import CSPPAN, CSPDarkNet +from .csp_resnet import CSPResNet +from .hgnetv2 import HGNetv2 +from .presnet import PResNet +from .test_resnet import MResNet +from .timm_model import TimmModel +from .torchvision_model import TorchVisionModel diff --git a/src/nn/backbone/common.py b/src/nn/backbone/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d485fdb8d8c9f1c42f045a82baea67a1ec69c5c0 --- /dev/null +++ b/src/nn/backbone/common.py @@ -0,0 +1,117 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.nn as nn + + +class ConvNormLayer(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): + super().__init__() + self.conv = nn.Conv2d( + ch_in, + ch_out, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=bias, + ) + self.norm = nn.BatchNorm2d(ch_out) + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + return self.act(self.norm(self.conv(x))) + + +class FrozenBatchNorm2d(nn.Module): + """copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py + BatchNorm2d where the batch statistics and the affine parameters are fixed. + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, num_features, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + n = num_features + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + self.num_features = n + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + def extra_repr(self): + return "{num_features}, eps={eps}".format(**self.__dict__) + + +def freeze_batch_norm2d(module: nn.Module) -> nn.Module: + if isinstance(module, nn.BatchNorm2d): + module = FrozenBatchNorm2d(module.num_features) + else: + for name, child in module.named_children(): + _child = freeze_batch_norm2d(child) + if _child is not child: + setattr(module, name, _child) + return module + + +def get_activation(act: str, inplace: bool = True): + """get activation""" + if act is None: + return nn.Identity() + + elif isinstance(act, nn.Module): + return act + + act = act.lower() + + if act == "silu" or act == "swish": + m = nn.SiLU() + + elif act == "relu": + m = nn.ReLU() + + elif act == "leaky_relu": + m = nn.LeakyReLU() + + elif act == "silu": + m = nn.SiLU() + + elif act == "gelu": + m = nn.GELU() + + elif act == "hardsigmoid": + m = nn.Hardsigmoid() + + else: + raise RuntimeError("") + + if hasattr(m, "inplace"): + m.inplace = inplace + + return m diff --git a/src/nn/backbone/csp_darknet.py b/src/nn/backbone/csp_darknet.py new file mode 100644 index 0000000000000000000000000000000000000000..04b44e69f6a299d86ce9c8023d68949ccb0edc14 --- /dev/null +++ b/src/nn/backbone/csp_darknet.py @@ -0,0 +1,203 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import math +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core import register +from .common import get_activation + + +def autopad(k, p=None): + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] + return p + + +def make_divisible(c, d): + return math.ceil(c / d) * d + + +class Conv(nn.Module): + def __init__(self, cin, cout, k=1, s=1, p=None, g=1, act="silu") -> None: + super().__init__() + self.conv = nn.Conv2d(cin, cout, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(cout) + self.act = get_activation(act, inplace=True) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, act="silu"): + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1, act=act) + self.cv2 = Conv(c_, c2, 3, 1, g=g, act=act) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__( + self, c1, c2, n=1, shortcut=True, g=1, e=0.5, act="silu" + ): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1, act=act) + self.cv2 = Conv(c1, c_, 1, 1, act=act) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0, act=act) for _ in range(n))) + self.cv3 = Conv(2 * c_, c2, 1, act=act) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class SPPF(nn.Module): + # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher + def __init__(self, c1, c2, k=5, act="silu"): # equivalent to SPP(k=(5, 9, 13)) + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1, act=act) + self.cv2 = Conv(c_ * 4, c2, 1, 1, act=act) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress torch 1.9.0 max_pool2d() warning + y1 = self.m(x) + y2 = self.m(y1) + return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) + + +@register() +class CSPDarkNet(nn.Module): + __share__ = ["depth_multi", "width_multi"] + + def __init__( + self, + in_channels=3, + width_multi=1.0, + depth_multi=1.0, + return_idx=[2, 3, -1], + act="silu", + ) -> None: + super().__init__() + + channels = [64, 128, 256, 512, 1024] + channels = [make_divisible(c * width_multi, 8) for c in channels] + + depths = [3, 6, 9, 3] + depths = [max(round(d * depth_multi), 1) for d in depths] + + self.layers = nn.ModuleList([Conv(in_channels, channels[0], 6, 2, 2, act=act)]) + for i, (c, d) in enumerate(zip(channels, depths), 1): + layer = nn.Sequential( + *[Conv(c, channels[i], 3, 2, act=act), C3(channels[i], channels[i], n=d, act=act)] + ) + self.layers.append(layer) + + self.layers.append(SPPF(channels[-1], channels[-1], k=5, act=act)) + + self.return_idx = return_idx + self.out_channels = [channels[i] for i in self.return_idx] + self.strides = [[2, 4, 8, 16, 32][i] for i in self.return_idx] + self.depths = depths + self.act = act + + def forward(self, x): + outputs = [] + for _, m in enumerate(self.layers): + x = m(x) + outputs.append(x) + + return [outputs[i] for i in self.return_idx] + + +@register() +class CSPPAN(nn.Module): + """ + P5 ---> 1x1 ---------------------------------> concat --> c3 --> det + | up | conv /2 + P4 ---> concat ---> c3 ---> 1x1 --> concat ---> c3 -----------> det + | up | conv /2 + P3 -----------------------> concat ---> c3 ---------------------> det + """ + + __share__ = [ + "depth_multi", + ] + + def __init__(self, in_channels=[256, 512, 1024], depth_multi=1.0, act="silu") -> None: + super().__init__() + depth = max(round(3 * depth_multi), 1) + + self.out_channels = in_channels + self.fpn_stems = nn.ModuleList( + [ + Conv(cin, cout, 1, 1, act=act) + for cin, cout in zip(in_channels[::-1], in_channels[::-1][1:]) + ] + ) + self.fpn_csps = nn.ModuleList( + [ + C3(cin, cout, depth, False, act=act) + for cin, cout in zip(in_channels[::-1], in_channels[::-1][1:]) + ] + ) + + self.pan_stems = nn.ModuleList([Conv(c, c, 3, 2, act=act) for c in in_channels[:-1]]) + self.pan_csps = nn.ModuleList([C3(c, c, depth, False, act=act) for c in in_channels[1:]]) + + def forward(self, feats): + fpn_feats = [] + for i, feat in enumerate(feats[::-1]): + if i == 0: + feat = self.fpn_stems[i](feat) + fpn_feats.append(feat) + else: + _feat = F.interpolate(fpn_feats[-1], scale_factor=2, mode="nearest") + feat = torch.concat([_feat, feat], dim=1) + feat = self.fpn_csps[i - 1](feat) + if i < len(self.fpn_stems): + feat = self.fpn_stems[i](feat) + fpn_feats.append(feat) + + pan_feats = [] + for i, feat in enumerate(fpn_feats[::-1]): + if i == 0: + pan_feats.append(feat) + else: + _feat = self.pan_stems[i - 1](pan_feats[-1]) + feat = torch.concat([_feat, feat], dim=1) + feat = self.pan_csps[i - 1](feat) + pan_feats.append(feat) + + return pan_feats + + +if __name__ == "__main__": + data = torch.rand(1, 3, 320, 640) + + width_multi = 0.75 + depth_multi = 0.33 + + m = CSPDarkNet(3, width_multi=width_multi, depth_multi=depth_multi, act="silu") + outputs = m(data) + print([o.shape for o in outputs]) + + m = CSPPAN(in_channels=m.out_channels, depth_multi=depth_multi, act="silu") + outputs = m(outputs) + print([o.shape for o in outputs]) diff --git a/src/nn/backbone/csp_resnet.py b/src/nn/backbone/csp_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9436b1f64861370fe72ee816161c46fa51cde4 --- /dev/null +++ b/src/nn/backbone/csp_resnet.py @@ -0,0 +1,302 @@ +""" +https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.6/ppdet/modeling/backbones/cspresnet.py + +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core import register +from .common import get_activation + +__all__ = ["CSPResNet"] + + +donwload_url = { + "s": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_s_pretrained_from_paddle.pth", + "m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_m_pretrained_from_paddle.pth", + "l": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_l_pretrained_from_paddle.pth", + "x": "https://github.com/lyuwenyu/storage/releases/download/v0.1/CSPResNetb_x_pretrained_from_paddle.pth", +} + + +class ConvBNLayer(nn.Module): + def __init__(self, ch_in, ch_out, filter_size=3, stride=1, groups=1, padding=0, act=None): + super().__init__() + self.conv = nn.Conv2d( + ch_in, ch_out, filter_size, stride, padding, groups=groups, bias=False + ) + self.bn = nn.BatchNorm2d(ch_out) + self.act = get_activation(act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + return x + + +class RepVggBlock(nn.Module): + def __init__(self, ch_in, ch_out, act="relu", alpha: bool = False): + super().__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=None) + self.conv2 = ConvBNLayer(ch_in, ch_out, 1, stride=1, padding=0, act=None) + self.act = get_activation(act) + + if alpha: + self.alpha = nn.Parameter( + torch.ones( + 1, + ) + ) + else: + self.alpha = None + + def forward(self, x): + if hasattr(self, "conv"): + y = self.conv(x) + else: + if self.alpha: + y = self.conv1(x) + self.alpha * self.conv2(x) + else: + y = self.conv1(x) + self.conv2(x) + y = self.act(y) + return y + + def convert_to_deploy(self): + if not hasattr(self, "conv"): + self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) + + kernel, bias = self.get_equivalent_kernel_bias() + self.conv.weight.data = kernel + self.conv.bias.data = bias + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) + + if self.alpha: + return kernel3x3 + self.alpha * self._pad_1x1_to_3x3_tensor( + kernel1x1 + ), bias3x3 + self.alpha * bias1x1 + else: + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return F.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch: ConvBNLayer): + if branch is None: + return 0, 0 + kernel = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class BasicBlock(nn.Module): + def __init__(self, ch_in, ch_out, act="relu", shortcut=True, use_alpha=False): + super().__init__() + assert ch_in == ch_out + self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act) + self.conv2 = RepVggBlock(ch_out, ch_out, act=act, alpha=use_alpha) + self.shortcut = shortcut + + def forward(self, x): + y = self.conv1(x) + y = self.conv2(y) + if self.shortcut: + return x + y + else: + return y + + +class EffectiveSELayer(nn.Module): + """Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + + def __init__(self, channels, act="hardsigmoid"): + super(EffectiveSELayer, self).__init__() + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.act = get_activation(act) + + def forward(self, x: torch.Tensor): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.fc(x_se) + x_se = self.act(x_se) + return x * x_se + + +class CSPResStage(nn.Module): + def __init__(self, block_fn, ch_in, ch_out, n, stride, act="relu", attn="eca", use_alpha=False): + super().__init__() + ch_mid = (ch_in + ch_out) // 2 + if stride == 2: + self.conv_down = ConvBNLayer(ch_in, ch_mid, 3, stride=2, padding=1, act=act) + else: + self.conv_down = None + self.conv1 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act) + self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act) + self.blocks = nn.Sequential( + *[ + block_fn(ch_mid // 2, ch_mid // 2, act=act, shortcut=True, use_alpha=use_alpha) + for i in range(n) + ] + ) + if attn: + self.attn = EffectiveSELayer(ch_mid, act="hardsigmoid") + else: + self.attn = None + + self.conv3 = ConvBNLayer(ch_mid, ch_out, 1, act=act) + + def forward(self, x): + if self.conv_down is not None: + x = self.conv_down(x) + y1 = self.conv1(x) + y2 = self.blocks(self.conv2(x)) + y = torch.concat([y1, y2], dim=1) + if self.attn is not None: + y = self.attn(y) + y = self.conv3(y) + return y + + +@register() +class CSPResNet(nn.Module): + layers = [3, 6, 6, 3] + channels = [64, 128, 256, 512, 1024] + model_cfg = { + "s": { + "depth_mult": 0.33, + "width_mult": 0.50, + }, + "m": { + "depth_mult": 0.67, + "width_mult": 0.75, + }, + "l": { + "depth_mult": 1.00, + "width_mult": 1.00, + }, + "x": { + "depth_mult": 1.33, + "width_mult": 1.25, + }, + } + + def __init__( + self, + name: str, + act="silu", + return_idx=[1, 2, 3], + use_large_stem=True, + use_alpha=False, + pretrained=False, + ): + super().__init__() + depth_mult = self.model_cfg[name]["depth_mult"] + width_mult = self.model_cfg[name]["width_mult"] + + channels = [max(round(c * width_mult), 1) for c in self.channels] + layers = [max(round(l * depth_mult), 1) for l in self.layers] + act = get_activation(act) + + if use_large_stem: + self.stem = nn.Sequential( + OrderedDict( + [ + ( + "conv1", + ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act), + ), + ( + "conv2", + ConvBNLayer( + channels[0] // 2, channels[0] // 2, 3, stride=1, padding=1, act=act + ), + ), + ( + "conv3", + ConvBNLayer( + channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act + ), + ), + ] + ) + ) + else: + self.stem = nn.Sequential( + OrderedDict( + [ + ( + "conv1", + ConvBNLayer(3, channels[0] // 2, 3, stride=2, padding=1, act=act), + ), + ( + "conv2", + ConvBNLayer( + channels[0] // 2, channels[0], 3, stride=1, padding=1, act=act + ), + ), + ] + ) + ) + + n = len(channels) - 1 + self.stages = nn.Sequential( + OrderedDict( + [ + ( + str(i), + CSPResStage( + BasicBlock, + channels[i], + channels[i + 1], + layers[i], + 2, + act=act, + use_alpha=use_alpha, + ), + ) + for i in range(n) + ] + ) + ) + + self._out_channels = channels[1:] + self._out_strides = [4 * 2**i for i in range(n)] + self.return_idx = return_idx + + if pretrained: + if isinstance(pretrained, bool) or "http" in pretrained: + state = torch.hub.load_state_dict_from_url(donwload_url[name], map_location="cpu") + else: + state = torch.load(pretrained, map_location="cpu") + self.load_state_dict(state) + print(f"Load CSPResNet_{name} state_dict") + + def forward(self, x): + x = self.stem(x) + outs = [] + for idx, stage in enumerate(self.stages): + x = stage(x) + if idx in self.return_idx: + outs.append(x) + + return outs diff --git a/src/nn/backbone/hgnetv2.py b/src/nn/backbone/hgnetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b15ec63c6a679dd2c753362d8f21b26925ced81 --- /dev/null +++ b/src/nn/backbone/hgnetv2.py @@ -0,0 +1,581 @@ +""" +reference +- https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py + +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +""" + +import logging +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core import register +from .common import FrozenBatchNorm2d + +# Constants for initialization +kaiming_normal_ = nn.init.kaiming_normal_ +zeros_ = nn.init.zeros_ +ones_ = nn.init.ones_ + +__all__ = ["HGNetv2"] + +def safe_barrier(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + else: + pass + +def safe_get_rank(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + +class LearnableAffineBlock(nn.Module): + def __init__(self, scale_value=1.0, bias_value=0.0): + super().__init__() + self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True) + self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True) + + def forward(self, x): + return self.scale * x + self.bias + + +class ConvBNAct(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size, + stride=1, + groups=1, + padding="", + use_act=True, + use_lab=False, + ): + super().__init__() + self.use_act = use_act + self.use_lab = use_lab + if padding == "same": + self.conv = nn.Sequential( + nn.ZeroPad2d([0, 1, 0, 1]), + nn.Conv2d(in_chs, out_chs, kernel_size, stride, groups=groups, bias=False), + ) + else: + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size, + stride, + padding=(kernel_size - 1) // 2, + groups=groups, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_chs) + if self.use_act: + self.act = nn.ReLU() + else: + self.act = nn.Identity() + if self.use_act and self.use_lab: + self.lab = LearnableAffineBlock() + else: + self.lab = nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + x = self.lab(x) + return x + + +class LightConvBNAct(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size, + groups=1, + use_lab=False, + ): + super().__init__() + self.conv1 = ConvBNAct( + in_chs, + out_chs, + kernel_size=1, + use_act=False, + use_lab=use_lab, + ) + self.conv2 = ConvBNAct( + out_chs, + out_chs, + kernel_size=kernel_size, + groups=out_chs, + use_act=True, + use_lab=use_lab, + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class StemBlock(nn.Module): + # for HGNetv2 + def __init__(self, in_chs, mid_chs, out_chs, use_lab=False): + super().__init__() + self.stem1 = ConvBNAct( + in_chs, + mid_chs, + kernel_size=3, + stride=2, + use_lab=use_lab, + ) + self.stem2a = ConvBNAct( + mid_chs, + mid_chs // 2, + kernel_size=2, + stride=1, + use_lab=use_lab, + ) + self.stem2b = ConvBNAct( + mid_chs // 2, + mid_chs, + kernel_size=2, + stride=1, + use_lab=use_lab, + ) + self.stem3 = ConvBNAct( + mid_chs * 2, + mid_chs, + kernel_size=3, + stride=2, + use_lab=use_lab, + ) + self.stem4 = ConvBNAct( + mid_chs, + out_chs, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + self.pool = nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) + + def forward(self, x): + x = self.stem1(x) + x = F.pad(x, (0, 1, 0, 1)) + x2 = self.stem2a(x) + x2 = F.pad(x2, (0, 1, 0, 1)) + x2 = self.stem2b(x2) + x1 = self.pool(x) + x = torch.cat([x1, x2], dim=1) + x = self.stem3(x) + x = self.stem4(x) + return x + + +class EseModule(nn.Module): + def __init__(self, chs): + super().__init__() + self.conv = nn.Conv2d( + chs, + chs, + kernel_size=1, + stride=1, + padding=0, + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + identity = x + x = x.mean((2, 3), keepdim=True) + x = self.conv(x) + x = self.sigmoid(x) + return torch.mul(identity, x) + + +class HG_Block(nn.Module): + def __init__( + self, + in_chs, + mid_chs, + out_chs, + layer_num, + kernel_size=3, + residual=False, + light_block=False, + use_lab=False, + agg="ese", + drop_path=0.0, + ): + super().__init__() + self.residual = residual + + self.layers = nn.ModuleList() + for i in range(layer_num): + if light_block: + self.layers.append( + LightConvBNAct( + in_chs if i == 0 else mid_chs, + mid_chs, + kernel_size=kernel_size, + use_lab=use_lab, + ) + ) + else: + self.layers.append( + ConvBNAct( + in_chs if i == 0 else mid_chs, + mid_chs, + kernel_size=kernel_size, + stride=1, + use_lab=use_lab, + ) + ) + + # feature aggregation + total_chs = in_chs + layer_num * mid_chs + if agg == "se": + aggregation_squeeze_conv = ConvBNAct( + total_chs, + out_chs // 2, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + aggregation_excitation_conv = ConvBNAct( + out_chs // 2, + out_chs, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + self.aggregation = nn.Sequential( + aggregation_squeeze_conv, + aggregation_excitation_conv, + ) + else: + aggregation_conv = ConvBNAct( + total_chs, + out_chs, + kernel_size=1, + stride=1, + use_lab=use_lab, + ) + att = EseModule(out_chs) + self.aggregation = nn.Sequential( + aggregation_conv, + att, + ) + + self.drop_path = nn.Dropout(drop_path) if drop_path else nn.Identity() + + def forward(self, x): + identity = x + output = [x] + for layer in self.layers: + x = layer(x) + output.append(x) + x = torch.cat(output, dim=1) + x = self.aggregation(x) + if self.residual: + x = self.drop_path(x) + identity + return x + + +class HG_Stage(nn.Module): + def __init__( + self, + in_chs, + mid_chs, + out_chs, + block_num, + layer_num, + downsample=True, + light_block=False, + kernel_size=3, + use_lab=False, + agg="se", + drop_path=0.0, + ): + super().__init__() + self.downsample = downsample + if downsample: + self.downsample = ConvBNAct( + in_chs, + in_chs, + kernel_size=3, + stride=2, + groups=in_chs, + use_act=False, + use_lab=use_lab, + ) + else: + self.downsample = nn.Identity() + + blocks_list = [] + for i in range(block_num): + blocks_list.append( + HG_Block( + in_chs if i == 0 else out_chs, + mid_chs, + out_chs, + layer_num, + residual=False if i == 0 else True, + kernel_size=kernel_size, + light_block=light_block, + use_lab=use_lab, + agg=agg, + drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, + ) + ) + self.blocks = nn.Sequential(*blocks_list) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +@register() +class HGNetv2(nn.Module): + """ + HGNetV2 + Args: + stem_channels: list. Number of channels for the stem block. + stage_type: str. The stage configuration of HGNet. such as the number of channels, stride, etc. + use_lab: boolean. Whether to use LearnableAffineBlock in network. + lr_mult_list: list. Control the learning rate of different stages. + Returns: + model: nn.Layer. Specific HGNetV2 model depends on args. + """ + + arch_configs = { + "B0": { + "stem_channels": [3, 16, 16], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [16, 16, 64, 1, False, False, 3, 3], + "stage2": [64, 32, 256, 1, True, False, 3, 3], + "stage3": [256, 64, 512, 2, True, True, 5, 3], + "stage4": [512, 128, 1024, 1, True, True, 5, 3], + }, + "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B0_stage1.pth", + }, + "B1": { + "stem_channels": [3, 24, 32], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 64, 1, False, False, 3, 3], + "stage2": [64, 48, 256, 1, True, False, 3, 3], + "stage3": [256, 96, 512, 2, True, True, 5, 3], + "stage4": [512, 192, 1024, 1, True, True, 5, 3], + }, + "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B1_stage1.pth", + }, + "B2": { + "stem_channels": [3, 24, 32], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 96, 1, False, False, 3, 4], + "stage2": [96, 64, 384, 1, True, False, 3, 4], + "stage3": [384, 128, 768, 3, True, True, 5, 4], + "stage4": [768, 256, 1536, 1, True, True, 5, 4], + }, + "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B2_stage1.pth", + }, + "B3": { + "stem_channels": [3, 24, 32], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [32, 32, 128, 1, False, False, 3, 5], + "stage2": [128, 64, 512, 1, True, False, 3, 5], + "stage3": [512, 128, 1024, 3, True, True, 5, 5], + "stage4": [1024, 256, 2048, 1, True, True, 5, 5], + }, + "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B3_stage1.pth", + }, + "B4": { + "stem_channels": [3, 32, 48], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [48, 48, 128, 1, False, False, 3, 6], + "stage2": [128, 96, 512, 1, True, False, 3, 6], + "stage3": [512, 192, 1024, 3, True, True, 5, 6], + "stage4": [1024, 384, 2048, 1, True, True, 5, 6], + }, + "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B4_stage1.pth", + }, + "B5": { + "stem_channels": [3, 32, 64], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [64, 64, 128, 1, False, False, 3, 6], + "stage2": [128, 128, 512, 2, True, False, 3, 6], + "stage3": [512, 256, 1024, 5, True, True, 5, 6], + "stage4": [1024, 512, 2048, 2, True, True, 5, 6], + }, + "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B5_stage1.pth", + }, + "B6": { + "stem_channels": [3, 48, 96], + "stage_config": { + # in_channels, mid_channels, out_channels, num_blocks, downsample, light_block, kernel_size, layer_num + "stage1": [96, 96, 192, 2, False, False, 3, 6], + "stage2": [192, 192, 512, 3, True, False, 3, 6], + "stage3": [512, 384, 1024, 6, True, True, 5, 6], + "stage4": [1024, 768, 2048, 3, True, True, 5, 6], + }, + "url": "https://github.com/Peterande/storage/releases/download/dfinev1.0/PPHGNetV2_B6_stage1.pth", + }, + } + + def __init__( + self, + name, + use_lab=False, + return_idx=[1, 2, 3], + freeze_stem_only=True, + freeze_at=0, + freeze_norm=True, + pretrained=True, + local_model_dir="weight/hgnetv2/", + ): + super().__init__() + self.use_lab = use_lab + self.return_idx = return_idx + + stem_channels = self.arch_configs[name]["stem_channels"] + stage_config = self.arch_configs[name]["stage_config"] + download_url = self.arch_configs[name]["url"] + + self._out_strides = [4, 8, 16, 32] + self._out_channels = [stage_config[k][2] for k in stage_config] + + # stem + self.stem = StemBlock( + in_chs=stem_channels[0], + mid_chs=stem_channels[1], + out_chs=stem_channels[2], + use_lab=use_lab, + ) + + # stages + self.stages = nn.ModuleList() + for i, k in enumerate(stage_config): + ( + in_channels, + mid_channels, + out_channels, + block_num, + downsample, + light_block, + kernel_size, + layer_num, + ) = stage_config[k] + self.stages.append( + HG_Stage( + in_channels, + mid_channels, + out_channels, + block_num, + layer_num, + downsample, + light_block, + kernel_size, + use_lab, + ) + ) + + if freeze_at >= 0: + self._freeze_parameters(self.stem) + if not freeze_stem_only: + for i in range(min(freeze_at + 1, len(self.stages))): + self._freeze_parameters(self.stages[i]) + + if freeze_norm: + self._freeze_norm(self) + + if pretrained: + RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m" + try: + model_path = local_model_dir + "PPHGNetV2_" + name + "_stage1.pth" + if os.path.exists(model_path): + state = torch.load(model_path, map_location="cpu") + print(f"Loaded stage1 {name} HGNetV2 from local file.") + else: + # If the file doesn't exist locally, download from the URL + if safe_get_rank() == 0: + print( + GREEN + + "If the pretrained HGNetV2 can't be downloaded automatically. Please check your network connection." + + RESET + ) + print( + GREEN + + "Please check your network connection. Or download the model manually from " + + RESET + + f"{download_url}" + + GREEN + + " to " + + RESET + + f"{local_model_dir}." + + RESET + ) + state = torch.hub.load_state_dict_from_url( + download_url, map_location="cpu", model_dir=local_model_dir + ) + safe_barrier() + else: + safe_barrier() + state = torch.load(local_model_dir) + + print(f"Loaded stage1 {name} HGNetV2 from URL.") + + self.load_state_dict(state) + + except (Exception, KeyboardInterrupt) as e: + if safe_get_rank() == 0: + print(f"{str(e)}") + logging.error( + RED + "CRITICAL WARNING: Failed to load pretrained HGNetV2 model" + RESET + ) + logging.error( + GREEN + + "Please check your network connection. Or download the model manually from " + + RESET + + f"{download_url}" + + GREEN + + " to " + + RESET + + f"{local_model_dir}." + + RESET + ) + exit() + + def _freeze_norm(self, m: nn.Module): + if isinstance(m, nn.BatchNorm2d): + m = FrozenBatchNorm2d(m.num_features) + else: + for name, child in m.named_children(): + _child = self._freeze_norm(child) + if _child is not child: + setattr(m, name, _child) + return m + + def _freeze_parameters(self, m: nn.Module): + for p in m.parameters(): + p.requires_grad = False + + def forward(self, x): + x = self.stem(x) + outs = [] + for idx, stage in enumerate(self.stages): + x = stage(x) + if idx in self.return_idx: + outs.append(x) + return outs diff --git a/src/nn/backbone/presnet.py b/src/nn/backbone/presnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c25c12b51e6132e4d9886203952fba988ef45e --- /dev/null +++ b/src/nn/backbone/presnet.py @@ -0,0 +1,263 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core import register +from .common import FrozenBatchNorm2d, get_activation + +__all__ = ["PResNet"] + + +ResNet_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + # 152: [3, 8, 36, 3], +} + + +donwload_url = { + 18: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth", + 34: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth", + 50: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth", + 101: "https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth", +} + + +class ConvNormLayer(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): + super().__init__() + self.conv = nn.Conv2d( + ch_in, + ch_out, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=bias, + ) + self.norm = nn.BatchNorm2d(ch_out) + self.act = get_activation(act) + + def forward(self, x): + return self.act(self.norm(self.conv(x))) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): + super().__init__() + + self.shortcut = shortcut + + if not shortcut: + if variant == "d" and stride == 2: + self.short = nn.Sequential( + OrderedDict( + [ + ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), + ("conv", ConvNormLayer(ch_in, ch_out, 1, 1)), + ] + ) + ) + else: + self.short = ConvNormLayer(ch_in, ch_out, 1, stride) + + self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act) + self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None) + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + out = self.branch2a(x) + out = self.branch2b(out) + if self.shortcut: + short = x + else: + short = self.short(x) + + out = out + short + out = self.act(out) + + return out + + +class BottleNeck(nn.Module): + expansion = 4 + + def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"): + super().__init__() + + if variant == "a": + stride1, stride2 = stride, 1 + else: + stride1, stride2 = 1, stride + + width = ch_out + + self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act) + self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act) + self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1) + + self.shortcut = shortcut + if not shortcut: + if variant == "d" and stride == 2: + self.short = nn.Sequential( + OrderedDict( + [ + ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)), + ("conv", ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)), + ] + ) + ) + else: + self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride) + + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + out = self.branch2a(x) + out = self.branch2b(out) + out = self.branch2c(out) + + if self.shortcut: + short = x + else: + short = self.short(x) + + out = out + short + out = self.act(out) + + return out + + +class Blocks(nn.Module): + def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(count): + self.blocks.append( + block( + ch_in, + ch_out, + stride=2 if i == 0 and stage_num != 2 else 1, + shortcut=False if i == 0 else True, + variant=variant, + act=act, + ) + ) + + if i == 0: + ch_in = ch_out * block.expansion + + def forward(self, x): + out = x + for block in self.blocks: + out = block(out) + return out + + +@register() +class PResNet(nn.Module): + def __init__( + self, + depth, + variant="d", + num_stages=4, + return_idx=[0, 1, 2, 3], + act="relu", + freeze_at=-1, + freeze_norm=True, + pretrained=False, + ): + super().__init__() + + block_nums = ResNet_cfg[depth] + ch_in = 64 + if variant in ["c", "d"]: + conv_def = [ + [3, ch_in // 2, 3, 2, "conv1_1"], + [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"], + [ch_in // 2, ch_in, 3, 1, "conv1_3"], + ] + else: + conv_def = [[3, ch_in, 7, 2, "conv1_1"]] + + self.conv1 = nn.Sequential( + OrderedDict( + [ + (name, ConvNormLayer(cin, cout, k, s, act=act)) + for cin, cout, k, s, name in conv_def + ] + ) + ) + + ch_out_list = [64, 128, 256, 512] + block = BottleNeck if depth >= 50 else BasicBlock + + _out_channels = [block.expansion * v for v in ch_out_list] + _out_strides = [4, 8, 16, 32] + + self.res_layers = nn.ModuleList() + for i in range(num_stages): + stage_num = i + 2 + self.res_layers.append( + Blocks( + block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant + ) + ) + ch_in = _out_channels[i] + + self.return_idx = return_idx + self.out_channels = [_out_channels[_i] for _i in return_idx] + self.out_strides = [_out_strides[_i] for _i in return_idx] + + if freeze_at >= 0: + self._freeze_parameters(self.conv1) + for i in range(min(freeze_at, num_stages)): + self._freeze_parameters(self.res_layers[i]) + + if freeze_norm: + self._freeze_norm(self) + + if pretrained: + if isinstance(pretrained, bool) or "http" in pretrained: + state = torch.hub.load_state_dict_from_url( + donwload_url[depth], map_location="cpu", model_dir="weight" + ) + else: + state = torch.load(pretrained, map_location="cpu") + self.load_state_dict(state) + print(f"Load PResNet{depth} state_dict") + + def _freeze_parameters(self, m: nn.Module): + for p in m.parameters(): + p.requires_grad = False + + def _freeze_norm(self, m: nn.Module): + if isinstance(m, nn.BatchNorm2d): + m = FrozenBatchNorm2d(m.num_features) + else: + for name, child in m.named_children(): + _child = self._freeze_norm(child) + if _child is not child: + setattr(m, name, _child) + return m + + def forward(self, x): + conv1 = self.conv1(x) + x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1) + outs = [] + for idx, stage in enumerate(self.res_layers): + x = stage(x) + if idx in self.return_idx: + outs.append(x) + return outs diff --git a/src/nn/backbone/test_resnet.py b/src/nn/backbone/test_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8210693e7d05c115bc07ffd053c2000172eb530f --- /dev/null +++ b/src/nn/backbone/test_resnet.py @@ -0,0 +1,83 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core import register + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(self.expansion * planes), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + +class _ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super().__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +@register() +class MResNet(nn.Module): + def __init__(self, num_classes=10, num_blocks=[2, 2, 2, 2]) -> None: + super().__init__() + self.model = _ResNet(BasicBlock, num_blocks, num_classes) + + def forward(self, x): + return self.model(x) diff --git a/src/nn/backbone/timm_model.py b/src/nn/backbone/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6c9d032ed7071c2203a381d99a5e575a810588a0 --- /dev/null +++ b/src/nn/backbone/timm_model.py @@ -0,0 +1,66 @@ +"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. + +https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055#0583 +""" + +import torch +from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names + +from ...core import register +from .utils import IntermediateLayerGetter + + +@register() +class TimmModel(torch.nn.Module): + def __init__( + self, name, return_layers, pretrained=False, exportable=True, features_only=True, **kwargs + ) -> None: + super().__init__() + + import timm + + model = timm.create_model( + name, + pretrained=pretrained, + exportable=exportable, + features_only=features_only, + **kwargs, + ) + # nodes, _ = get_graph_node_names(model) + # print(nodes) + # features = {'': ''} + # model = create_feature_extractor(model, return_nodes=features) + + assert set(return_layers).issubset( + model.feature_info.module_name() + ), f"return_layers should be a subset of {model.feature_info.module_name()}" + + # self.model = model + self.model = IntermediateLayerGetter(model, return_layers) + + return_idx = [model.feature_info.module_name().index(name) for name in return_layers] + self.strides = [model.feature_info.reduction()[i] for i in return_idx] + self.channels = [model.feature_info.channels()[i] for i in return_idx] + self.return_idx = return_idx + self.return_layers = return_layers + + def forward(self, x: torch.Tensor): + outputs = self.model(x) + # outputs = [outputs[i] for i in self.return_idx] + return outputs + + +if __name__ == "__main__": + model = TimmModel(name="resnet34", return_layers=["layer2", "layer3"]) + data = torch.rand(1, 3, 640, 640) + outputs = model(data) + + for output in outputs: + print(output.shape) + + """ + model: + type: TimmModel + name: resnet34 + return_layers: ['layer2', 'layer4'] + """ diff --git a/src/nn/backbone/torchvision_model.py b/src/nn/backbone/torchvision_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a584659f84975459564fe1fe43626b54b732d653 --- /dev/null +++ b/src/nn/backbone/torchvision_model.py @@ -0,0 +1,50 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torchvision + +from ...core import register +from .utils import IntermediateLayerGetter + +__all__ = ["TorchVisionModel"] + + +@register() +class TorchVisionModel(torch.nn.Module): + def __init__(self, name, return_layers, weights=None, **kwargs) -> None: + super().__init__() + + if weights is not None: + weights = getattr(torchvision.models.get_model_weights(name), weights) + + model = torchvision.models.get_model(name, weights=weights, **kwargs) + + # TODO hard code. + if hasattr(model, "features"): + model = IntermediateLayerGetter(model.features, return_layers) + else: + model = IntermediateLayerGetter(model, return_layers) + + self.model = model + + def forward(self, x): + return self.model(x) + + +# TorchVisionModel('swin_t', return_layers=['5', '7']) +# TorchVisionModel('resnet34', return_layers=['layer2','layer3', 'layer4']) + +# TorchVisionModel: +# name: swin_t +# return_layers: ['5', '7'] +# weights: DEFAULT + + +# model: +# type: TorchVisionModel +# name: resnet34 +# return_layers: ['layer2','layer3', 'layer4'] +# weights: DEFAULT diff --git a/src/nn/backbone/utils.py b/src/nn/backbone/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41ba6b51bbd00f89a43a3c45e4e90ad12b9bd462 --- /dev/null +++ b/src/nn/backbone/utils.py @@ -0,0 +1,56 @@ +""" +https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py + +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from collections import OrderedDict +from typing import Dict, List + +import torch.nn as nn + + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + """ + + _version = 3 + + def __init__(self, model: nn.Module, return_layers: List[str]) -> None: + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError( + "return_layers are not present in model. {}".format( + [name for name, _ in model.named_children()] + ) + ) + orig_return_layers = return_layers + return_layers = {str(k): str(k) for k in return_layers} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super().__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + outputs = [] + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + outputs.append(x) + + return outputs diff --git a/src/nn/criterion/__init__.py b/src/nn/criterion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a537a2dbddb39b807258bd3121f8a419c61a8c8d --- /dev/null +++ b/src/nn/criterion/__init__.py @@ -0,0 +1,11 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch.nn as nn + +from ...core import register +from .det_criterion import DetCriterion + +CrossEntropyLoss = register()(nn.CrossEntropyLoss) diff --git a/src/nn/criterion/det_criterion.py b/src/nn/criterion/det_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..80ddd42ba326174738531abe2a70f3a52373f62b --- /dev/null +++ b/src/nn/criterion/det_criterion.py @@ -0,0 +1,188 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.distributed +import torch.nn.functional as F +import torchvision + +from ...core import register +from ...misc import box_ops, dist_utils + + +@register() +class DetCriterion(torch.nn.Module): + """Default Detection Criterion""" + + __share__ = ["num_classes"] + __inject__ = ["matcher"] + + def __init__( + self, + losses, + weight_dict, + num_classes=80, + alpha=0.75, + gamma=2.0, + box_fmt="cxcywh", + matcher=None, + ): + """ + Args: + losses (list[str]): requested losses, support ['boxes', 'vfl', 'focal'] + weight_dict (dict[str, float)]: corresponding losses weight, including + ['loss_bbox', 'loss_giou', 'loss_vfl', 'loss_focal'] + box_fmt (str): in box format, 'cxcywh' or 'xyxy' + matcher (Matcher): matcher used to match source to target + """ + super().__init__() + self.losses = losses + self.weight_dict = weight_dict + self.alpha = alpha + self.gamma = gamma + self.num_classes = num_classes + self.box_fmt = box_fmt + assert matcher is not None, "" + self.matcher = matcher + + def forward(self, outputs, targets, **kwargs): + """ + Args: + outputs: Dict[Tensor], 'pred_boxes', 'pred_logits', 'meta'. + targets, List[Dict[str, Tensor]], len(targets) == batch_size. + kwargs, store other information such as current epoch id. + Return: + losses, Dict[str, Tensor] + """ + matched = self.matcher(outputs, targets) + values = matched["values"] + indices = matched["indices"] + num_boxes = self._get_positive_nums(indices) + + # Compute all the requested losses + losses = {} + for loss in self.losses: + l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + losses.update(l_dict) + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def _get_positive_nums(self, indices): + # number of positive samples + num_pos = sum(len(i) for (i, _) in indices) + num_pos = torch.as_tensor([num_pos], dtype=torch.float32, device=indices[0][0].device) + if dist_utils.is_dist_available_and_initialized(): + torch.distributed.all_reduce(num_pos) + num_pos = torch.clamp(num_pos / dist_utils.get_world_size(), min=1).item() + return num_pos + + def loss_labels_focal(self, outputs, targets, indices, num_boxes): + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1].to( + src_logits.dtype + ) + loss = torchvision.ops.sigmoid_focal_loss( + src_logits, target, self.alpha, self.gamma, reduction="none" + ) + loss = loss.sum() / num_boxes + return {"loss_focal": loss} + + def loss_labels_vfl(self, outputs, targets, indices, num_boxes): + assert "pred_boxes" in outputs + idx = self._get_src_permutation_idx(indices) + + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][j] for t, (_, j) in zip(targets, indices)], dim=0) + + src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy") + target_boxes = torchvision.ops.box_convert( + target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy" + ) + iou, _ = box_ops.elementwise_box_iou(src_boxes.detach(), target_boxes) + + src_logits: torch.Tensor = outputs["pred_logits"] + target_classes_o = torch.cat([t["labels"][j] for t, (_, j) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_o[idx] = iou.to(src_logits.dtype) + target_score = target_score_o.unsqueeze(-1) * target + + src_score = F.sigmoid(src_logits.detach()) + weight = self.alpha * src_score.pow(self.gamma) * (1 - target) + target_score + + loss = F.binary_cross_entropy_with_logits( + src_logits, target_score, weight=weight, reduction="none" + ) + loss = loss.sum() / num_boxes + return {"loss_vfl": loss} + + def loss_boxes(self, outputs, targets, indices, num_boxes): + assert "pred_boxes" in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + + losses = {} + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy") + target_boxes = torchvision.ops.box_convert( + target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy" + ) + loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + def loss_boxes_giou(self, outputs, targets, indices, num_boxes): + assert "pred_boxes" in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + + losses = {} + src_boxes = torchvision.ops.box_convert(src_boxes, in_fmt=self.box_fmt, out_fmt="xyxy") + target_boxes = torchvision.ops.box_convert( + target_boxes, in_fmt=self.box_fmt, out_fmt="xyxy" + ) + loss_giou = 1 - box_ops.elementwise_generalized_box_iou(src_boxes, target_boxes) + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + "boxes": self.loss_boxes, + "giou": self.loss_boxes_giou, + "vfl": self.loss_labels_vfl, + "focal": self.loss_labels_focal, + } + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) diff --git a/src/nn/postprocessor/__init__.py b/src/nn/postprocessor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a322bf18eae34f0ae9a9ed02e9d2df6b3bb38c2 --- /dev/null +++ b/src/nn/postprocessor/__init__.py @@ -0,0 +1,6 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from .nms_postprocessor import DetNMSPostProcessor diff --git a/src/nn/postprocessor/box_revert.py b/src/nn/postprocessor/box_revert.py new file mode 100644 index 0000000000000000000000000000000000000000..86c54de542883dd27d19eda05174a0c36c035a23 --- /dev/null +++ b/src/nn/postprocessor/box_revert.py @@ -0,0 +1,66 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from enum import Enum + +import torch +import torchvision +from torch import Tensor + + +class BoxProcessFormat(Enum): + """Box process format + + Available formats are + * ``RESIZE`` + * ``RESIZE_KEEP_RATIO`` + * ``RESIZE_KEEP_RATIO_PADDING`` + """ + + RESIZE = 1 + RESIZE_KEEP_RATIO = 2 + RESIZE_KEEP_RATIO_PADDING = 3 + + +def box_revert( + boxes: Tensor, + orig_sizes: Tensor = None, + eval_sizes: Tensor = None, + inpt_sizes: Tensor = None, + inpt_padding: Tensor = None, + normalized: bool = True, + in_fmt: str = "cxcywh", + out_fmt: str = "xyxy", + process_fmt=BoxProcessFormat.RESIZE, +) -> Tensor: + """ + Args: + boxes(Tensor), [N, :, 4], (x1, y1, x2, y2), pred boxes. + inpt_sizes(Tensor), [N, 2], (w, h). input sizes. + orig_sizes(Tensor), [N, 2], (w, h). origin sizes. + inpt_padding (Tensor), [N, 2], (w_pad, h_pad, ...). + (inpt_sizes + inpt_padding) == eval_sizes + """ + assert in_fmt in ("cxcywh", "xyxy"), "" + + if normalized and eval_sizes is not None: + boxes = boxes * eval_sizes.repeat(1, 2).unsqueeze(1) + + if inpt_padding is not None: + if in_fmt == "xyxy": + boxes -= inpt_padding[:, :2].repeat(1, 2).unsqueeze(1) + elif in_fmt == "cxcywh": + boxes[..., :2] -= inpt_padding[:, :2].repeat(1, 2).unsqueeze(1) + + if orig_sizes is not None: + orig_sizes = orig_sizes.repeat(1, 2).unsqueeze(1) + if inpt_sizes is not None: + inpt_sizes = inpt_sizes.repeat(1, 2).unsqueeze(1) + boxes = boxes * (orig_sizes / inpt_sizes) + else: + boxes = boxes * orig_sizes + + boxes = torchvision.ops.box_convert(boxes, in_fmt=in_fmt, out_fmt=out_fmt) + return boxes diff --git a/src/nn/postprocessor/detr_postprocessor.py b/src/nn/postprocessor/detr_postprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..f161ecf5793f82800e8b6b327b9a866a5014fbf3 --- /dev/null +++ b/src/nn/postprocessor/detr_postprocessor.py @@ -0,0 +1,86 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +__all__ = ["DetDETRPostProcessor"] + +from .box_revert import BoxProcessFormat, box_revert + + +def mod(a, b): + out = a - a // b * b + return out + + +class DetDETRPostProcessor(nn.Module): + def __init__( + self, + num_classes=80, + use_focal_loss=True, + num_top_queries=300, + box_process_format=BoxProcessFormat.RESIZE, + ) -> None: + super().__init__() + self.use_focal_loss = use_focal_loss + self.num_top_queries = num_top_queries + self.num_classes = int(num_classes) + self.box_process_format = box_process_format + self.deploy_mode = False + + def extra_repr(self) -> str: + return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}" + + def forward(self, outputs, **kwargs): + logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] + + if self.use_focal_loss: + scores = F.sigmoid(logits) + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) + labels = index % self.num_classes + # labels = mod(index, self.num_classes) # for tensorrt + index = index // self.num_classes + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + + else: + scores = F.softmax(logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > self.num_top_queries: + scores, index = torch.topk(scores, self.num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather( + boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]) + ) + + if kwargs is not None: + boxes = box_revert( + boxes, + in_fmt="cxcywh", + out_fmt="xyxy", + process_fmt=self.box_process_format, + normalized=True, + **kwargs, + ) + + # TODO for onnx export + if self.deploy_mode: + return labels, boxes, scores + + results = [] + for lab, box, sco in zip(labels, boxes, scores): + result = dict(labels=lab, boxes=box, scores=sco) + results.append(result) + + return results + + def deploy( + self, + ): + self.eval() + self.deploy_mode = True + return self diff --git a/src/nn/postprocessor/nms_postprocessor.py b/src/nn/postprocessor/nms_postprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc35924857f92643428c0865c607b8eec5e588c --- /dev/null +++ b/src/nn/postprocessor/nms_postprocessor.py @@ -0,0 +1,86 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from typing import Dict + +import torch +import torch.distributed +import torch.nn.functional as F +import torchvision +from torch import Tensor + +from ...core import register + +__all__ = [ + "DetNMSPostProcessor", +] + + +@register() +class DetNMSPostProcessor(torch.nn.Module): + def __init__( + self, + iou_threshold=0.7, + score_threshold=0.01, + keep_topk=300, + box_fmt="cxcywh", + logit_fmt="sigmoid", + ) -> None: + super().__init__() + self.iou_threshold = iou_threshold + self.score_threshold = score_threshold + self.keep_topk = keep_topk + self.box_fmt = box_fmt.lower() + self.logit_fmt = logit_fmt.lower() + self.logit_func = getattr(F, self.logit_fmt, None) + self.deploy_mode = False + + def forward(self, outputs: Dict[str, Tensor], orig_target_sizes: Tensor): + logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] + pred_boxes = torchvision.ops.box_convert(boxes, in_fmt=self.box_fmt, out_fmt="xyxy") + pred_boxes *= orig_target_sizes.repeat(1, 2).unsqueeze(1) + + values, pred_labels = torch.max(logits, dim=-1) + + if self.logit_func: + pred_scores = self.logit_func(values) + else: + pred_scores = values + + # TODO for onnx export + if self.deploy_mode: + blobs = { + "pred_labels": pred_labels, + "pred_boxes": pred_boxes, + "pred_scores": pred_scores, + } + return blobs + + results = [] + for i in range(logits.shape[0]): + score_keep = pred_scores[i] > self.score_threshold + pred_box = pred_boxes[i][score_keep] + pred_label = pred_labels[i][score_keep] + pred_score = pred_scores[i][score_keep] + + keep = torchvision.ops.batched_nms(pred_box, pred_score, pred_label, self.iou_threshold) + keep = keep[: self.keep_topk] + + blob = { + "labels": pred_label[keep], + "boxes": pred_box[keep], + "scores": pred_score[keep], + } + + results.append(blob) + + return results + + def deploy( + self, + ): + self.eval() + self.deploy_mode = True + return self diff --git a/src/optim/__init__.py b/src/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1762a8bf988816c991d693f2282736941f628a9 --- /dev/null +++ b/src/optim/__init__.py @@ -0,0 +1,9 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from .amp import * +from .ema import * +from .optim import * +from .warmup import * diff --git a/src/optim/amp.py b/src/optim/amp.py new file mode 100644 index 0000000000000000000000000000000000000000..255fdadaeb496b337f0ebf53eac84f7259a21c67 --- /dev/null +++ b/src/optim/amp.py @@ -0,0 +1,12 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch.cuda.amp as amp + +from ..core import register + +__all__ = ["GradScaler"] + +GradScaler = register()(amp.grad_scaler.GradScaler) diff --git a/src/optim/ema.py b/src/optim/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..ead78154a38cae68e3ab8ac8aefa0ba91b58bbd8 --- /dev/null +++ b/src/optim/ema.py @@ -0,0 +1,108 @@ +""" +D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +--------------------------------------------------------------------------------- +Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright (c) 2023 lyuwenyu. All Rights Reserved. +""" + +import math +from copy import deepcopy + +import torch +import torch.nn as nn + +from ..core import register +from ..misc import dist_utils + +__all__ = ["ModelEMA"] + + +@register() +class ModelEMA(object): + """ + Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models + Keep a moving average of everything in the model state_dict (parameters and buffers). + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + A smoothed version of the weights is necessary for some training schemes to perform well. + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + + def __init__( + self, model: nn.Module, decay: float = 0.9999, warmups: int = 1000, start: int = 0 + ): + super().__init__() + + self.module = deepcopy(dist_utils.de_parallel(model)).eval() + # if next(model.parameters()).device.type != 'cpu': + # self.module.half() # FP16 EMA + + self.decay = decay + self.warmups = warmups + self.before_start = 0 + self.start = start + self.updates = 0 # number of EMA updates + if warmups == 0: + self.decay_fn = lambda x: decay + else: + self.decay_fn = lambda x: decay * ( + 1 - math.exp(-x / warmups) + ) # decay exponential ramp (to help early epochs) + + for p in self.module.parameters(): + p.requires_grad_(False) + + def update(self, model: nn.Module): + if self.before_start < self.start: + self.before_start += 1 + return + # Update EMA parameters + with torch.no_grad(): + self.updates += 1 + d = self.decay_fn(self.updates) + msd = dist_utils.de_parallel(model).state_dict() + for k, v in self.module.state_dict().items(): + if v.dtype.is_floating_point: + v *= d + v += (1 - d) * msd[k].detach() + + def to(self, *args, **kwargs): + self.module = self.module.to(*args, **kwargs) + return self + + def state_dict( + self, + ): + return dict(module=self.module.state_dict(), updates=self.updates) + + def load_state_dict(self, state, strict=True): + self.module.load_state_dict(state["module"], strict=strict) + if "updates" in state: + self.updates = state["updates"] + + def forwad( + self, + ): + raise RuntimeError("ema...") + + def extra_repr(self) -> str: + return f"decay={self.decay}, warmups={self.warmups}" + + +class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): + """Maintains moving averages of model parameters using an exponential decay. + ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` + `torch.optim.swa_utils.AveragedModel `_ + is used to compute the EMA. + """ + + def __init__(self, model, decay, device="cpu", use_buffers=True): + self.decay_fn = lambda x: decay * (1 - math.exp(-x / 2000)) + + def ema_avg(avg_model_param, model_param, num_averaged): + decay = self.decay_fn(num_averaged) + return decay * avg_model_param + (1 - decay) * model_param + + super().__init__(model, device, ema_avg, use_buffers=use_buffers) diff --git a/src/optim/optim.py b/src/optim/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..daf863d499d9e2845a0253b39fb274e12c9b517e --- /dev/null +++ b/src/optim/optim.py @@ -0,0 +1,22 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler + +from ..core import register + +__all__ = ["AdamW", "SGD", "Adam", "MultiStepLR", "CosineAnnealingLR", "OneCycleLR", "LambdaLR"] + + +SGD = register()(optim.SGD) +Adam = register()(optim.Adam) +AdamW = register()(optim.AdamW) + + +MultiStepLR = register()(lr_scheduler.MultiStepLR) +CosineAnnealingLR = register()(lr_scheduler.CosineAnnealingLR) +OneCycleLR = register()(lr_scheduler.OneCycleLR) +LambdaLR = register()(lr_scheduler.LambdaLR) diff --git a/src/optim/warmup.py b/src/optim/warmup.py new file mode 100644 index 0000000000000000000000000000000000000000..2d021633026fbe8b18dd846419e6d9b9f39ce7b8 --- /dev/null +++ b/src/optim/warmup.py @@ -0,0 +1,56 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from torch.optim.lr_scheduler import LRScheduler + +from ..core import register + + +class Warmup(object): + def __init__( + self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1 + ) -> None: + self.lr_scheduler = lr_scheduler + self.warmup_end_values = [pg["lr"] for pg in lr_scheduler.optimizer.param_groups] + self.last_step = last_step + self.warmup_duration = warmup_duration + self.step() + + def state_dict(self): + return {k: v for k, v in self.__dict__.items() if k != "lr_scheduler"} + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + + def get_warmup_factor(self, step, **kwargs): + raise NotImplementedError + + def step( + self, + ): + self.last_step += 1 + if self.last_step >= self.warmup_duration: + return + factor = self.get_warmup_factor(self.last_step) + for i, pg in enumerate(self.lr_scheduler.optimizer.param_groups): + pg["lr"] = factor * self.warmup_end_values[i] + + def finished( + self, + ): + if self.last_step >= self.warmup_duration: + return True + return False + + +@register() +class LinearWarmup(Warmup): + def __init__( + self, lr_scheduler: LRScheduler, warmup_duration: int, last_step: int = -1 + ) -> None: + super().__init__(lr_scheduler, warmup_duration, last_step) + + def get_warmup_factor(self, step): + return min(1.0, (step + 1) / self.warmup_duration) diff --git a/src/solver/__init__.py b/src/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1636fe6ebb4b9283ca264567b2d90e7ec2e7850 --- /dev/null +++ b/src/solver/__init__.py @@ -0,0 +1,15 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from typing import Dict + +from ._solver import BaseSolver +from .clas_solver import ClasSolver +from .det_solver import DetSolver + +TASKS: Dict[str, BaseSolver] = { + "classification": ClasSolver, + "detection": DetSolver, +} diff --git a/src/solver/_solver.py b/src/solver/_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..f45022e8a1c20f74e45604bb411015946e7d34fb --- /dev/null +++ b/src/solver/_solver.py @@ -0,0 +1,783 @@ +import atexit +from datetime import datetime +from pathlib import Path +from typing import Dict + +import torch +import torch.nn as nn + +from ..core import BaseConfig +from ..misc import dist_utils + + +def to(m: nn.Module, device: str): + if m is None: + return None + return m.to(device) + + +def remove_module_prefix(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("module."): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + return new_state_dict + + +class BaseSolver(object): + def __init__(self, cfg: BaseConfig) -> None: + self.cfg = cfg + self.obj365_ids = [ + 0, + 46, + 5, + 58, + 114, + 55, + 116, + 65, + 21, + 40, + 176, + 127, + 249, + 24, + 56, + 139, + 92, + 78, + 99, + 96, + 144, + 295, + 178, + 180, + 38, + 39, + 13, + 43, + 120, + 219, + 148, + 173, + 165, + 154, + 137, + 113, + 145, + 146, + 204, + 8, + 35, + 10, + 88, + 84, + 93, + 26, + 112, + 82, + 265, + 104, + 141, + 152, + 234, + 143, + 150, + 97, + 2, + 50, + 25, + 75, + 98, + 153, + 37, + 73, + 115, + 132, + 106, + 61, + 163, + 134, + 277, + 81, + 133, + 18, + 94, + 30, + 169, + 70, + 328, + 226, + ] + + def _setup(self): + """Avoid instantiating unnecessary classes""" + cfg = self.cfg + if cfg.device: + device = torch.device(cfg.device) + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.model = cfg.model + + # NOTE: Must load_tuning_state before EMA instance building + if self.cfg.tuning: + print(f"Tuning checkpoint from {self.cfg.tuning}") + self.load_tuning_state(self.cfg.tuning) + + self.model = dist_utils.warp_model( + self.model.to(device), + sync_bn=cfg.sync_bn, + find_unused_parameters=cfg.find_unused_parameters, + ) + + self.criterion = self.to(cfg.criterion, device) + self.postprocessor = self.to(cfg.postprocessor, device) + + self.ema = self.to(cfg.ema, device) + self.scaler = cfg.scaler + + self.device = device + self.last_epoch = self.cfg.last_epoch + + self.output_dir = Path(cfg.output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.writer = cfg.writer + + if self.writer: + atexit.register(self.writer.close) + if dist_utils.is_main_process(): + self.writer.add_text("config", "{:s}".format(cfg.__repr__()), 0) + self.use_wandb = self.cfg.use_wandb + if self.use_wandb: + try: + import wandb + self.use_wandb = True + except ImportError: + self.use_wandb = False + + def cleanup(self): + if self.writer: + atexit.register(self.writer.close) + + def train(self): + self._setup() + self.optimizer = self.cfg.optimizer + self.lr_scheduler = self.cfg.lr_scheduler + self.lr_warmup_scheduler = self.cfg.lr_warmup_scheduler + + self.train_dataloader = dist_utils.warp_loader( + self.cfg.train_dataloader, shuffle=self.cfg.train_dataloader.shuffle + ) + self.val_dataloader = dist_utils.warp_loader( + self.cfg.val_dataloader, shuffle=self.cfg.val_dataloader.shuffle + ) + + self.evaluator = self.cfg.evaluator + + # NOTE: Instantiating order + if self.cfg.resume: + print(f"Resume checkpoint from {self.cfg.resume}") + self.load_resume_state(self.cfg.resume) + + def eval(self): + self._setup() + + self.val_dataloader = dist_utils.warp_loader( + self.cfg.val_dataloader, shuffle=self.cfg.val_dataloader.shuffle + ) + + self.evaluator = self.cfg.evaluator + + if self.cfg.resume: + print(f"Resume checkpoint from {self.cfg.resume}") + self.load_resume_state(self.cfg.resume) + + def to(self, module, device): + return module.to(device) if hasattr(module, "to") else module + + def state_dict(self): + """State dict, train/eval""" + state = {} + state["date"] = datetime.now().isoformat() + + # For resume + state["last_epoch"] = self.last_epoch + + for k, v in self.__dict__.items(): + if hasattr(v, "state_dict"): + v = dist_utils.de_parallel(v) + state[k] = v.state_dict() + + return state + + def load_state_dict(self, state): + """Load state dict, train/eval""" + if "last_epoch" in state: + self.last_epoch = state["last_epoch"] + print("Load last_epoch") + + for k, v in self.__dict__.items(): + if hasattr(v, "load_state_dict") and k in state: + v = dist_utils.de_parallel(v) + v.load_state_dict(state[k]) + print(f"Load {k}.state_dict") + + if hasattr(v, "load_state_dict") and k not in state: + if k == "ema": + model = getattr(self, "model", None) + if model is not None: + ema = dist_utils.de_parallel(v) + model_state_dict = remove_module_prefix(model.state_dict()) + ema.load_state_dict({"module": model_state_dict}) + print(f"Load {k}.state_dict from model.state_dict") + else: + print(f"Not load {k}.state_dict") + + def load_resume_state(self, path: str): + """Load resume""" + if path.startswith("http"): + state = torch.hub.load_state_dict_from_url(path, map_location="cpu") + else: + state = torch.load(path, map_location="cpu") + + # state['model'] = remove_module_prefix(state['model']) + self.load_state_dict(state) + + def load_tuning_state(self, path: str): + """Load model for tuning and adjust mismatched head parameters""" + if path.startswith("http"): + state = torch.hub.load_state_dict_from_url(path, map_location="cpu") + else: + state = torch.load(path, map_location="cpu") + + module = dist_utils.de_parallel(self.model) + + # Load the appropriate state dict + if "ema" in state: + pretrain_state_dict = state["ema"]["module"] + else: + pretrain_state_dict = state["model"] + + # Adjust head parameters between datasets + try: + adjusted_state_dict = self._adjust_head_parameters( + module.state_dict(), pretrain_state_dict + ) + stat, infos = self._matched_state(module.state_dict(), adjusted_state_dict) + except Exception: + stat, infos = self._matched_state(module.state_dict(), pretrain_state_dict) + + module.load_state_dict(stat, strict=False) + print(f"Load model.state_dict, {infos}") + + @staticmethod + def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]): + missed_list = [] + unmatched_list = [] + matched_state = {} + for k, v in state.items(): + if k in params: + if v.shape == params[k].shape: + matched_state[k] = params[k] + else: + unmatched_list.append(k) + else: + missed_list.append(k) + + return matched_state, {"missed": missed_list, "unmatched": unmatched_list} + + def _adjust_head_parameters(self, cur_state_dict, pretrain_state_dict): + """Adjust head parameters between datasets.""" + # List of parameters to adjust + if ( + pretrain_state_dict["decoder.denoising_class_embed.weight"].size() + != cur_state_dict["decoder.denoising_class_embed.weight"].size() + ): + del pretrain_state_dict["decoder.denoising_class_embed.weight"] + + head_param_names = ["decoder.enc_score_head.weight", "decoder.enc_score_head.bias"] + for i in range(8): + head_param_names.append(f"decoder.dec_score_head.{i}.weight") + head_param_names.append(f"decoder.dec_score_head.{i}.bias") + + adjusted_params = [] + + for param_name in head_param_names: + if param_name in cur_state_dict and param_name in pretrain_state_dict: + cur_tensor = cur_state_dict[param_name] + pretrain_tensor = pretrain_state_dict[param_name] + adjusted_tensor = self.map_class_weights(cur_tensor, pretrain_tensor) + if adjusted_tensor is not None: + pretrain_state_dict[param_name] = adjusted_tensor + adjusted_params.append(param_name) + else: + print(f"Cannot adjust parameter '{param_name}' due to size mismatch.") + + return pretrain_state_dict + + def map_class_weights(self, cur_tensor, pretrain_tensor): + """Map class weights from pretrain model to current model based on class IDs.""" + if pretrain_tensor.size() == cur_tensor.size(): + return pretrain_tensor + + adjusted_tensor = cur_tensor.clone() + adjusted_tensor.requires_grad = False + + if pretrain_tensor.size() > cur_tensor.size(): + for coco_id, obj_id in enumerate(self.obj365_ids): + adjusted_tensor[coco_id] = pretrain_tensor[obj_id + 1] + else: + for coco_id, obj_id in enumerate(self.obj365_ids): + adjusted_tensor[obj_id + 1] = pretrain_tensor[coco_id] + + return adjusted_tensor + + def fit(self): + raise NotImplementedError("") + + def val(self): + raise NotImplementedError("") + + +# obj365_classes = [ +# 'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp', 'Glasses', +# 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf', 'Handbag/Satchel', +# 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', 'Book', 'Gloves', 'Storage box', +# 'Boat', 'Leather Shoes', 'Flower', 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', +# 'Pillow', 'Boots', 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', +# 'Belt', 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', 'Watch', +# 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool', 'Barrel/bucket', 'Van', +# 'Couch', 'Sandals', 'Bakset', 'Drum', 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', +# 'Motorcycle', 'Guitar', 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck', +# 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy', 'Candle', 'Sailboat', +# 'Laptop', 'Awning', 'Bed', 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet', +# 'Sink', 'Apple', 'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', +# 'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock', 'Pot', 'Cow', +# 'Cake', 'Dinning Table', 'Sheep', 'Hanger', 'Blackboard/Whiteboard', 'Napkin', +# 'Other Fish', 'Orange/Tangerine', 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', +# 'Machinery Vehicle', 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', +# 'Airplane', 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage', 'Nightstand', +# 'Tea pot', 'Telephone', 'Trolley', 'Head Phone', 'Sports Car', 'Stop Sign', 'Dessert', +# 'Scooter', 'Stroller', 'Crane', 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', +# 'Baseball Bat', 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza', +# 'Elephant', 'Skateboard', 'Surfboard', 'Gun', 'Skating and Skiing shoes', 'Gas stove', +# 'Donut', 'Bow Tie', 'Carrot', 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', +# 'Pepper', 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks', 'Microwave', +# 'Pigeon', 'Baseball', 'Cutting/chopping Board', 'Coffee Table', 'Side Table', 'Scissors', +# 'Marker', 'Pie', 'Ladder', 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', +# 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', 'Egg', +# 'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards', 'Converter', 'Bathtub', +# 'Wheelchair', 'Golf Club', 'Briefcase', 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', +# 'Pear', 'Heavy Truck', 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', +# 'Tennis Racket', 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', +# 'Tennis', 'Ship', 'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion', +# 'Green beans', 'Projector', 'Frisbee', 'Washing Machine/Drying Machine', 'Chicken', +# 'Printer', 'Watermelon', 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', +# 'Hotair ballon', 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog', +# 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer', 'Goose', 'Tape', +# 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple', 'Golf Ball', 'Ambulance', 'Parking meter', +# 'Mango', 'Key', 'Hurdle', 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', +# 'Megaphone', 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', +# 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom', 'Trombone', +# 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', 'Router/modem', 'Poker Card', 'Toaster', +# 'Shrimp', 'Sushi', 'Cheese', 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', +# 'Cue', 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap', 'Recorder', +# 'Bear', 'Eggplant', 'Board Eraser', 'Coconut', 'Tape Measur/ Ruler', 'Pig', +# 'Showerhead', 'Globe', 'Chips', 'Steak', 'Crosswalk Sign', 'Stapler', 'Campel', +# 'Formula 1 ', 'Pomegranate', 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', +# 'Rice Cooker', 'Tuba', 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', +# 'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill', +# 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', 'Lighter', 'Grapefruit', +# 'Game board', 'Mop', 'Radish', 'Baozi', 'Target', 'French', 'Spring Rolls', 'Monkey', +# 'Rabbit', 'Pencil Case', 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', +# 'Scallop', 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle', +# 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster', 'Durian', 'Okra', +# 'Lipstick', 'Cosmetics Mirror', 'Curling', 'Table Tennis ' +# ] + +# coco_classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', +# 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', +# 'stop sign', 'parking meter', 'bench', 'wild bird', 'cat', 'dog', +# 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', +# 'backpack', 'umbrella', 'handbag/satchel', 'tie', 'luggage', 'frisbee', +# 'skating and skiing shoes', 'snowboard', 'baseball', 'kite', 'baseball bat', +# 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', +# 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl/basin', +# 'banana', 'apple', 'sandwich', 'orange/tangerine', 'broccoli', 'carrot', +# 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', +# 'potted plant', 'bed', 'dinning table', 'toilet', 'moniter/tv', 'laptop', +# 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', +# 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', +# 'vase', 'scissors', 'stuffed toy', 'hair dryer', 'toothbrush'] + + +# obj365_classes = [ +# (0, 'Person'), +# (1, 'Sneakers'), +# (2, 'Chair'), +# (3, 'Other Shoes'), +# (4, 'Hat'), +# (5, 'Car'), +# (6, 'Lamp'), +# (7, 'Glasses'), +# (8, 'Bottle'), +# (9, 'Desk'), +# (10, 'Cup'), +# (11, 'Street Lights'), +# (12, 'Cabinet/shelf'), +# (13, 'Handbag/Satchel'), +# (14, 'Bracelet'), +# (15, 'Plate'), +# (16, 'Picture/Frame'), +# (17, 'Helmet'), +# (18, 'Book'), +# (19, 'Gloves'), +# (20, 'Storage box'), +# (21, 'Boat'), +# (22, 'Leather Shoes'), +# (23, 'Flower'), +# (24, 'Bench'), +# (25, 'Potted Plant'), +# (26, 'Bowl/Basin'), +# (27, 'Flag'), +# (28, 'Pillow'), +# (29, 'Boots'), +# (30, 'Vase'), +# (31, 'Microphone'), +# (32, 'Necklace'), +# (33, 'Ring'), +# (34, 'SUV'), +# (35, 'Wine Glass'), +# (36, 'Belt'), +# (37, 'Monitor/TV'), +# (38, 'Backpack'), +# (39, 'Umbrella'), +# (40, 'Traffic Light'), +# (41, 'Speaker'), +# (42, 'Watch'), +# (43, 'Tie'), +# (44, 'Trash bin Can'), +# (45, 'Slippers'), +# (46, 'Bicycle'), +# (47, 'Stool'), +# (48, 'Barrel/bucket'), +# (49, 'Van'), +# (50, 'Couch'), +# (51, 'Sandals'), +# (52, 'Basket'), +# (53, 'Drum'), +# (54, 'Pen/Pencil'), +# (55, 'Bus'), +# (56, 'Wild Bird'), +# (57, 'High Heels'), +# (58, 'Motorcycle'), +# (59, 'Guitar'), +# (60, 'Carpet'), +# (61, 'Cell Phone'), +# (62, 'Bread'), +# (63, 'Camera'), +# (64, 'Canned'), +# (65, 'Truck'), +# (66, 'Traffic cone'), +# (67, 'Cymbal'), +# (68, 'Lifesaver'), +# (69, 'Towel'), +# (70, 'Stuffed Toy'), +# (71, 'Candle'), +# (72, 'Sailboat'), +# (73, 'Laptop'), +# (74, 'Awning'), +# (75, 'Bed'), +# (76, 'Faucet'), +# (77, 'Tent'), +# (78, 'Horse'), +# (79, 'Mirror'), +# (80, 'Power outlet'), +# (81, 'Sink'), +# (82, 'Apple'), +# (83, 'Air Conditioner'), +# (84, 'Knife'), +# (85, 'Hockey Stick'), +# (86, 'Paddle'), +# (87, 'Pickup Truck'), +# (88, 'Fork'), +# (89, 'Traffic Sign'), +# (90, 'Balloon'), +# (91, 'Tripod'), +# (92, 'Dog'), +# (93, 'Spoon'), +# (94, 'Clock'), +# (95, 'Pot'), +# (96, 'Cow'), +# (97, 'Cake'), +# (98, 'Dining Table'), +# (99, 'Sheep'), +# (100, 'Hanger'), +# (101, 'Blackboard/Whiteboard'), +# (102, 'Napkin'), +# (103, 'Other Fish'), +# (104, 'Orange/Tangerine'), +# (105, 'Toiletry'), +# (106, 'Keyboard'), +# (107, 'Tomato'), +# (108, 'Lantern'), +# (109, 'Machinery Vehicle'), +# (110, 'Fan'), +# (111, 'Green Vegetables'), +# (112, 'Banana'), +# (113, 'Baseball Glove'), +# (114, 'Airplane'), +# (115, 'Mouse'), +# (116, 'Train'), +# (117, 'Pumpkin'), +# (118, 'Soccer'), +# (119, 'Skiboard'), +# (120, 'Luggage'), +# (121, 'Nightstand'), +# (122, 'Tea pot'), +# (123, 'Telephone'), +# (124, 'Trolley'), +# (125, 'Head Phone'), +# (126, 'Sports Car'), +# (127, 'Stop Sign'), +# (128, 'Dessert'), +# (129, 'Scooter'), +# (130, 'Stroller'), +# (131, 'Crane'), +# (132, 'Remote'), +# (133, 'Refrigerator'), +# (134, 'Oven'), +# (135, 'Lemon'), +# (136, 'Duck'), +# (137, 'Baseball Bat'), +# (138, 'Surveillance Camera'), +# (139, 'Cat'), +# (140, 'Jug'), +# (141, 'Broccoli'), +# (142, 'Piano'), +# (143, 'Pizza'), +# (144, 'Elephant'), +# (145, 'Skateboard'), +# (146, 'Surfboard'), +# (147, 'Gun'), +# (148, 'Skating and Skiing Shoes'), +# (149, 'Gas Stove'), +# (150, 'Donut'), +# (151, 'Bow Tie'), +# (152, 'Carrot'), +# (153, 'Toilet'), +# (154, 'Kite'), +# (155, 'Strawberry'), +# (156, 'Other Balls'), +# (157, 'Shovel'), +# (158, 'Pepper'), +# (159, 'Computer Box'), +# (160, 'Toilet Paper'), +# (161, 'Cleaning Products'), +# (162, 'Chopsticks'), +# (163, 'Microwave'), +# (164, 'Pigeon'), +# (165, 'Baseball'), +# (166, 'Cutting/chopping Board'), +# (167, 'Coffee Table'), +# (168, 'Side Table'), +# (169, 'Scissors'), +# (170, 'Marker'), +# (171, 'Pie'), +# (172, 'Ladder'), +# (173, 'Snowboard'), +# (174, 'Cookies'), +# (175, 'Radiator'), +# (176, 'Fire Hydrant'), +# (177, 'Basketball'), +# (178, 'Zebra'), +# (179, 'Grape'), +# (180, 'Giraffe'), +# (181, 'Potato'), +# (182, 'Sausage'), +# (183, 'Tricycle'), +# (184, 'Violin'), +# (185, 'Egg'), +# (186, 'Fire Extinguisher'), +# (187, 'Candy'), +# (188, 'Fire Truck'), +# (189, 'Billiards'), +# (190, 'Converter'), +# (191, 'Bathtub'), +# (192, 'Wheelchair'), +# (193, 'Golf Club'), +# (194, 'Briefcase'), +# (195, 'Cucumber'), +# (196, 'Cigar/Cigarette'), +# (197, 'Paint Brush'), +# (198, 'Pear'), +# (199, 'Heavy Truck'), +# (200, 'Hamburger'), +# (201, 'Extractor'), +# (202, 'Extension Cord'), +# (203, 'Tong'), +# (204, 'Tennis Racket'), +# (205, 'Folder'), +# (206, 'American Football'), +# (207, 'Earphone'), +# (208, 'Mask'), +# (209, 'Kettle'), +# (210, 'Tennis'), +# (211, 'Ship'), +# (212, 'Swing'), +# (213, 'Coffee Machine'), +# (214, 'Slide'), +# (215, 'Carriage'), +# (216, 'Onion'), +# (217, 'Green Beans'), +# (218, 'Projector'), +# (219, 'Frisbee'), +# (220, 'Washing Machine/Drying Machine'), +# (221, 'Chicken'), +# (222, 'Printer'), +# (223, 'Watermelon'), +# (224, 'Saxophone'), +# (225, 'Tissue'), +# (226, 'Toothbrush'), +# (227, 'Ice Cream'), +# (228, 'Hot Air Balloon'), +# (229, 'Cello'), +# (230, 'French Fries'), +# (231, 'Scale'), +# (232, 'Trophy'), +# (233, 'Cabbage'), +# (234, 'Hot Dog'), +# (235, 'Blender'), +# (236, 'Peach'), +# (237, 'Rice'), +# (238, 'Wallet/Purse'), +# (239, 'Volleyball'), +# (240, 'Deer'), +# (241, 'Goose'), +# (242, 'Tape'), +# (243, 'Tablet'), +# (244, 'Cosmetics'), +# (245, 'Trumpet'), +# (246, 'Pineapple'), +# (247, 'Golf Ball'), +# (248, 'Ambulance'), +# (249, 'Parking Meter'), +# (250, 'Mango'), +# (251, 'Key'), +# (252, 'Hurdle'), +# (253, 'Fishing Rod'), +# (254, 'Medal'), +# (255, 'Flute'), +# (256, 'Brush'), +# (257, 'Penguin'), +# (258, 'Megaphone'), +# (259, 'Corn'), +# (260, 'Lettuce'), +# (261, 'Garlic'), +# (262, 'Swan'), +# (263, 'Helicopter'), +# (264, 'Green Onion'), +# (265, 'Sandwich'), +# (266, 'Nuts'), +# (267, 'Speed Limit Sign'), +# (268, 'Induction Cooker'), +# (269, 'Broom'), +# (270, 'Trombone'), +# (271, 'Plum'), +# (272, 'Rickshaw'), +# (273, 'Goldfish'), +# (274, 'Kiwi Fruit'), +# (275, 'Router/Modem'), +# (276, 'Poker Card'), +# (277, 'Toaster'), +# (278, 'Shrimp'), +# (279, 'Sushi'), +# (280, 'Cheese'), +# (281, 'Notepaper'), +# (282, 'Cherry'), +# (283, 'Pliers'), +# (284, 'CD'), +# (285, 'Pasta'), +# (286, 'Hammer'), +# (287, 'Cue'), +# (288, 'Avocado'), +# (289, 'Hami Melon'), +# (290, 'Flask'), +# (291, 'Mushroom'), +# (292, 'Screwdriver'), +# (293, 'Soap'), +# (294, 'Recorder'), +# (295, 'Bear'), +# (296, 'Eggplant'), +# (297, 'Board Eraser'), +# (298, 'Coconut'), +# (299, 'Tape Measure/Ruler'), +# (300, 'Pig'), +# (301, 'Showerhead'), +# (302, 'Globe'), +# (303, 'Chips'), +# (304, 'Steak'), +# (305, 'Crosswalk Sign'), +# (306, 'Stapler'), +# (307, 'Camel'), +# (308, 'Formula 1'), +# (309, 'Pomegranate'), +# (310, 'Dishwasher'), +# (311, 'Crab'), +# (312, 'Hoverboard'), +# (313, 'Meatball'), +# (314, 'Rice Cooker'), +# (315, 'Tuba'), +# (316, 'Calculator'), +# (317, 'Papaya'), +# (318, 'Antelope'), +# (319, 'Parrot'), +# (320, 'Seal'), +# (321, 'Butterfly'), +# (322, 'Dumbbell'), +# (323, 'Donkey'), +# (324, 'Lion'), +# (325, 'Urinal'), +# (326, 'Dolphin'), +# (327, 'Electric Drill'), +# (328, 'Hair Dryer'), +# (329, 'Egg Tart'), +# (330, 'Jellyfish'), +# (331, 'Treadmill'), +# (332, 'Lighter'), +# (333, 'Grapefruit'), +# (334, 'Game Board'), +# (335, 'Mop'), +# (336, 'Radish'), +# (337, 'Baozi'), +# (338, 'Target'), +# (339, 'French'), +# (340, 'Spring Rolls'), +# (341, 'Monkey'), +# (342, 'Rabbit'), +# (343, 'Pencil Case'), +# (344, 'Yak'), +# (345, 'Red Cabbage'), +# (346, 'Binoculars'), +# (347, 'Asparagus'), +# (348, 'Barbell'), +# (349, 'Scallop'), +# (350, 'Noodles'), +# (351, 'Comb'), +# (352, 'Dumpling'), +# (353, 'Oyster'), +# (354, 'Table Tennis Paddle'), +# (355, 'Cosmetics Brush/Eyeliner Pencil'), +# (356, 'Chainsaw'), +# (357, 'Eraser'), +# (358, 'Lobster'), +# (359, 'Durian'), +# (360, 'Okra'), +# (361, 'Lipstick'), +# (362, 'Cosmetics Mirror'), +# (363, 'Curling'), +# (364, 'Table Tennis') +# ] diff --git a/src/solver/clas_engine.py b/src/solver/clas_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..228c89921abf29ea7c6bbd4d23ab00f7d78b6f16 --- /dev/null +++ b/src/solver/clas_engine.py @@ -0,0 +1,74 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.nn as nn + +from ..misc import MetricLogger, SmoothedValue, reduce_dict + + +def train_one_epoch( + model: nn.Module, criterion: nn.Module, dataloader, optimizer, ema, epoch, device +): + """ """ + model.train() + + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) + print_freq = 100 + header = "Epoch: [{}]".format(epoch) + + for imgs, labels in metric_logger.log_every(dataloader, print_freq, header): + imgs = imgs.to(device) + labels = labels.to(device) + + preds = model(imgs) + loss: torch.Tensor = criterion(preds, labels, epoch) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if ema is not None: + ema.update(model) + + loss_reduced_values = {k: v.item() for k, v in reduce_dict({"loss": loss}).items()} + metric_logger.update(**loss_reduced_values) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + return stats + + +@torch.no_grad() +def evaluate(model, criterion, dataloader, device): + model.eval() + + metric_logger = MetricLogger(delimiter=" ") + # metric_logger.add_meter('acc', SmoothedValue(window_size=1, fmt='{global_avg:.4f}')) + # metric_logger.add_meter('loss', SmoothedValue(window_size=1, fmt='{value:.2f}')) + metric_logger.add_meter("acc", SmoothedValue(window_size=1)) + metric_logger.add_meter("loss", SmoothedValue(window_size=1)) + + header = "Test:" + for imgs, labels in metric_logger.log_every(dataloader, 10, header): + imgs, labels = imgs.to(device), labels.to(device) + preds = model(imgs) + + acc = (preds.argmax(dim=-1) == labels).sum() / preds.shape[0] + loss = criterion(preds, labels) + + dict_reduced = reduce_dict({"acc": acc, "loss": loss}) + reduced_values = {k: v.item() for k, v in dict_reduced.items()} + metric_logger.update(**reduced_values) + + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + + stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + return stats diff --git a/src/solver/clas_solver.py b/src/solver/clas_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..d5617e6b8dcd831f32f5504aede9fe3dfb434b60 --- /dev/null +++ b/src/solver/clas_solver.py @@ -0,0 +1,75 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import datetime +import json +import time +from pathlib import Path + +import torch +import torch.nn as nn + +from ..misc import dist_utils +from ._solver import BaseSolver +from .clas_engine import evaluate, train_one_epoch + + +class ClasSolver(BaseSolver): + def fit( + self, + ): + print("Start training") + self.train() + args = self.cfg + + n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + print("Number of params:", n_parameters) + + output_dir = Path(args.output_dir) + output_dir.mkdir(exist_ok=True) + + start_time = time.time() + start_epoch = self.last_epoch + 1 + for epoch in range(start_epoch, args.epochs): + if dist_utils.is_dist_available_and_initialized(): + self.train_dataloader.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + self.model, + self.criterion, + self.train_dataloader, + self.optimizer, + self.ema, + epoch=epoch, + device=self.device, + ) + self.lr_scheduler.step() + self.last_epoch += 1 + + if output_dir: + checkpoint_paths = [output_dir / "checkpoint.pth"] + # extra checkpoint before LR drop and every 100 epochs + if (epoch + 1) % args.checkpoint_freq == 0: + checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") + for checkpoint_path in checkpoint_paths: + dist_utils.save_on_master(self.state_dict(epoch), checkpoint_path) + + module = self.ema.module if self.ema else self.model + test_stats = evaluate(module, self.criterion, self.val_dataloader, self.device) + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if output_dir and dist_utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) diff --git a/src/solver/det_engine.py b/src/solver/det_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..e3467577d61db08ebc89b29d76f9698eb279053f --- /dev/null +++ b/src/solver/det_engine.py @@ -0,0 +1,257 @@ +""" +D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +--------------------------------------------------------------------------------- +Modified from DETR (https://github.com/facebookresearch/detr/blob/main/engine.py) +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +""" + +import math +import sys +from typing import Dict, Iterable, List + +import numpy as np +import torch +import torch.amp +from torch.cuda.amp.grad_scaler import GradScaler +from torch.utils.tensorboard import SummaryWriter + +from ..data import CocoEvaluator +from ..data.dataset import mscoco_category2label +from ..misc import MetricLogger, SmoothedValue, dist_utils, save_samples +from ..optim import ModelEMA, Warmup +from .validator import Validator, scale_boxes + + +def train_one_epoch( + model: torch.nn.Module, + criterion: torch.nn.Module, + data_loader: Iterable, + optimizer: torch.optim.Optimizer, + device: torch.device, + epoch: int, + use_wandb: bool, + max_norm: float = 0, + **kwargs, +): + if use_wandb: + import wandb + + model.train() + criterion.train() + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = "Epoch: [{}]".format(epoch) + + print_freq = kwargs.get("print_freq", 10) + writer: SummaryWriter = kwargs.get("writer", None) + + ema: ModelEMA = kwargs.get("ema", None) + scaler: GradScaler = kwargs.get("scaler", None) + lr_warmup_scheduler: Warmup = kwargs.get("lr_warmup_scheduler", None) + losses = [] + + output_dir = kwargs.get("output_dir", None) + num_visualization_sample_batch = kwargs.get("num_visualization_sample_batch", 1) + + for i, (samples, targets) in enumerate( + metric_logger.log_every(data_loader, print_freq, header) + ): + global_step = epoch * len(data_loader) + i + metas = dict(epoch=epoch, step=i, global_step=global_step, epoch_step=len(data_loader)) + + if global_step < num_visualization_sample_batch and output_dir is not None and dist_utils.is_main_process(): + save_samples(samples, targets, output_dir, "train", normalized=True, box_fmt="cxcywh") + + samples = samples.to(device) + targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] + + if scaler is not None: + with torch.autocast(device_type=str(device), cache_enabled=True): + outputs = model(samples, targets=targets) + + if torch.isnan(outputs["pred_boxes"]).any() or torch.isinf(outputs["pred_boxes"]).any(): + print(outputs["pred_boxes"]) + state = model.state_dict() + new_state = {} + for key, value in model.state_dict().items(): + # Replace 'module' with 'model' in each key + new_key = key.replace("module.", "") + # Add the updated key-value pair to the state dictionary + state[new_key] = value + new_state["model"] = state + dist_utils.save_on_master(new_state, "./NaN.pth") + + with torch.autocast(device_type=str(device), enabled=False): + loss_dict = criterion(outputs, targets, **metas) + + loss = sum(loss_dict.values()) + scaler.scale(loss).backward() + + if max_norm > 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + else: + outputs = model(samples, targets=targets) + loss_dict = criterion(outputs, targets, **metas) + + loss: torch.Tensor = sum(loss_dict.values()) + optimizer.zero_grad() + loss.backward() + + if max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + + optimizer.step() + + # ema + if ema is not None: + ema.update(model) + + if lr_warmup_scheduler is not None: + lr_warmup_scheduler.step() + + loss_dict_reduced = dist_utils.reduce_dict(loss_dict) + loss_value = sum(loss_dict_reduced.values()) + losses.append(loss_value.detach().cpu().numpy()) + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + metric_logger.update(loss=loss_value, **loss_dict_reduced) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + if writer and dist_utils.is_main_process() and global_step % 10 == 0: + writer.add_scalar("Loss/total", loss_value.item(), global_step) + for j, pg in enumerate(optimizer.param_groups): + writer.add_scalar(f"Lr/pg_{j}", pg["lr"], global_step) + for k, v in loss_dict_reduced.items(): + writer.add_scalar(f"Loss/{k}", v.item(), global_step) + + if use_wandb: + wandb.log( + {"lr": optimizer.param_groups[0]["lr"], "epoch": epoch, "train/loss": np.mean(losses)} + ) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate( + model: torch.nn.Module, + criterion: torch.nn.Module, + postprocessor, + data_loader, + coco_evaluator: CocoEvaluator, + device, + epoch: int, + use_wandb: bool, + **kwargs, +): + if use_wandb: + import wandb + + model.eval() + criterion.eval() + coco_evaluator.cleanup() + + metric_logger = MetricLogger(delimiter=" ") + # metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = "Test:" + + # iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessor.keys()) + iou_types = coco_evaluator.iou_types + # coco_evaluator = CocoEvaluator(base_ds, iou_types) + # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] + + gt: List[Dict[str, torch.Tensor]] = [] + preds: List[Dict[str, torch.Tensor]] = [] + + output_dir = kwargs.get("output_dir", None) + num_visualization_sample_batch = kwargs.get("num_visualization_sample_batch", 1) + + for i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, 10, header)): + global_step = epoch * len(data_loader) + i + + if global_step < num_visualization_sample_batch and output_dir is not None and dist_utils.is_main_process(): + save_samples(samples, targets, output_dir, "val", normalized=False, box_fmt="xyxy") + + samples = samples.to(device) + targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] + + outputs = model(samples) + # with torch.autocast(device_type=str(device)): + # outputs = model(samples) + + # TODO (lyuwenyu), fix dataset converted using `convert_to_coco_api`? + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + # orig_target_sizes = torch.tensor([[samples.shape[-1], samples.shape[-2]]], device=samples.device) + + results = postprocessor(outputs, orig_target_sizes) + + # if 'segm' in postprocessor.keys(): + # target_sizes = torch.stack([t["size"] for t in targets], dim=0) + # results = postprocessor['segm'](results, outputs, orig_target_sizes, target_sizes) + + res = {target["image_id"].item(): output for target, output in zip(targets, results)} + if coco_evaluator is not None: + coco_evaluator.update(res) + + # validator format for metrics + for idx, (target, result) in enumerate(zip(targets, results)): + gt.append( + { + "boxes": scale_boxes( # from model input size to original img size + target["boxes"], + (target["orig_size"][1], target["orig_size"][0]), + (samples[idx].shape[-1], samples[idx].shape[-2]), + ), + "labels": target["labels"], + } + ) + labels = ( + torch.tensor([mscoco_category2label[int(x.item())] for x in result["labels"].flatten()]) + .to(result["labels"].device) + .reshape(result["labels"].shape) + ) if postprocessor.remap_mscoco_category else result["labels"] + preds.append( + {"boxes": result["boxes"], "labels": labels, "scores": result["scores"]} + ) + + # Conf matrix, F1, Precision, Recall, box IoU + metrics = Validator(gt, preds).compute_metrics() + print("Metrics:", metrics) + if use_wandb: + metrics = {f"metrics/{k}": v for k, v in metrics.items()} + metrics["epoch"] = epoch + wandb.log(metrics) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + + stats = {} + # stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + if "bbox" in iou_types: + stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist() + if "segm" in iou_types: + stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() + + return stats, coco_evaluator diff --git a/src/solver/det_solver.py b/src/solver/det_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaacd2a09eb023793d24a422b86cf9694aa4d06 --- /dev/null +++ b/src/solver/det_solver.py @@ -0,0 +1,227 @@ +""" +D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +--------------------------------------------------------------------------------- +Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright (c) 2023 lyuwenyu. All Rights Reserved. +""" + +import datetime +import json +import time + +import torch + +from ..misc import dist_utils, stats +from ._solver import BaseSolver +from .det_engine import evaluate, train_one_epoch + + +class DetSolver(BaseSolver): + def fit(self): + self.train() + args = self.cfg + metric_names = ["AP50:95", "AP50", "AP75", "APsmall", "APmedium", "APlarge"] + + if self.use_wandb: + import wandb + + wandb.init( + project=args.yaml_cfg["project_name"], + name=args.yaml_cfg["exp_name"], + config=args.yaml_cfg, + ) + wandb.watch(self.model) + + n_parameters, model_stats = stats(self.cfg) + print(model_stats) + print("-" * 42 + "Start training" + "-" * 43) + top1 = 0 + best_stat = { + "epoch": -1, + } + if self.last_epoch > 0: + module = self.ema.module if self.ema else self.model + test_stats, coco_evaluator = evaluate( + module, + self.criterion, + self.postprocessor, + self.val_dataloader, + self.evaluator, + self.device, + self.last_epoch, + self.use_wandb + ) + for k in test_stats: + best_stat["epoch"] = self.last_epoch + best_stat[k] = test_stats[k][0] + top1 = test_stats[k][0] + print(f"best_stat: {best_stat}") + + best_stat_print = best_stat.copy() + start_time = time.time() + start_epoch = self.last_epoch + 1 + for epoch in range(start_epoch, args.epochs): + self.train_dataloader.set_epoch(epoch) + # self.train_dataloader.dataset.set_epoch(epoch) + if dist_utils.is_dist_available_and_initialized(): + self.train_dataloader.sampler.set_epoch(epoch) + + if epoch == self.train_dataloader.collate_fn.stop_epoch: + self.load_resume_state(str(self.output_dir / "best_stg1.pth")) + if self.ema: + self.ema.decay = self.train_dataloader.collate_fn.ema_restart_decay + print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}") + + train_stats = train_one_epoch( + self.model, + self.criterion, + self.train_dataloader, + self.optimizer, + self.device, + epoch, + max_norm=args.clip_max_norm, + print_freq=args.print_freq, + ema=self.ema, + scaler=self.scaler, + lr_warmup_scheduler=self.lr_warmup_scheduler, + writer=self.writer, + use_wandb=self.use_wandb, + output_dir=self.output_dir, + ) + + if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished(): + self.lr_scheduler.step() + + self.last_epoch += 1 + + if self.output_dir and epoch < self.train_dataloader.collate_fn.stop_epoch: + checkpoint_paths = [self.output_dir / "last.pth"] + # extra checkpoint before LR drop and every 100 epochs + if (epoch + 1) % args.checkpoint_freq == 0: + checkpoint_paths.append(self.output_dir / f"checkpoint{epoch:04}.pth") + for checkpoint_path in checkpoint_paths: + dist_utils.save_on_master(self.state_dict(), checkpoint_path) + + module = self.ema.module if self.ema else self.model + test_stats, coco_evaluator = evaluate( + module, + self.criterion, + self.postprocessor, + self.val_dataloader, + self.evaluator, + self.device, + epoch, + self.use_wandb, + output_dir=self.output_dir, + ) + + # TODO + for k in test_stats: + if self.writer and dist_utils.is_main_process(): + for i, v in enumerate(test_stats[k]): + self.writer.add_scalar(f"Test/{k}_{i}".format(k), v, epoch) + + if k in best_stat: + best_stat["epoch"] = ( + epoch if test_stats[k][0] > best_stat[k] else best_stat["epoch"] + ) + best_stat[k] = max(best_stat[k], test_stats[k][0]) + else: + best_stat["epoch"] = epoch + best_stat[k] = test_stats[k][0] + + if best_stat[k] > top1: + best_stat_print["epoch"] = epoch + top1 = best_stat[k] + if self.output_dir: + if epoch >= self.train_dataloader.collate_fn.stop_epoch: + dist_utils.save_on_master( + self.state_dict(), self.output_dir / "best_stg2.pth" + ) + else: + dist_utils.save_on_master( + self.state_dict(), self.output_dir / "best_stg1.pth" + ) + + best_stat_print[k] = max(best_stat[k], top1) + print(f"best_stat: {best_stat_print}") # global best + + if best_stat["epoch"] == epoch and self.output_dir: + if epoch >= self.train_dataloader.collate_fn.stop_epoch: + if test_stats[k][0] > top1: + top1 = test_stats[k][0] + dist_utils.save_on_master( + self.state_dict(), self.output_dir / "best_stg2.pth" + ) + else: + top1 = max(test_stats[k][0], top1) + dist_utils.save_on_master( + self.state_dict(), self.output_dir / "best_stg1.pth" + ) + + elif epoch >= self.train_dataloader.collate_fn.stop_epoch: + best_stat = { + "epoch": -1, + } + if self.ema: + self.ema.decay -= 0.0001 + self.load_resume_state(str(self.output_dir / "best_stg1.pth")) + print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}") + + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters, + } + + if self.use_wandb: + wandb_logs = {} + for idx, metric_name in enumerate(metric_names): + wandb_logs[f"metrics/{metric_name}"] = test_stats["coco_eval_bbox"][idx] + wandb_logs["epoch"] = epoch + wandb.log(wandb_logs) + + if self.output_dir and dist_utils.is_main_process(): + with (self.output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + # for evaluation logs + if coco_evaluator is not None: + (self.output_dir / "eval").mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ["latest.pth"] + if epoch % 50 == 0: + filenames.append(f"{epoch:03}.pth") + for name in filenames: + torch.save( + coco_evaluator.coco_eval["bbox"].eval, + self.output_dir / "eval" / name, + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + + def val(self): + self.eval() + + module = self.ema.module if self.ema else self.model + test_stats, coco_evaluator = evaluate( + module, + self.criterion, + self.postprocessor, + self.val_dataloader, + self.evaluator, + self.device, + epoch=-1, + use_wandb=False, + ) + + if self.output_dir: + dist_utils.save_on_master( + coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth" + ) + + return diff --git a/src/solver/validator.py b/src/solver/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..14aa742ea5a959e00250d6c4e97e5d5213cce886 --- /dev/null +++ b/src/solver/validator.py @@ -0,0 +1,347 @@ +import copy +from collections import defaultdict +from pathlib import Path +from typing import Dict, List + +import matplotlib.pyplot as plt +import numpy as np +import torch +from loguru import logger +from torchvision.ops import box_iou + + +class Validator: + def __init__( + self, + gt: List[Dict[str, torch.Tensor]], + preds: List[Dict[str, torch.Tensor]], + conf_thresh=0.5, + iou_thresh=0.5, + ) -> None: + """ + Format example: + gt = [{'labels': tensor([0]), 'boxes': tensor([[561.0, 297.0, 661.0, 359.0]])}, ...] + len(gt) is the number of images + bboxes are in format [x1, y1, x2, y2], absolute values + """ + self.gt = gt + self.preds = preds + self.conf_thresh = conf_thresh + self.iou_thresh = iou_thresh + self.thresholds = np.arange(0.2, 1.0, 0.05) + self.conf_matrix = None + + def compute_metrics(self, extended=False) -> Dict[str, float]: + filtered_preds = filter_preds(copy.deepcopy(self.preds), self.conf_thresh) + metrics = self._compute_main_metrics(filtered_preds) + if not extended: + metrics.pop("extended_metrics", None) + return metrics + + def _compute_main_metrics(self, preds): + ( + self.metrics_per_class, + self.conf_matrix, + self.class_to_idx, + ) = self._compute_metrics_and_confusion_matrix(preds) + tps, fps, fns = 0, 0, 0 + ious = [] + extended_metrics = {} + for key, value in self.metrics_per_class.items(): + tps += value["TPs"] + fps += value["FPs"] + fns += value["FNs"] + ious.extend(value["IoUs"]) + + extended_metrics[f"precision_{key}"] = ( + value["TPs"] / (value["TPs"] + value["FPs"]) + if value["TPs"] + value["FPs"] > 0 + else 0 + ) + extended_metrics[f"recall_{key}"] = ( + value["TPs"] / (value["TPs"] + value["FNs"]) + if value["TPs"] + value["FNs"] > 0 + else 0 + ) + + extended_metrics[f"iou_{key}"] = np.mean(value["IoUs"]) + + precision = tps / (tps + fps) if (tps + fps) > 0 else 0 + recall = tps / (tps + fns) if (tps + fns) > 0 else 0 + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + iou = np.mean(ious).item() if ious else 0 + return { + "f1": f1, + "precision": precision, + "recall": recall, + "iou": iou, + "TPs": tps, + "FPs": fps, + "FNs": fns, + "extended_metrics": extended_metrics, + } + + def _compute_matrix_multi_class(self, preds): + metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) + for pred, gt in zip(preds, self.gt): + pred_boxes = pred["boxes"] + pred_labels = pred["labels"] + gt_boxes = gt["boxes"] + gt_labels = gt["labels"] + + # isolate each class + labels = torch.unique(torch.cat([pred_labels, gt_labels])) + for label in labels: + pred_cl_boxes = pred_boxes[pred_labels == label] # filter by bool mask + gt_cl_boxes = gt_boxes[gt_labels == label] + + n_preds = len(pred_cl_boxes) + n_gts = len(gt_cl_boxes) + if not (n_preds or n_gts): + continue + if not n_preds: + metrics_per_class[label.item()]["FNs"] += n_gts + metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) + continue + if not n_gts: + metrics_per_class[label.item()]["FPs"] += n_preds + metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) + continue + + ious = box_iou(pred_cl_boxes, gt_cl_boxes) # matrix of all IoUs + ious_mask = ious >= self.iou_thresh + + # indeces of boxes that have IoU >= threshold + pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) + + if not pred_indices.numel(): # no predicts matched gts + metrics_per_class[label.item()]["FNs"] += n_gts + metrics_per_class[label.item()]["IoUs"].extend([0] * n_gts) + metrics_per_class[label.item()]["FPs"] += n_preds + metrics_per_class[label.item()]["IoUs"].extend([0] * n_preds) + continue + + iou_values = ious[pred_indices, gt_indices] + + # sorting by IoU to match hgihest scores first + sorted_indices = torch.argsort(-iou_values) + pred_indices = pred_indices[sorted_indices] + gt_indices = gt_indices[sorted_indices] + iou_values = iou_values[sorted_indices] + + matched_preds = set() + matched_gts = set() + for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): + if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds: + matched_preds.add(pred_idx.item()) + matched_gts.add(gt_idx.item()) + metrics_per_class[label.item()]["TPs"] += 1 + metrics_per_class[label.item()]["IoUs"].append(iou.item()) + + unmatched_preds = set(range(n_preds)) - matched_preds + unmatched_gts = set(range(n_gts)) - matched_gts + metrics_per_class[label.item()]["FPs"] += len(unmatched_preds) + metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_preds)) + metrics_per_class[label.item()]["FNs"] += len(unmatched_gts) + metrics_per_class[label.item()]["IoUs"].extend([0] * len(unmatched_gts)) + return metrics_per_class + + def _compute_metrics_and_confusion_matrix(self, preds): + # Initialize per-class metrics + metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) + + # Collect all class IDs + all_classes = set() + for pred in preds: + all_classes.update(pred["labels"].tolist()) + for gt in self.gt: + all_classes.update(gt["labels"].tolist()) + all_classes = sorted(list(all_classes)) + class_to_idx = {cls_id: idx for idx, cls_id in enumerate(all_classes)} + n_classes = len(all_classes) + conf_matrix = np.zeros((n_classes + 1, n_classes + 1), dtype=int) # +1 for background class + + for pred, gt in zip(preds, self.gt): + pred_boxes = pred["boxes"] + pred_labels = pred["labels"] + gt_boxes = gt["boxes"] + gt_labels = gt["labels"] + + n_preds = len(pred_boxes) + n_gts = len(gt_boxes) + + if n_preds == 0 and n_gts == 0: + continue + + ious = box_iou(pred_boxes, gt_boxes) if n_preds > 0 and n_gts > 0 else torch.tensor([]) + # Assign matches between preds and gts + matched_pred_indices = set() + matched_gt_indices = set() + + if ious.numel() > 0: + # For each pred box, find the gt box with highest IoU + ious_mask = ious >= self.iou_thresh + pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) + iou_values = ious[pred_indices, gt_indices] + + # Sorting by IoU to match highest scores first + sorted_indices = torch.argsort(-iou_values) + pred_indices = pred_indices[sorted_indices] + gt_indices = gt_indices[sorted_indices] + iou_values = iou_values[sorted_indices] + + for pred_idx, gt_idx, iou in zip(pred_indices, gt_indices, iou_values): + if ( + pred_idx.item() in matched_pred_indices + or gt_idx.item() in matched_gt_indices + ): + continue + matched_pred_indices.add(pred_idx.item()) + matched_gt_indices.add(gt_idx.item()) + + pred_label = pred_labels[pred_idx].item() + gt_label = gt_labels[gt_idx].item() + + pred_cls_idx = class_to_idx[pred_label] + gt_cls_idx = class_to_idx[gt_label] + + # Update confusion matrix + conf_matrix[gt_cls_idx, pred_cls_idx] += 1 + + # Update per-class metrics + if pred_label == gt_label: + metrics_per_class[gt_label]["TPs"] += 1 + metrics_per_class[gt_label]["IoUs"].append(iou.item()) + else: + # Misclassification + metrics_per_class[gt_label]["FNs"] += 1 + metrics_per_class[pred_label]["FPs"] += 1 + metrics_per_class[gt_label]["IoUs"].append(0) + metrics_per_class[pred_label]["IoUs"].append(0) + + # Unmatched predictions (False Positives) + unmatched_pred_indices = set(range(n_preds)) - matched_pred_indices + for pred_idx in unmatched_pred_indices: + pred_label = pred_labels[pred_idx].item() + pred_cls_idx = class_to_idx[pred_label] + # Update confusion matrix: background row + conf_matrix[n_classes, pred_cls_idx] += 1 + # Update per-class metrics + metrics_per_class[pred_label]["FPs"] += 1 + metrics_per_class[pred_label]["IoUs"].append(0) + + # Unmatched ground truths (False Negatives) + unmatched_gt_indices = set(range(n_gts)) - matched_gt_indices + for gt_idx in unmatched_gt_indices: + gt_label = gt_labels[gt_idx].item() + gt_cls_idx = class_to_idx[gt_label] + # Update confusion matrix: background column + conf_matrix[gt_cls_idx, n_classes] += 1 + # Update per-class metrics + metrics_per_class[gt_label]["FNs"] += 1 + metrics_per_class[gt_label]["IoUs"].append(0) + + return metrics_per_class, conf_matrix, class_to_idx + + def save_plots(self, path_to_save) -> None: + path_to_save = Path(path_to_save) + path_to_save.mkdir(parents=True, exist_ok=True) + + if self.conf_matrix is not None: + class_labels = [str(cls_id) for cls_id in self.class_to_idx.keys()] + ["background"] + + plt.figure(figsize=(10, 8)) + plt.imshow(self.conf_matrix, interpolation="nearest", cmap=plt.cm.Blues) + plt.title("Confusion Matrix") + plt.colorbar() + tick_marks = np.arange(len(class_labels)) + plt.xticks(tick_marks, class_labels, rotation=45) + plt.yticks(tick_marks, class_labels) + + # Add labels to each cell + thresh = self.conf_matrix.max() / 2.0 + for i in range(self.conf_matrix.shape[0]): + for j in range(self.conf_matrix.shape[1]): + plt.text( + j, + i, + format(self.conf_matrix[i, j], "d"), + horizontalalignment="center", + color="white" if self.conf_matrix[i, j] > thresh else "black", + ) + + plt.ylabel("True label") + plt.xlabel("Predicted label") + plt.tight_layout() + plt.savefig(path_to_save / "confusion_matrix.png") + plt.close() + + thresholds = self.thresholds + precisions, recalls, f1_scores = [], [], [] + + # Store the original predictions to reset after each threshold + original_preds = copy.deepcopy(self.preds) + + for threshold in thresholds: + # Filter predictions based on the current threshold + filtered_preds = filter_preds(copy.deepcopy(original_preds), threshold) + # Compute metrics with the filtered predictions + metrics = self._compute_main_metrics(filtered_preds) + precisions.append(metrics["precision"]) + recalls.append(metrics["recall"]) + f1_scores.append(metrics["f1"]) + + # Plot Precision and Recall vs Threshold + plt.figure() + plt.plot(thresholds, precisions, label="Precision", marker="o") + plt.plot(thresholds, recalls, label="Recall", marker="o") + plt.xlabel("Threshold") + plt.ylabel("Value") + plt.title("Precision and Recall vs Threshold") + plt.legend() + plt.grid(True) + plt.savefig(path_to_save / "precision_recall_vs_threshold.png") + plt.close() + + # Plot F1 Score vs Threshold + plt.figure() + plt.plot(thresholds, f1_scores, label="F1 Score", marker="o") + plt.xlabel("Threshold") + plt.ylabel("F1 Score") + plt.title("F1 Score vs Threshold") + plt.grid(True) + plt.savefig(path_to_save / "f1_score_vs_threshold.png") + plt.close() + + # Find the best threshold based on F1 Score (last occurence) + best_idx = len(f1_scores) - np.argmax(f1_scores[::-1]) - 1 + best_threshold = thresholds[best_idx] + best_f1 = f1_scores[best_idx] + + logger.info( + f"Best Threshold: {round(best_threshold, 2)} with F1 Score: {round(best_f1, 3)}" + ) + + +def filter_preds(preds, conf_thresh): + for pred in preds: + keep_idxs = pred["scores"] >= conf_thresh + pred["scores"] = pred["scores"][keep_idxs] + pred["boxes"] = pred["boxes"][keep_idxs] + pred["labels"] = pred["labels"][keep_idxs] + return preds + + +def scale_boxes(boxes, orig_shape, resized_shape): + """ + boxes in format: [x1, y1, x2, y2], absolute values + orig_shape: [height, width] + resized_shape: [height, width] + """ + scale_x = orig_shape[1] / resized_shape[1] + scale_y = orig_shape[0] / resized_shape[0] + boxes[:, 0] *= scale_x + boxes[:, 2] *= scale_x + boxes[:, 1] *= scale_y + boxes[:, 3] *= scale_y + return boxes diff --git a/src/zoo/__init__.py b/src/zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90da867ed4a50efd9d7236dda5c46c36dd14d894 --- /dev/null +++ b/src/zoo/__init__.py @@ -0,0 +1,6 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from . import dfine diff --git a/src/zoo/dfine/__init__.py b/src/zoo/dfine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d310da329ddd8d1017df2151ec8eae41a81bfdf --- /dev/null +++ b/src/zoo/dfine/__init__.py @@ -0,0 +1,11 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +from .dfine import DFINE +from .dfine_criterion import DFINECriterion +from .dfine_decoder import DFINETransformer +from .hybrid_encoder import HybridEncoder +from .matcher import HungarianMatcher +from .postprocessor import DFINEPostProcessor diff --git a/src/zoo/dfine/blog.md b/src/zoo/dfine/blog.md new file mode 100644 index 0000000000000000000000000000000000000000..81e1373e2b0b2c5614ba7125e3a016c60a34829e --- /dev/null +++ b/src/zoo/dfine/blog.md @@ -0,0 +1,90 @@ +English Blog | [中文博客](blog_cn.md) + +## 🔥 Revolutionizing Real-Time Object Detection: D-FINE vs. YOLO and Other DETR Models + +In the rapidly evolving field of real-time object detection, **D-FINE** emerges as a revolutionary approach that significantly surpasses existing models like **YOLOv10**, **YOLO11**, and **RT-DETR v1/v2/v3**, raising the performance ceiling for real-time object detection. After pretraining on the large-scale dataset Objects365, **D-FINE** far exceeds its competitor **LW-DETR**, achieving up to **59.3%** AP on the COCO dataset while maintaining excellent frame rates, parameter counts, and computational complexity. This positions **D-FINE** as a leader in the realm of real-time object detection, laying the groundwork for future research advancements. + +Currently, all code, weights, logs, compilation tools, and the FiftyOne visualization tool for **D-FINE** have been fully open-sourced, thanks to the codebase provided by **RT-DETR**. This includes pretraining tutorials, custom dataset tutorials, and more. We will continue to update with improvement insights and tuning strategies. We welcome everyone to raise issues and collectively promote the **D-FINE** series. We also hope you can leave a ⭐; it's the best encouragement for us. + +**GitHub Repo**: https://github.com/Peterande/D-FINE + +**ArXiv Paper**: https://arxiv.org/abs/2410.13842 + +--- + +### 🔍 Exploring the Key Innovations Behind D-FINE + +**D-FINE** redefines the regression task in DETR-based object detectors as **FDR**, and based on this, develops a performance-enhancing self-distillation mechanism **GO-LSD**. Below is a brief introduction to **FDR** and **GO-LSD**: + +#### **FDR (Fine-grained Distribution Refinement)** Decouples the Bounding Box Generation Process: + +1. **Initial Box Prediction**: Similar to traditional DETR methods, the decoder of **D-FINE** transforms object queries into several initial bounding boxes in the first layer. These boxes do not need to be highly accurate and serve only as an initialization. +2. **Fine-Grained Distribution Optimization**: Unlike traditional methods that directly decode new bounding boxes, **D-FINE** generates four sets of probability distributions based on these initial bounding boxes in the decoder layers and iteratively optimizes these distributions layer by layer. These distributions essentially act as a "fine-grained intermediate representation" of the detection boxes. Coupled with a carefully designed weighting function **W(n)**, **D-FINE** can adjust the initial bounding boxes by fine-tuning these representations, allowing for subtle modifications or significant shifts of the edges (top, bottom, left, right). The specific process is illustrated in the figure: + +

+ Fine-grained Distribution Refinement Process +

+ +For readability, we will not elaborate on the mathematical formulas and the Fine-Grained Localization (FGL) Loss that aids optimization here. Interested readers can refer to the original paper for derivations. + +The main advantages of redefining the bounding box regression task as **FDR** are: + +1. **Simplified Supervision**: While optimizing detection boxes using traditional L1 loss and IoU loss, the "residual" between labels and predictions can be additionally used to constrain these intermediate probability distributions. This allows each decoder layer to more effectively focus on and address the localization errors it currently faces. As the number of layers increases, their optimization objectives become progressively simpler, thereby simplifying the overall optimization process. + +2. **Robustness in Complex Scenarios**: The values of these probability distributions inherently represent the confidence level of fine-tuning for each edge. This enables **D-FINE** to independently model the uncertainty of each edge at different network depths, thereby exhibiting stronger robustness in complex real-world scenarios such as occlusion, motion blur, and low-light conditions, compared to directly regressing four fixed values. + +3. **Flexible Optimization Mechanism**: The probability distributions are transformed into final bounding box offsets through a weighted sum. The carefully designed weighting function ensures fine-grained adjustments when the initial box is accurate and provides significant corrections when necessary. + +4. **Research Potential and Scalability**: By transforming the regression task into a probability distribution prediction problem consistent with classification tasks, **FDR** not only enhances compatibility with other tasks but also enables object detection models to benefit from innovations in areas such as knowledge distillation, multi-task learning, and distribution optimization, opening new avenues for future research. + +--- + +#### **GO-LSD (Global Optimal Localization Self-Distillation)** Integrates Knowledge Distillation into FDR-Based Detectors Seamlessly + +Based on the above, object detectors equipped with the **FDR** framework satisfy the following two points: + +1. **Ability to Achieve Knowledge Transfer**: As Hinton mentioned in the paper *"Distilling the Knowledge in a Neural Network"*, probabilities are "knowledge." The network's output becomes probability distributions, and these distributions carry localization knowledge. By calculating the KLD loss, this "knowledge" can be transferred from deeper layers to shallower layers. This is something that traditional fixed box representations (Dirac δ functions) cannot achieve. + +2. **Consistent Optimization Objectives**: Since each decoder layer in the **FDR** framework shares a common goal: reducing the residual between the initial bounding box and the ground truth bounding box, the precise probability distributions generated by the final layer can serve as the ultimate target for each preceding layer and guide them through distillation. + +Thus, based on **FDR**, we propose **GO-LSD (Global Optimal Localization Self-Distillation)**. By implementing localization knowledge distillation between network layers, we further extend the capabilities of **D-FINE**. The specific process is illustrated in the figure: + +

+ GO-LSD Process +

+ +Similarly, for readability, we will not elaborate on the mathematical formulas and the Decoupled Distillation Focal (DDF) Loss that aids optimization here. Interested readers can refer to the original paper for derivations. + +This results in a synergistic win-win effect: as training progresses, the predictions of the final layer become increasingly accurate, and its generated soft labels can better help the preceding layers improve prediction accuracy. Conversely, the earlier layers learn to localize accurately more quickly, simplifying the optimization tasks of the deeper layers and further enhancing overall accuracy. + +--- + +### Visualization of D-FINE Predictions + +The following visualization showcases **D-FINE**'s predictions in various complex detection scenarios. These scenarios include occlusion, low-light conditions, motion blur, depth-of-field effects, and densely populated scenes. Despite these challenges, **D-FINE** still produces accurate localization results. + +

+ D-FINE Predictions in Complex Scenarios +

+ +Additionally, the visualization below shows the prediction results of the first layer and the last layer, the corresponding distributions of the four edges, and the weighted distributions. It can be seen that the localization of the predicted boxes becomes more precise as the distributions are optimized. + +

+ +

+ +--- + +### Frequently Asked Questions + +#### Question 1: Will FDR and GO-LSD increase the inference cost? + +**Answer**: No, FDR and the original prediction have almost no difference in speed, parameter count, and computational complexity, making it a seamless replacement. + +#### Question 2: Will FDR and GO-LSD increase the training cost? + +**Answer**: The increase in training cost mainly comes from how to generate the labels of the distributions. We have optimized this process, keeping the increase in training time and memory consumption to 6% and 2%, respectively, making it almost negligible. + +#### Question 3: Why is D-FINE faster and more lightweight than the RT-DETR series? + +**Answer**: Directly applying FDR and GO-LSD will significantly improve performance but will not make the network faster or lighter. Therefore, we performed a series of lightweight optimizations on RT-DETR. These adjustments led to a performance drop, but our methods compensated for these losses, achieving a perfect balance of speed, parameters, computational complexity, and performance. diff --git a/src/zoo/dfine/blog_cn.md b/src/zoo/dfine/blog_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..efe88bfc9e0e1de515cfcea53d11eeb412e9ad2f --- /dev/null +++ b/src/zoo/dfine/blog_cn.md @@ -0,0 +1,90 @@ +[English Blog](blog.md) | 中文博客 + +## 🔥 革新实时目标检测:D-FINE 与 YOLO 和其他 DETR 模型的对比 + +在快速发展的实时目标检测领域,**D-FINE** 作为一项革命性的方法,显著超越了现有模型(如 **YOLOv10**、**YOLO11** 及 **RT-DETR v1/v2/v3**),提升了实时目标检测的性能上限。经过大规模数据集 Objects365 的预训练,**D-FINE** 远超其竞争对手 **LW-DETR**,在 COCO 数据集上实现了高达 **59.3%** 的 AP,同时保持了卓越的帧率、参数量和计算复杂度。这使得 **D-FINE** 成为实时目标检测领域的佼佼者,为未来的研究奠定了基础。 + +目前,D-FINE 的所有代码、权重、日志、编译工具,以及 FiftyOne 可视化工具已经全部开源,感谢 RT-DETR 提供的 codebase。其中还包括了预训练教程、自定义数据集教程等。之后还会陆续更新一些改进心得和调参攻略,欢迎大家多提 issue,共同将 D-FINE 系列发扬光大。同时希望您能随手留下一颗 ⭐,这是对我们最好的鼓励。 + +**Github Repo**: https://github.com/Peterande/D-FINE + +**Arxiv Paper**: https://arxiv.org/abs/2410.13842 + +--- + +### 🔍 探索 D-FINE 背后的关键创新 + +**D-FINE** 将基于 DETR 的目标检测器中的回归任务重新定义为 FDR,并在此基础上开发出了无感提升性能的自蒸馏机制 GO-LSD。下面对 FDR 和 GO-LSD 进行简要介绍: + +#### FDR (Fine-grained Distribution Refinement) 将检测框的生成过程拆解为: + +1. **初始框预测**:与传统 DETR 方法类似,**D-FINE** 的解码器 (decoder) 会在第一层将 object queries 转变为若干个初始的边界框,这些框不需要特别精准,仅作为一种初始化。 +2. **细粒度的分布优化**:**D-FINE** 解码层不会像传统方法那样直接解码出新的边界框,而是基于这些初始化的边界框,生成四组概率分布;并迭代地对这四组概率分布进行逐层优化。这些分布本质上是作为检测框的一种“细粒度中间表征”;配合精心设计的加权函数 W(n),**D-FINE** 能够通过微调这些表征来实现对初始边界框的调整,包含对其上下左右边缘进行细微的小幅度修正亦或是大幅度的搬移,具体的流程如图: + +

+ 精细分布优化过程 +

+ +为了方便阅读,我们不在此赘述数学公式及帮助优化的损失函数 Fine-Grained Localization (FGL) Loss,有兴趣的可以根据原文推导。 + +将边界框回归任务重新定义为 FDR 的主要优势在于: + +1. **简化的监督**:在使用传统的 L1 损失、IoU 损失优化检测框的同时,可以额外用标签和预测结果之间的“残差”来约束这些中间态的概率分布。这使每个解码层 (decoder layer) 能够更有效地关注并解决其当前面临的定位误差,随着层数加深,其优化目标变得越来越简单,从而简化了整体优化过程。 + +2. **复杂场景下的鲁棒性**:这些概率分布的值本质上代表了对每个边界“微调”的自信程度。这使 **D-FINE** 能够在不同网络深度独立建模每个边界的不确定性,从而在遮挡、运动模糊和低光照等复杂的实际场景下表现出更强的鲁棒性,相比直接回归四个固定值要更为稳健。 + +3. **灵活的优化机制**:概率分布通过加权求和转化为最终的边界框偏移值。精心设计的加权函数确保在初始框准确时进行细微调整,而在必要时则提供较大的修正。 + +4. **研究潜力与可扩展性**:FDR 通过将回归任务转变为同分类任务一致的概率分布预测问题,不仅提高了与其他任务的兼容性,还使得目标检测模型可以受益于知识蒸馏、多任务学习和分布优化等更多领域的创新,为未来的研究打开了新的大门。 + +--- + +#### GO-LSD (Global Optimal Localization Self-Distillation) 将知识蒸馏无痛应用到 FDR 框架检测器 + +根据上文,搭载 FDR 框架的目标检测器满足了以下两点: + +1. **能够实现知识传递**:Hinton 早在 *"Distilling the Knowledge in a Neural Network"* 一文中就说过:概率即“知识”;网络输出变成了概率分布,而概率分布携带定位知识 (Localization Knowledge),而通过计算 KLD 损失,可以将这些“知识”从深层传递到浅层。这是传统固定框表示(狄拉克 δ 函数)无法实现的。 + +2. **一致的优化目标**:由于 FDR 架构中每一个解码层都共享一个共同目标:减少初始边界框与真实边界框之间的残差;因此最后一层生成的精确概率分布可以作为前面每一层的最终目标,并通过蒸馏引导前几层。 + +于是,基于 FDR,我们提出了 GO-LSD(全局最优定位自蒸馏)。通过在网络层间实现定位知识蒸馏,进一步扩展了 **D-FINE** 的能力,具体流程如图: + +

+ GO-LSD过程 +

+ +同样的,为了方便阅读,我们不在此赘述数学公式及帮助优化的损失函数 Decoupled Distillation Focal (DDF) Loss,有兴趣的可以根据原文推导。 + +这产生了一种双赢的协同效应:随着训练的进行,最后一层的预测变得越来越准确,其生成的软标签能够更好地帮助前几层提高预测准确性。反过来,前几层学会更快地定位到准确位置,简化了深层的优化任务,进一步提高了整体准确性。 + +--- + +### D-FINE 预测的可视化 + +以下可视化展示了 **D-FINE** 在各种复杂检测场景中的预测结果。这些场景包括遮挡、低光照、运动模糊、景深效果和密集场景。尽管面对这些挑战,**D-FINE** 依然能够产生准确的定位结果。 + +

+ D-FINE在复杂场景中的预测 +

+ +同时下面给出的可视化结果展示了第一层和最后一层的预测结果、对应四条边的分布、以及加权后的分布。可以看到,预测框的定位会随着分布的优化而变得更加精准。 + +

+ +

+ +--- + +### 常见问题解答 + +#### 问题1:FDR 和 GO-LSD 会带来更多的推理成本吗? + +**回答**:并不会,FDR 和原始的预测几乎没有在速度、参数量和计算复杂度上的任何区别,完全是无感替换。 + +#### 问题2:FDR 和 GO-LSD 会带来更多的训练成本吗? + +**回答**:训练成本的增加主要来源于如何生成分布的标签。我们已经对该过程进行了优化,将额外训练时长和显存占用控制在了 6% 和 2%,几乎无感。 + +#### 问题3:D-FINE 为什么会比 RT-DETR 系列更快、更轻量? + +**回答**:直接应用 FDR 和 GO-LSD 只会显著提高性能,并不会让网络更快、更轻。所以我们对 RT-DETR 进行了一系列的轻量化处理,这些处理带来了性能的下降,但我们的方法弥补了这些损失,实现了速度-参数-计算量-性能的完美平衡。 diff --git a/src/zoo/dfine/box_ops.py b/src/zoo/dfine/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8ad3c791b91bea4c8061d775610b37ed9e4512 --- /dev/null +++ b/src/zoo/dfine/box_ops.py @@ -0,0 +1,93 @@ +""" +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +https://github.com/facebookresearch/detr/blob/main/util/box_ops.py +""" + +import torch +from torch import Tensor +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [ + (x_c - 0.5 * w.clamp(min=0.0)), + (y_c - 0.5 * h.clamp(min=0.0)), + (x_c + 0.5 * w.clamp(min=0.0)), + (y_c + 0.5 * h.clamp(min=0.0)), + ] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x: Tensor) -> Tensor: + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1: Tensor, boxes2: Tensor): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = masks * x.unsqueeze(0) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = masks * y.unsqueeze(0) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/src/zoo/dfine/denoising.py b/src/zoo/dfine/denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..c52cf3229e83f81ab5b4958381c106cd25c50c18 --- /dev/null +++ b/src/zoo/dfine/denoising.py @@ -0,0 +1,121 @@ +"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. +Modifications Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +""" + +import torch + +from .box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh +from .utils import inverse_sigmoid + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """cnd""" + if num_denoising <= 0: + return None, None, None, None + + num_gts = [len(t["labels"]) for t in targets] + device = targets[0]["labels"].device + + max_gt_num = max(num_gts) + if max_gt_num == 0: + dn_meta = {"dn_positive_idx": None, "dn_num_group": 0, "dn_num_split": [0, num_queries]} + return None, None, None, dn_meta + + num_group = num_denoising // max_gt_num + num_group = 1 if num_group == 0 else num_group + # pad gt to max_num of a batch + bs = len(num_gts) + + input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device) + + for i in range(bs): + num_gt = num_gts[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_group]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) + # positive and negative mask + negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts]) + # total denoising queries + num_denoising = int(max_gt_num * 2 * num_group) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = box_cxcywh_to_xyxy(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + # shrink_mask = torch.zeros_like(rand_sign) + # shrink_mask[:, :, :2] = (rand_sign[:, :, :2] == 1) # rand_sign == 1 → (x1, y1) ↘ → smaller bbox + # shrink_mask[:, :, 2:] = (rand_sign[:, :, 2:] == -1) # rand_sign == -1 → (x2, y2) ↖ → smaller bbox + # mask = rand_part > (upper_bound / (upper_bound+1)) + # # this is to make sure the dn bbox can be reversed to the original bbox by dfine head. + # rand_sign = torch.where((shrink_mask * (1 - negative_gt_mask) * mask).bool(), \ + # rand_sign * upper_bound / (upper_bound+1) / rand_part, rand_sign) + known_bbox += rand_sign * rand_part * diff + known_bbox = torch.clip(known_bbox, min=0.0, max=1.0) + input_query_bbox = box_xyxy_to_cxcywh(known_bbox) + input_query_bbox[input_query_bbox < 0] *= -1 + input_query_bbox_unact = inverse_sigmoid(input_query_bbox) + + input_query_logits = class_embed(input_query_class) + + tgt_size = num_denoising + num_queries + attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising:, :num_denoising] = True + + # reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), + max_gt_num * 2 * (i + 1) : num_denoising, + ] = True + if i == num_group - 1: + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * i * 2] = True + else: + attn_mask[ + max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), + max_gt_num * 2 * (i + 1) : num_denoising, + ] = True + attn_mask[max_gt_num * 2 * i : max_gt_num * 2 * (i + 1), : max_gt_num * 2 * i] = True + + dn_meta = { + "dn_positive_idx": dn_positive_idx, + "dn_num_group": num_group, + "dn_num_split": [num_denoising, num_queries], + } + + # print(input_query_class.shape) # torch.Size([4, 196, 256]) + # print(input_query_bbox.shape) # torch.Size([4, 196, 4]) + # print(attn_mask.shape) # torch.Size([496, 496]) + + return input_query_logits, input_query_bbox_unact, attn_mask, dn_meta diff --git a/src/zoo/dfine/dfine.py b/src/zoo/dfine/dfine.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb1f8f54bbe856b37666a9a70e1ccf41ba10d4a --- /dev/null +++ b/src/zoo/dfine/dfine.py @@ -0,0 +1,47 @@ +""" +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +""" + +import torch.nn as nn + +from ...core import register + +__all__ = [ + "DFINE", +] + + +@register() +class DFINE(nn.Module): + __inject__ = [ + "backbone", + "encoder", + "decoder", + ] + + def __init__( + self, + backbone: nn.Module, + encoder: nn.Module, + decoder: nn.Module, + ): + super().__init__() + self.backbone = backbone + self.decoder = decoder + self.encoder = encoder + + def forward(self, x, targets=None): + x = self.backbone(x) + x = self.encoder(x) + x = self.decoder(x, targets) + + return x + + def deploy( + self, + ): + self.eval() + for m in self.modules(): + if hasattr(m, "convert_to_deploy"): + m.convert_to_deploy() + return self diff --git a/src/zoo/dfine/dfine_criterion.py b/src/zoo/dfine/dfine_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..208f9f340f207aa67f16112cb4141a21b7197a32 --- /dev/null +++ b/src/zoo/dfine/dfine_criterion.py @@ -0,0 +1,525 @@ +""" +D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +--------------------------------------------------------------------------------- +Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright (c) 2023 lyuwenyu. All Rights Reserved. +""" + +import copy + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from ...core import register +from ...misc.dist_utils import get_world_size, is_dist_available_and_initialized +from .box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou +from .dfine_utils import bbox2distance + + +@register() +class DFINECriterion(nn.Module): + """This class computes the loss for D-FINE.""" + + __share__ = [ + "num_classes", + ] + __inject__ = [ + "matcher", + ] + + def __init__( + self, + matcher, + weight_dict, + losses, + alpha=0.2, + gamma=2.0, + num_classes=80, + reg_max=32, + boxes_weight_format=None, + share_matched_indices=False, + ): + """Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals. + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + num_classes: number of object categories, omitting the special no-object category. + reg_max (int): Max number of the discrete bins in D-FINE. + boxes_weight_format: format for boxes weight (iou, ). + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.boxes_weight_format = boxes_weight_format + self.share_matched_indices = share_matched_indices + self.alpha = alpha + self.gamma = gamma + self.fgl_targets, self.fgl_targets_dn = None, None + self.own_targets, self.own_targets_dn = None, None + self.reg_max = reg_max + self.num_pos, self.num_neg = None, None + + def loss_labels_focal(self, outputs, targets, indices, num_boxes): + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + loss = torchvision.ops.sigmoid_focal_loss( + src_logits, target, self.alpha, self.gamma, reduction="none" + ) + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + + return {"loss_focal": loss} + + def loss_labels_vfl(self, outputs, targets, indices, num_boxes, values=None): + assert "pred_boxes" in outputs + idx = self._get_src_permutation_idx(indices) + if values is None: + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)) + ious = torch.diag(ious).detach() + else: + ious = values + + src_logits = outputs["pred_logits"] + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_o[idx] = ious.to(target_score_o.dtype) + target_score = target_score_o.unsqueeze(-1) * target + + pred_score = F.sigmoid(src_logits).detach() + weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score + + loss = F.binary_cross_entropy_with_logits( + src_logits, target_score, weight=weight, reduction="none" + ) + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {"loss_vfl": loss} + + def loss_boxes(self, outputs, targets, indices, num_boxes, boxes_weight=None): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert "pred_boxes" in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + losses = {} + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag( + generalized_box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)) + ) + loss_giou = loss_giou if boxes_weight is None else loss_giou * boxes_weight + losses["loss_giou"] = loss_giou.sum() / num_boxes + + return losses + + def loss_local(self, outputs, targets, indices, num_boxes, T=5): + """Compute Fine-Grained Localization (FGL) Loss + and Decoupled Distillation Focal (DDF) Loss.""" + + losses = {} + if "pred_corners" in outputs: + idx = self._get_src_permutation_idx(indices) + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + + pred_corners = outputs["pred_corners"][idx].reshape(-1, (self.reg_max + 1)) + ref_points = outputs["ref_points"][idx].detach() + with torch.no_grad(): + if self.fgl_targets_dn is None and "is_dn" in outputs: + self.fgl_targets_dn = bbox2distance( + ref_points, + box_cxcywh_to_xyxy(target_boxes), + self.reg_max, + outputs["reg_scale"], + outputs["up"], + ) + if self.fgl_targets is None and "is_dn" not in outputs: + self.fgl_targets = bbox2distance( + ref_points, + box_cxcywh_to_xyxy(target_boxes), + self.reg_max, + outputs["reg_scale"], + outputs["up"], + ) + + target_corners, weight_right, weight_left = ( + self.fgl_targets_dn if "is_dn" in outputs else self.fgl_targets + ) + + ious = torch.diag( + box_iou( + box_cxcywh_to_xyxy(outputs["pred_boxes"][idx]), box_cxcywh_to_xyxy(target_boxes) + )[0] + ) + weight_targets = ious.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach() + + losses["loss_fgl"] = self.unimodal_distribution_focal_loss( + pred_corners, + target_corners, + weight_right, + weight_left, + weight_targets, + avg_factor=num_boxes, + ) + + if "teacher_corners" in outputs: + pred_corners = outputs["pred_corners"].reshape(-1, (self.reg_max + 1)) + target_corners = outputs["teacher_corners"].reshape(-1, (self.reg_max + 1)) + if torch.equal(pred_corners, target_corners): + losses["loss_ddf"] = pred_corners.sum() * 0 + else: + weight_targets_local = outputs["teacher_logits"].sigmoid().max(dim=-1)[0] + + mask = torch.zeros_like(weight_targets_local, dtype=torch.bool) + mask[idx] = True + mask = mask.unsqueeze(-1).repeat(1, 1, 4).reshape(-1) + + weight_targets_local[idx] = ious.reshape_as(weight_targets_local[idx]).to( + weight_targets_local.dtype + ) + weight_targets_local = ( + weight_targets_local.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach() + ) + + loss_match_local = ( + weight_targets_local + * (T**2) + * ( + nn.KLDivLoss(reduction="none")( + F.log_softmax(pred_corners / T, dim=1), + F.softmax(target_corners.detach() / T, dim=1), + ) + ).sum(-1) + ) + if "is_dn" not in outputs: + batch_scale = ( + 8 / outputs["pred_boxes"].shape[0] + ) # Avoid the influence of batch size per GPU + self.num_pos, self.num_neg = ( + (mask.sum() * batch_scale) ** 0.5, + ((~mask).sum() * batch_scale) ** 0.5, + ) + loss_match_local1 = loss_match_local[mask].mean() if mask.any() else 0 + loss_match_local2 = loss_match_local[~mask].mean() if (~mask).any() else 0 + losses["loss_ddf"] = ( + loss_match_local1 * self.num_pos + loss_match_local2 * self.num_neg + ) / (self.num_pos + self.num_neg) + + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def _get_go_indices(self, indices, indices_aux_list): + """Get a matching union set across all decoder layers.""" + results = [] + for indices_aux in indices_aux_list: + indices = [ + (torch.cat([idx1[0], idx2[0]]), torch.cat([idx1[1], idx2[1]])) + for idx1, idx2 in zip(indices.copy(), indices_aux.copy()) + ] + + for ind in [torch.cat([idx[0][:, None], idx[1][:, None]], 1) for idx in indices]: + unique, counts = torch.unique(ind, return_counts=True, dim=0) + count_sort_indices = torch.argsort(counts, descending=True) + unique_sorted = unique[count_sort_indices] + column_to_row = {} + for idx in unique_sorted: + row_idx, col_idx = idx[0].item(), idx[1].item() + if row_idx not in column_to_row: + column_to_row[row_idx] = col_idx + final_rows = torch.tensor(list(column_to_row.keys()), device=ind.device) + final_cols = torch.tensor(list(column_to_row.values()), device=ind.device) + results.append((final_rows.long(), final_cols.long())) + return results + + def _clear_cache(self): + self.fgl_targets, self.fgl_targets_dn = None, None + self.own_targets, self.own_targets_dn = None, None + self.num_pos, self.num_neg = None, None + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + "boxes": self.loss_boxes, + "focal": self.loss_labels_focal, + "vfl": self.loss_labels_vfl, + "local": self.loss_local, + } + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets, **kwargs): + """This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if "aux" not in k} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets)["indices"] + self._clear_cache() + + # Get the matching union set across all decoder layers. + if "aux_outputs" in outputs: + indices_aux_list, cached_indices, cached_indices_enc = [], [], [] + for i, aux_outputs in enumerate(outputs["aux_outputs"] + [outputs["pre_outputs"]]): + indices_aux = self.matcher(aux_outputs, targets)["indices"] + cached_indices.append(indices_aux) + indices_aux_list.append(indices_aux) + for i, aux_outputs in enumerate(outputs["enc_aux_outputs"]): + indices_enc = self.matcher(aux_outputs, targets)["indices"] + cached_indices_enc.append(indices_enc) + indices_aux_list.append(indices_enc) + indices_go = self._get_go_indices(indices, indices_aux_list) + + num_boxes_go = sum(len(x[0]) for x in indices_go) + num_boxes_go = torch.as_tensor( + [num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device + ) + if is_dist_available_and_initialized(): + torch.distributed.all_reduce(num_boxes_go) + num_boxes_go = torch.clamp(num_boxes_go / get_world_size(), min=1).item() + else: + assert "aux_outputs" in outputs, "" + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor( + [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device + ) + if is_dist_available_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + indices_in = indices_go if loss in ["boxes", "local"] else indices + num_boxes_in = num_boxes_go if loss in ["boxes", "local"] else num_boxes + meta = self.get_loss_meta_info(loss, outputs, targets, indices_in) + l_dict = self.get_loss(loss, outputs, targets, indices_in, num_boxes_in, **meta) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + losses.update(l_dict) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "aux_outputs" in outputs: + for i, aux_outputs in enumerate(outputs["aux_outputs"]): + aux_outputs["up"], aux_outputs["reg_scale"] = outputs["up"], outputs["reg_scale"] + for loss in self.losses: + indices_in = indices_go if loss in ["boxes", "local"] else cached_indices[i] + num_boxes_in = num_boxes_go if loss in ["boxes", "local"] else num_boxes + meta = self.get_loss_meta_info(loss, aux_outputs, targets, indices_in) + l_dict = self.get_loss( + loss, aux_outputs, targets, indices_in, num_boxes_in, **meta + ) + + l_dict = { + k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict + } + l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of auxiliary traditional head output at first decoder layer. + if "pre_outputs" in outputs: + aux_outputs = outputs["pre_outputs"] + for loss in self.losses: + indices_in = indices_go if loss in ["boxes", "local"] else cached_indices[-1] + num_boxes_in = num_boxes_go if loss in ["boxes", "local"] else num_boxes + meta = self.get_loss_meta_info(loss, aux_outputs, targets, indices_in) + l_dict = self.get_loss(loss, aux_outputs, targets, indices_in, num_boxes_in, **meta) + + l_dict = { + k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict + } + l_dict = {k + "_pre": v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of encoder auxiliary losses. + if "enc_aux_outputs" in outputs: + assert "enc_meta" in outputs, "" + class_agnostic = outputs["enc_meta"]["class_agnostic"] + if class_agnostic: + orig_num_classes = self.num_classes + self.num_classes = 1 + enc_targets = copy.deepcopy(targets) + for t in enc_targets: + t["labels"] = torch.zeros_like(t["labels"]) + else: + enc_targets = targets + + for i, aux_outputs in enumerate(outputs["enc_aux_outputs"]): + for loss in self.losses: + indices_in = indices_go if loss == "boxes" else cached_indices_enc[i] + num_boxes_in = num_boxes_go if loss == "boxes" else num_boxes + meta = self.get_loss_meta_info(loss, aux_outputs, enc_targets, indices_in) + l_dict = self.get_loss( + loss, aux_outputs, enc_targets, indices_in, num_boxes_in, **meta + ) + l_dict = { + k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict + } + l_dict = {k + f"_enc_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + if class_agnostic: + self.num_classes = orig_num_classes + + # In case of cdn auxiliary losses. For dfine + if "dn_outputs" in outputs: + assert "dn_meta" in outputs, "" + indices_dn = self.get_cdn_matched_indices(outputs["dn_meta"], targets) + dn_num_boxes = num_boxes * outputs["dn_meta"]["dn_num_group"] + dn_num_boxes = dn_num_boxes if dn_num_boxes > 0 else 1 + + for i, aux_outputs in enumerate(outputs["dn_outputs"]): + aux_outputs["is_dn"] = True + aux_outputs["up"], aux_outputs["reg_scale"] = outputs["up"], outputs["reg_scale"] + for loss in self.losses: + meta = self.get_loss_meta_info(loss, aux_outputs, targets, indices_dn) + l_dict = self.get_loss( + loss, aux_outputs, targets, indices_dn, dn_num_boxes, **meta + ) + l_dict = { + k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict + } + l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of auxiliary traditional head output at first decoder layer. + if "dn_pre_outputs" in outputs: + aux_outputs = outputs["dn_pre_outputs"] + for loss in self.losses: + meta = self.get_loss_meta_info(loss, aux_outputs, targets, indices_dn) + l_dict = self.get_loss( + loss, aux_outputs, targets, indices_dn, dn_num_boxes, **meta + ) + l_dict = { + k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict + } + l_dict = {k + "_dn_pre": v for k, v in l_dict.items()} + losses.update(l_dict) + + # For debugging Objects365 pre-train. + losses = {k: torch.nan_to_num(v, nan=0.0) for k, v in losses.items()} + return losses + + def get_loss_meta_info(self, loss, outputs, targets, indices): + if self.boxes_weight_format is None: + return {} + + src_boxes = outputs["pred_boxes"][self._get_src_permutation_idx(indices)] + target_boxes = torch.cat([t["boxes"][j] for t, (_, j) in zip(targets, indices)], dim=0) + + if self.boxes_weight_format == "iou": + iou, _ = box_iou( + box_cxcywh_to_xyxy(src_boxes.detach()), box_cxcywh_to_xyxy(target_boxes) + ) + iou = torch.diag(iou) + elif self.boxes_weight_format == "giou": + iou = torch.diag( + generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes.detach()), box_cxcywh_to_xyxy(target_boxes) + ) + ) + else: + raise AttributeError() + + if loss in ("boxes",): + meta = {"boxes_weight": iou} + elif loss in ("vfl",): + meta = {"values": iou} + else: + meta = {} + + return meta + + @staticmethod + def get_cdn_matched_indices(dn_meta, targets): + """get_cdn_matched_indices""" + dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] + num_gts = [len(t["labels"]) for t in targets] + device = targets[0]["labels"].device + + dn_match_indices = [] + for i, num_gt in enumerate(num_gts): + if num_gt > 0: + gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device) + gt_idx = gt_idx.tile(dn_num_group) + assert len(dn_positive_idx[i]) == len(gt_idx) + dn_match_indices.append((dn_positive_idx[i], gt_idx)) + else: + dn_match_indices.append( + ( + torch.zeros(0, dtype=torch.int64, device=device), + torch.zeros(0, dtype=torch.int64, device=device), + ) + ) + + return dn_match_indices + + def feature_loss_function(self, fea, target_fea): + loss = (fea - target_fea) ** 2 * ((fea > 0) | (target_fea > 0)).float() + return torch.abs(loss) + + def unimodal_distribution_focal_loss( + self, pred, label, weight_right, weight_left, weight=None, reduction="sum", avg_factor=None + ): + dis_left = label.long() + dis_right = dis_left + 1 + + loss = F.cross_entropy(pred, dis_left, reduction="none") * weight_left.reshape( + -1 + ) + F.cross_entropy(pred, dis_right, reduction="none") * weight_right.reshape(-1) + + if weight is not None: + weight = weight.float() + loss = loss * weight + + if avg_factor is not None: + loss = loss.sum() / avg_factor + elif reduction == "mean": + loss = loss.mean() + elif reduction == "sum": + loss = loss.sum() + + return loss + + def get_gradual_steps(self, outputs): + num_layers = len(outputs["aux_outputs"]) + 1 if "aux_outputs" in outputs else 1 + step = 0.5 / (num_layers - 1) + opt_list = [0.5 + step * i for i in range(num_layers)] if num_layers > 1 else [1] + return opt_list diff --git a/src/zoo/dfine/dfine_decoder.py b/src/zoo/dfine/dfine_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..32c7ecaa43a395920df88300728501e879c71af5 --- /dev/null +++ b/src/zoo/dfine/dfine_decoder.py @@ -0,0 +1,959 @@ +""" +D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +--------------------------------------------------------------------------------- +Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright (c) 2023 lyuwenyu. All Rights Reserved. +""" + +import copy +import functools +import math +from collections import OrderedDict +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +from ...core import register +from .denoising import get_contrastive_denoising_training_group +from .dfine_utils import distance2bbox, weighting_function +from .utils import ( + bias_init_with_prob, + deformable_attention_core_func_v2, + get_activation, + inverse_sigmoid, +) + +__all__ = ["DFINETransformer"] + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.act = get_activation(act) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MSDeformableAttention(nn.Module): + def __init__( + self, + embed_dim=256, + num_heads=8, + num_levels=4, + num_points=4, + method="default", + offset_scale=0.5, + ): + """Multi-Scale Deformable Attention""" + super(MSDeformableAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_levels = num_levels + self.offset_scale = offset_scale + + if isinstance(num_points, list): + assert len(num_points) == num_levels, "" + num_points_list = num_points + else: + num_points_list = [num_points for _ in range(num_levels)] + + self.num_points_list = num_points_list + + num_points_scale = [1 / n for n in num_points_list for _ in range(n)] + self.register_buffer( + "num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32) + ) + + self.total_points = num_heads * sum(num_points_list) + self.method = method + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2) + self.attention_weights = nn.Linear(embed_dim, self.total_points) + + self.ms_deformable_attn_core = functools.partial( + deformable_attention_core_func_v2, method=self.method + ) + + self._reset_parameters() + + if method == "discrete": + for p in self.sampling_offsets.parameters(): + p.requires_grad = False + + def _reset_parameters(self): + # sampling_offsets + init.constant_(self.sampling_offsets.weight, 0) + thetas = torch.arange(self.num_heads, dtype=torch.float32) * ( + 2.0 * math.pi / self.num_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape( + 1, -1, 1 + ) + grid_init *= scaling + self.sampling_offsets.bias.data[...] = grid_init.flatten() + + # attention_weights + init.constant_(self.attention_weights.weight, 0) + init.constant_(self.attention_weights.bias, 0) + + def forward( + self, + query: torch.Tensor, + reference_points: torch.Tensor, + value: torch.Tensor, + value_spatial_shapes: List[int], + ): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + + sampling_offsets: torch.Tensor = self.sampling_offsets(query) + sampling_offsets = sampling_offsets.reshape( + bs, Len_q, self.num_heads, sum(self.num_points_list), 2 + ) + + attention_weights = self.attention_weights(query).reshape( + bs, Len_q, self.num_heads, sum(self.num_points_list) + ) + attention_weights = F.softmax(attention_weights, dim=-1) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.tensor(value_spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2) + sampling_locations = ( + reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + + sampling_offsets / offset_normalizer + ) + elif reference_points.shape[-1] == 4: + # reference_points [8, 480, None, 1, 4] + # sampling_offsets [8, 480, 8, 12, 2] + num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1) + offset = ( + sampling_offsets + * num_points_scale + * reference_points[:, :, None, :, 2:] + * self.offset_scale + ) + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format( + reference_points.shape[-1] + ) + ) + + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list + ) + + return output + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + n_levels=4, + n_points=4, + cross_attn_method="default", + layer_scale=None, + ): + super(TransformerDecoderLayer, self).__init__() + if layer_scale is not None: + dim_feedforward = round(layer_scale * dim_feedforward) + d_model = round(layer_scale * d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # cross attention + self.cross_attn = MSDeformableAttention( + d_model, n_head, n_levels, n_points, method=cross_attn_method + ) + self.dropout2 = nn.Dropout(dropout) + + # gate + self.gateway = Gate(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = get_activation(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + self._reset_parameters() + + def _reset_parameters(self): + init.xavier_uniform_(self.linear1.weight) + init.xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + def forward( + self, target, reference_points, value, spatial_shapes, attn_mask=None, query_pos_embed=None + ): + # self attention + q = k = self.with_pos_embed(target, query_pos_embed) + + target2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask) + target = target + self.dropout1(target2) + target = self.norm1(target) + + # cross attention + target2 = self.cross_attn( + self.with_pos_embed(target, query_pos_embed), reference_points, value, spatial_shapes + ) + + target = self.gateway(target, self.dropout2(target2)) + + # ffn + target2 = self.forward_ffn(target) + target = target + self.dropout4(target2) + target = self.norm3(target.clamp(min=-65504, max=65504)) + + return target + + +class Gate(nn.Module): + def __init__(self, d_model): + super(Gate, self).__init__() + self.gate = nn.Linear(2 * d_model, 2 * d_model) + bias = bias_init_with_prob(0.5) + init.constant_(self.gate.bias, bias) + init.constant_(self.gate.weight, 0) + self.norm = nn.LayerNorm(d_model) + + def forward(self, x1, x2): + gate_input = torch.cat([x1, x2], dim=-1) + gates = torch.sigmoid(self.gate(gate_input)) + gate1, gate2 = gates.chunk(2, dim=-1) + return self.norm(gate1 * x1 + gate2 * x2) + + +class Integral(nn.Module): + """ + A static layer that calculates integral results from a distribution. + + This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`, + where Pr(n) is the softmax probability vector representing the discrete + distribution, and W(n) is the non-uniform Weighting Function. + + Args: + reg_max (int): Max number of the discrete bins. Default is 32. + It can be adjusted based on the dataset or task requirements. + """ + + def __init__(self, reg_max=32): + super(Integral, self).__init__() + self.reg_max = reg_max + + def forward(self, x, project): + shape = x.shape + x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1) + x = F.linear(x, project.to(x.device)).reshape(-1, 4) + return x.reshape(list(shape[:-1]) + [-1]) + + +class LQE(nn.Module): + def __init__(self, k, hidden_dim, num_layers, reg_max): + super(LQE, self).__init__() + self.k = k + self.reg_max = reg_max + self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers) + init.constant_(self.reg_conf.layers[-1].bias, 0) + init.constant_(self.reg_conf.layers[-1].weight, 0) + + def forward(self, scores, pred_corners): + B, L, _ = pred_corners.size() + prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), dim=-1) + prob_topk, _ = prob.topk(self.k, dim=-1) + stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) + quality_score = self.reg_conf(stat.reshape(B, L, -1)) + return scores + quality_score + + +class TransformerDecoder(nn.Module): + """ + Transformer Decoder implementing Fine-grained Distribution Refinement (FDR). + + This decoder refines object detection predictions through iterative updates across multiple layers, + utilizing attention mechanisms, location quality estimators, and distribution refinement techniques + to improve bounding box accuracy and robustness. + """ + + def __init__( + self, + hidden_dim, + decoder_layer, + decoder_layer_wide, + num_layers, + num_head, + reg_max, + reg_scale, + up, + eval_idx=-1, + layer_scale=2, + ): + super(TransformerDecoder, self).__init__() + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.layer_scale = layer_scale + self.num_head = num_head + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max + self.layers = nn.ModuleList( + [copy.deepcopy(decoder_layer) for _ in range(self.eval_idx + 1)] + + [copy.deepcopy(decoder_layer_wide) for _ in range(num_layers - self.eval_idx - 1)] + ) + self.lqe_layers = nn.ModuleList( + [copy.deepcopy(LQE(4, 64, 2, reg_max)) for _ in range(num_layers)] + ) + + def value_op(self, memory, value_proj, value_scale, memory_mask, memory_spatial_shapes): + """ + Preprocess values for MSDeformableAttention. + """ + value = value_proj(memory) if value_proj is not None else memory + value = F.interpolate(memory, size=value_scale) if value_scale is not None else value + if memory_mask is not None: + value = value * memory_mask.to(value.dtype).unsqueeze(-1) + value = value.reshape(value.shape[0], value.shape[1], self.num_head, -1) + split_shape = [h * w for h, w in memory_spatial_shapes] + return value.permute(0, 2, 3, 1).split(split_shape, dim=-1) + + def convert_to_deploy(self): + self.project = weighting_function(self.reg_max, self.up, self.reg_scale, deploy=True) + self.layers = self.layers[: self.eval_idx + 1] + self.lqe_layers = nn.ModuleList( + [nn.Identity()] * (self.eval_idx) + [self.lqe_layers[self.eval_idx]] + ) + + def forward( + self, + target, + ref_points_unact, + memory, + spatial_shapes, + bbox_head, + score_head, + query_pos_head, + pre_bbox_head, + integral, + up, + reg_scale, + attn_mask=None, + memory_mask=None, + dn_meta=None, + ): + output = target + output_detach = pred_corners_undetach = 0 + value = self.value_op(memory, None, None, memory_mask, spatial_shapes) + + dec_out_bboxes = [] + dec_out_logits = [] + dec_out_pred_corners = [] + dec_out_refs = [] + if not hasattr(self, "project"): + project = weighting_function(self.reg_max, up, reg_scale) + else: + project = self.project + + ref_points_detach = F.sigmoid(ref_points_unact) + + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = query_pos_head(ref_points_detach).clamp(min=-10, max=10) + + # TODO Adjust scale if needed for detachable wider layers + if i >= self.eval_idx + 1 and self.layer_scale > 1: + query_pos_embed = F.interpolate(query_pos_embed, scale_factor=self.layer_scale) + value = self.value_op( + memory, None, query_pos_embed.shape[-1], memory_mask, spatial_shapes + ) + output = F.interpolate(output, size=query_pos_embed.shape[-1]) + output_detach = output.detach() + + output = layer( + output, ref_points_input, value, spatial_shapes, attn_mask, query_pos_embed + ) + + if i == 0: + # Initial bounding box predictions with inverse sigmoid refinement + pre_bboxes = F.sigmoid(pre_bbox_head(output) + inverse_sigmoid(ref_points_detach)) + pre_scores = score_head[0](output) + ref_points_initial = pre_bboxes.detach() + + # Refine bounding box corners using FDR, integrating previous layer's corrections + pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach + inter_ref_bbox = distance2bbox( + ref_points_initial, integral(pred_corners, project), reg_scale + ) + + if self.training or i == self.eval_idx: + scores = score_head[i](output) + # Lqe does not affect the performance here. + scores = self.lqe_layers[i](scores, pred_corners) + dec_out_logits.append(scores) + dec_out_bboxes.append(inter_ref_bbox) + dec_out_pred_corners.append(pred_corners) + dec_out_refs.append(ref_points_initial) + + if not self.training: + break + + pred_corners_undetach = pred_corners + ref_points_detach = inter_ref_bbox.detach() + output_detach = output.detach() + + return ( + torch.stack(dec_out_bboxes), + torch.stack(dec_out_logits), + torch.stack(dec_out_pred_corners), + torch.stack(dec_out_refs), + pre_bboxes, + pre_scores, + ) + + +@register() +class DFINETransformer(nn.Module): + __share__ = ["num_classes", "eval_spatial_size"] + + def __init__( + self, + num_classes=80, + hidden_dim=256, + num_queries=300, + feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_levels=3, + num_points=4, + nhead=8, + num_layers=6, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_query_content=False, + eval_spatial_size=None, + eval_idx=-1, + eps=1e-2, + aux_loss=True, + cross_attn_method="default", + query_select_method="default", + reg_max=32, + reg_scale=4.0, + layer_scale=1, + ): + super().__init__() + assert len(feat_channels) <= num_levels + assert len(feat_strides) == len(feat_channels) + + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + scaled_dim = round(layer_scale * hidden_dim) + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_layers = num_layers + self.eval_spatial_size = eval_spatial_size + self.aux_loss = aux_loss + self.reg_max = reg_max + + assert query_select_method in ("default", "one2many", "agnostic"), "" + assert cross_attn_method in ("default", "discrete"), "" + self.cross_attn_method = cross_attn_method + self.query_select_method = query_select_method + + # backbone feature projection + self._build_input_proj_layer(feat_channels) + + # Transformer module + self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False) + self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False) + decoder_layer = TransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_levels, + num_points, + cross_attn_method=cross_attn_method, + ) + decoder_layer_wide = TransformerDecoderLayer( + hidden_dim, + nhead, + dim_feedforward, + dropout, + activation, + num_levels, + num_points, + cross_attn_method=cross_attn_method, + layer_scale=layer_scale, + ) + self.decoder = TransformerDecoder( + hidden_dim, + decoder_layer, + decoder_layer_wide, + num_layers, + nhead, + reg_max, + self.reg_scale, + self.up, + eval_idx, + layer_scale, + ) + # denoising + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + if num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + num_classes + 1, hidden_dim, padding_idx=num_classes + ) + init.normal_(self.denoising_class_embed.weight[:-1]) + + # decoder embedding + self.learn_query_content = learn_query_content + if learn_query_content: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2) + + # if num_select_queries != self.num_queries: + # layer = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, activation='gelu') + # self.encoder = TransformerEncoder(layer, 1) + + self.enc_output = nn.Sequential( + OrderedDict( + [ + ("proj", nn.Linear(hidden_dim, hidden_dim)), + ( + "norm", + nn.LayerNorm( + hidden_dim, + ), + ), + ] + ) + ) + + if query_select_method == "agnostic": + self.enc_score_head = nn.Linear(hidden_dim, 1) + else: + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + + # decoder head + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + self.dec_score_head = nn.ModuleList( + [nn.Linear(hidden_dim, num_classes) for _ in range(self.eval_idx + 1)] + + [nn.Linear(scaled_dim, num_classes) for _ in range(num_layers - self.eval_idx - 1)] + ) + self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + self.dec_bbox_head = nn.ModuleList( + [ + MLP(hidden_dim, hidden_dim, 4 * (self.reg_max + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + MLP(scaled_dim, scaled_dim, 4 * (self.reg_max + 1), 3) + for _ in range(num_layers - self.eval_idx - 1) + ] + ) + self.integral = Integral(self.reg_max) + + # init encoder output anchors and valid_mask + if self.eval_spatial_size: + anchors, valid_mask = self._generate_anchors() + self.register_buffer("anchors", anchors) + self.register_buffer("valid_mask", valid_mask) + # init encoder output anchors and valid_mask + if self.eval_spatial_size: + self.anchors, self.valid_mask = self._generate_anchors() + + self._reset_parameters(feat_channels) + + def convert_to_deploy(self): + self.dec_score_head = nn.ModuleList( + [nn.Identity()] * (self.eval_idx) + [self.dec_score_head[self.eval_idx]] + ) + self.dec_bbox_head = nn.ModuleList( + [ + self.dec_bbox_head[i] if i <= self.eval_idx else nn.Identity() + for i in range(len(self.dec_bbox_head)) + ] + ) + + def _reset_parameters(self, feat_channels): + bias = bias_init_with_prob(0.01) + init.constant_(self.enc_score_head.bias, bias) + init.constant_(self.enc_bbox_head.layers[-1].weight, 0) + init.constant_(self.enc_bbox_head.layers[-1].bias, 0) + + init.constant_(self.pre_bbox_head.layers[-1].weight, 0) + init.constant_(self.pre_bbox_head.layers[-1].bias, 0) + + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + init.constant_(cls_.bias, bias) + if hasattr(reg_, "layers"): + init.constant_(reg_.layers[-1].weight, 0) + init.constant_(reg_.layers[-1].bias, 0) + + init.xavier_uniform_(self.enc_output[0].weight) + if self.learn_query_content: + init.xavier_uniform_(self.tgt_embed.weight) + init.xavier_uniform_(self.query_pos_head.layers[0].weight) + init.xavier_uniform_(self.query_pos_head.layers[1].weight) + for m, in_channels in zip(self.input_proj, feat_channels): + if in_channels != self.hidden_dim: + init.xavier_uniform_(m[0].weight) + + def _build_input_proj_layer(self, feat_channels): + self.input_proj = nn.ModuleList() + for in_channels in feat_channels: + if in_channels == self.hidden_dim: + self.input_proj.append(nn.Identity()) + else: + self.input_proj.append( + nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), + ( + "norm", + nn.BatchNorm2d( + self.hidden_dim, + ), + ), + ] + ) + ) + ) + + in_channels = feat_channels[-1] + + for _ in range(self.num_levels - len(feat_channels)): + if in_channels == self.hidden_dim: + self.input_proj.append(nn.Identity()) + else: + self.input_proj.append( + nn.Sequential( + OrderedDict( + [ + ( + "conv", + nn.Conv2d( + in_channels, self.hidden_dim, 3, 2, padding=1, bias=False + ), + ), + ("norm", nn.BatchNorm2d(self.hidden_dim)), + ] + ) + ) + ) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats: List[torch.Tensor]): + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + for i, feat in enumerate(proj_feats): + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) + # [num_levels, 2] + spatial_shapes.append([h, w]) + + # [b, l, c] + feat_flatten = torch.concat(feat_flatten, 1) + return feat_flatten, spatial_shapes + + def _generate_anchors( + self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu" + ): + if spatial_shapes is None: + spatial_shapes = [] + eval_h, eval_w = self.eval_spatial_size + for s in self.feat_strides: + spatial_shapes.append([int(eval_h / s), int(eval_w / s)]) + + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + grid_xy = torch.stack([grid_x, grid_y], dim=-1) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype) + wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl) + lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4) + anchors.append(lvl_anchors) + + anchors = torch.concat(anchors, dim=1).to(device) + valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.inf) + + return anchors, valid_mask + + def _get_decoder_input( + self, memory: torch.Tensor, spatial_shapes, denoising_logits=None, denoising_bbox_unact=None + ): + # prepare input for decoder + if self.training or self.eval_spatial_size is None: + anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device) + else: + anchors = self.anchors + valid_mask = self.valid_mask + if memory.shape[0] > 1: + anchors = anchors.repeat(memory.shape[0], 1, 1) + + # memory = torch.where(valid_mask, memory, 0) + # TODO fix type error for onnx export + memory = valid_mask.to(memory.dtype) * memory + + output_memory: torch.Tensor = self.enc_output(memory) + enc_outputs_logits: torch.Tensor = self.enc_score_head(output_memory) + + enc_topk_bboxes_list, enc_topk_logits_list = [], [] + enc_topk_memory, enc_topk_logits, enc_topk_anchors = self._select_topk( + output_memory, enc_outputs_logits, anchors, self.num_queries + ) + + enc_topk_bbox_unact: torch.Tensor = self.enc_bbox_head(enc_topk_memory) + enc_topk_anchors + + if self.training: + enc_topk_bboxes = F.sigmoid(enc_topk_bbox_unact) + enc_topk_bboxes_list.append(enc_topk_bboxes) + enc_topk_logits_list.append(enc_topk_logits) + + # if self.num_select_queries != self.num_queries: + # raise NotImplementedError('') + + if self.learn_query_content: + content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1]) + else: + content = enc_topk_memory.detach() + + enc_topk_bbox_unact = enc_topk_bbox_unact.detach() + + if denoising_bbox_unact is not None: + enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) + content = torch.concat([denoising_logits, content], dim=1) + + return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list + + def _select_topk( + self, + memory: torch.Tensor, + outputs_logits: torch.Tensor, + outputs_anchors_unact: torch.Tensor, + topk: int, + ): + if self.query_select_method == "default": + _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1) + + elif self.query_select_method == "one2many": + _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1) + topk_ind = topk_ind // self.num_classes + + elif self.query_select_method == "agnostic": + _, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1) + + topk_ind: torch.Tensor + + topk_anchors = outputs_anchors_unact.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_anchors_unact.shape[-1]) + ) + + topk_logits = ( + outputs_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1]) + ) + if self.training + else None + ) + + topk_memory = memory.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1]) + ) + + return topk_memory, topk_logits, topk_anchors + + def forward(self, feats, targets=None): + # input projection and embedding + memory, spatial_shapes = self._get_encoder_input(feats) + + # prepare denoising training + if self.training and self.num_denoising > 0: + denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = ( + get_contrastive_denoising_training_group( + targets, + self.num_classes, + self.num_queries, + self.denoising_class_embed, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=1.0, + ) + ) + else: + denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None + + init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = ( + self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact) + ) + + # decoder + out_bboxes, out_logits, out_corners, out_refs, pre_bboxes, pre_logits = self.decoder( + init_ref_contents, + init_ref_points_unact, + memory, + spatial_shapes, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + self.pre_bbox_head, + self.integral, + self.up, + self.reg_scale, + attn_mask=attn_mask, + dn_meta=dn_meta, + ) + + if self.training and dn_meta is not None: + dn_pre_logits, pre_logits = torch.split(pre_logits, dn_meta["dn_num_split"], dim=1) + dn_pre_bboxes, pre_bboxes = torch.split(pre_bboxes, dn_meta["dn_num_split"], dim=1) + dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta["dn_num_split"], dim=2) + dn_out_logits, out_logits = torch.split(out_logits, dn_meta["dn_num_split"], dim=2) + + dn_out_corners, out_corners = torch.split(out_corners, dn_meta["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(out_refs, dn_meta["dn_num_split"], dim=2) + + if self.training: + out = { + "pred_logits": out_logits[-1], + "pred_boxes": out_bboxes[-1], + "pred_corners": out_corners[-1], + "ref_points": out_refs[-1], + "up": self.up, + "reg_scale": self.reg_scale, + } + else: + out = {"pred_logits": out_logits[-1], "pred_boxes": out_bboxes[-1]} + + if self.training and self.aux_loss: + out["aux_outputs"] = self._set_aux_loss2( + out_logits[:-1], + out_bboxes[:-1], + out_corners[:-1], + out_refs[:-1], + out_corners[-1], + out_logits[-1], + ) + out["enc_aux_outputs"] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list) + out["pre_outputs"] = {"pred_logits": pre_logits, "pred_boxes": pre_bboxes} + out["enc_meta"] = {"class_agnostic": self.query_select_method == "agnostic"} + + if dn_meta is not None: + out["dn_outputs"] = self._set_aux_loss2( + dn_out_logits, + dn_out_bboxes, + dn_out_corners, + dn_out_refs, + dn_out_corners[-1], + dn_out_logits[-1], + ) + out["dn_pre_outputs"] = {"pred_logits": dn_pre_logits, "pred_boxes": dn_pre_bboxes} + out["dn_meta"] = dn_meta + + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @torch.jit.unused + def _set_aux_loss2( + self, + outputs_class, + outputs_coord, + outputs_corners, + outputs_ref, + teacher_corners=None, + teacher_logits=None, + ): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + { + "pred_logits": a, + "pred_boxes": b, + "pred_corners": c, + "ref_points": d, + "teacher_corners": teacher_corners, + "teacher_logits": teacher_logits, + } + for a, b, c, d in zip(outputs_class, outputs_coord, outputs_corners, outputs_ref) + ] diff --git a/src/zoo/dfine/dfine_utils.py b/src/zoo/dfine/dfine_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e2abbff433624d5facb4c9a87f41d6dec1f5293e --- /dev/null +++ b/src/zoo/dfine/dfine_utils.py @@ -0,0 +1,169 @@ +""" +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +""" + +import torch + +from .box_ops import box_xyxy_to_cxcywh + + +def weighting_function(reg_max, up, reg_scale, deploy=False): + """ + Generates the non-uniform Weighting Function W(n) for bounding box regression. + + Args: + reg_max (int): Max number of the discrete bins. + up (Tensor): Controls upper bounds of the sequence, + where maximum offset is ±up * H / W. + reg_scale (float): Controls the curvature of the Weighting Function. + Larger values result in flatter weights near the central axis W(reg_max/2)=0 + and steeper weights at both ends. + deploy (bool): If True, uses deployment mode settings. + + Returns: + Tensor: Sequence of Weighting Function. + """ + if deploy: + upper_bound1 = (abs(up[0]) * abs(reg_scale)).item() + upper_bound2 = (abs(up[0]) * abs(reg_scale) * 2).item() + step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) + left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)] + right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] + values = ( + [-upper_bound2] + + left_values + + [torch.zeros_like(up[0][None])] + + right_values + + [upper_bound2] + ) + return torch.tensor(values, dtype=up.dtype, device=up.device) + else: + upper_bound1 = abs(up[0]) * abs(reg_scale) + upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 + step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) + left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)] + right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] + values = ( + [-upper_bound2] + + left_values + + [torch.zeros_like(up[0][None])] + + right_values + + [upper_bound2] + ) + return torch.cat(values, 0) + + +def translate_gt(gt, reg_max, reg_scale, up): + """ + Decodes bounding box ground truth (GT) values into distribution-based GT representations. + + This function maps continuous GT values into discrete distribution bins, which can be used + for regression tasks in object detection models. It calculates the indices of the closest + bins to each GT value and assigns interpolation weights to these bins based on their proximity + to the GT value. + + Args: + gt (Tensor): Ground truth bounding box values, shape (N, ). + reg_max (int): Maximum number of discrete bins for the distribution. + reg_scale (float): Controls the curvature of the Weighting Function. + up (Tensor): Controls the upper bounds of the Weighting Function. + + Returns: + Tuple[Tensor, Tensor, Tensor]: + - indices (Tensor): Index of the left bin closest to each GT value, shape (N, ). + - weight_right (Tensor): Weight assigned to the right bin, shape (N, ). + - weight_left (Tensor): Weight assigned to the left bin, shape (N, ). + """ + gt = gt.reshape(-1) + function_values = weighting_function(reg_max, up, reg_scale) + + # Find the closest left-side indices for each value + diffs = function_values.unsqueeze(0) - gt.unsqueeze(1) + mask = diffs <= 0 + closest_left_indices = torch.sum(mask, dim=1) - 1 + + # Calculate the weights for the interpolation + indices = closest_left_indices.float() + + weight_right = torch.zeros_like(indices) + weight_left = torch.zeros_like(indices) + + valid_idx_mask = (indices >= 0) & (indices < reg_max) + valid_indices = indices[valid_idx_mask].long() + + # Obtain distances + left_values = function_values[valid_indices] + right_values = function_values[valid_indices + 1] + + left_diffs = torch.abs(gt[valid_idx_mask] - left_values) + right_diffs = torch.abs(right_values - gt[valid_idx_mask]) + + # Valid weights + weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs) + weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask] + + # Invalid weights (out of range) + invalid_idx_mask_neg = indices < 0 + weight_right[invalid_idx_mask_neg] = 0.0 + weight_left[invalid_idx_mask_neg] = 1.0 + indices[invalid_idx_mask_neg] = 0.0 + + invalid_idx_mask_pos = indices >= reg_max + weight_right[invalid_idx_mask_pos] = 1.0 + weight_left[invalid_idx_mask_pos] = 0.0 + indices[invalid_idx_mask_pos] = reg_max - 0.1 + + return indices, weight_right, weight_left + + +def distance2bbox(points, distance, reg_scale): + """ + Decodes edge-distances into bounding box coordinates. + + Args: + points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h], + where (x, y) is the center and (w, h) are width and height. + distance (Tensor): (B, N, 4) or (N, 4), representing distances from the + point to the left, top, right, and bottom boundaries. + + reg_scale (float): Controls the curvature of the Weighting Function. + + Returns: + Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h]. + """ + reg_scale = abs(reg_scale) + x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) + y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) + x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) + y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) + + bboxes = torch.stack([x1, y1, x2, y2], -1) + + return box_xyxy_to_cxcywh(bboxes) + + +def bbox2distance(points, bbox, reg_max, reg_scale, up, eps=0.1): + """ + Converts bounding box coordinates to distances from a reference point. + + Args: + points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center. + bbox (Tensor): (n, 4) bounding boxes in "xyxy" format. + reg_max (float): Maximum bin value. + reg_scale (float): Controling curvarture of W(n). + up (Tensor): Controling upper bounds of W(n). + eps (float): Small value to ensure target < reg_max. + + Returns: + Tensor: Decoded distances. + """ + reg_scale = abs(reg_scale) + left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale + top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale + right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale + bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale + four_lens = torch.stack([left, top, right, bottom], -1) + four_lens, weight_right, weight_left = translate_gt(four_lens, reg_max, reg_scale, up) + if reg_max is not None: + four_lens = four_lens.clamp(min=0, max=reg_max - eps) + return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach() diff --git a/src/zoo/dfine/hybrid_encoder.py b/src/zoo/dfine/hybrid_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ec487d50a8d7d13bdfb1184232996ac59261d352 --- /dev/null +++ b/src/zoo/dfine/hybrid_encoder.py @@ -0,0 +1,488 @@ +""" +D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +--------------------------------------------------------------------------------- +Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright (c) 2023 lyuwenyu. All Rights Reserved. +""" + +import copy +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...core import register +from .utils import get_activation + +__all__ = ["HybridEncoder"] + + +class ConvNormLayer_fuse(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None): + super().__init__() + padding = (kernel_size - 1) // 2 if padding is None else padding + self.conv = nn.Conv2d( + ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias + ) + self.norm = nn.BatchNorm2d(ch_out) + self.act = nn.Identity() if act is None else get_activation(act) + self.ch_in, self.ch_out, self.kernel_size, self.stride, self.g, self.padding, self.bias = ( + ch_in, + ch_out, + kernel_size, + stride, + g, + padding, + bias, + ) + + def forward(self, x): + if hasattr(self, "conv_bn_fused"): + y = self.conv_bn_fused(x) + else: + y = self.norm(self.conv(x)) + return self.act(y) + + def convert_to_deploy(self): + if not hasattr(self, "conv_bn_fused"): + self.conv_bn_fused = nn.Conv2d( + self.ch_in, + self.ch_out, + self.kernel_size, + self.stride, + groups=self.g, + padding=self.padding, + bias=True, + ) + + kernel, bias = self.get_equivalent_kernel_bias() + self.conv_bn_fused.weight.data = kernel + self.conv_bn_fused.bias.data = bias + self.__delattr__("conv") + self.__delattr__("norm") + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor() + + return kernel3x3, bias3x3 + + def _fuse_bn_tensor(self): + kernel = self.conv.weight + running_mean = self.norm.running_mean + running_var = self.norm.running_var + gamma = self.norm.weight + beta = self.norm.bias + eps = self.norm.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class ConvNormLayer(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size, stride, g=1, padding=None, bias=False, act=None): + super().__init__() + padding = (kernel_size - 1) // 2 if padding is None else padding + self.conv = nn.Conv2d( + ch_in, ch_out, kernel_size, stride, groups=g, padding=padding, bias=bias + ) + self.norm = nn.BatchNorm2d(ch_out) + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + return self.act(self.norm(self.conv(x))) + + +class SCDown(nn.Module): + def __init__(self, c1, c2, k, s): + super().__init__() + self.cv1 = ConvNormLayer_fuse(c1, c2, 1, 1) + self.cv2 = ConvNormLayer_fuse(c2, c2, k, s, c2) + + def forward(self, x): + return self.cv2(self.cv1(x)) + + +class VGGBlock(nn.Module): + def __init__(self, ch_in, ch_out, act="relu"): + super().__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None) + self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None) + self.act = nn.Identity() if act is None else act + + def forward(self, x): + if hasattr(self, "conv"): + y = self.conv(x) + else: + y = self.conv1(x) + self.conv2(x) + + return self.act(y) + + def convert_to_deploy(self): + if not hasattr(self, "conv"): + self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) + + kernel, bias = self.get_equivalent_kernel_bias() + self.conv.weight.data = kernel + self.conv.bias.data = bias + self.__delattr__("conv1") + self.__delattr__("conv2") + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) + + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return F.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch: ConvNormLayer): + if branch is None: + return 0, 0 + kernel = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class ELAN(nn.Module): + # csp-elan + def __init__(self, c1, c2, c3, c4, n=2, bias=False, act="silu", bottletype=VGGBlock): + super().__init__() + self.c = c3 + self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act) + self.cv2 = nn.Sequential( + bottletype(c3 // 2, c4, act=get_activation(act)), + ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act), + ) + self.cv3 = nn.Sequential( + bottletype(c4, c4, act=get_activation(act)), + ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act), + ) + self.cv4 = ConvNormLayer_fuse(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act) + + def forward(self, x): + # y = [self.cv1(x)] + y = list(self.cv1(x).chunk(2, 1)) + y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +class RepNCSPELAN4(nn.Module): + # csp-elan + def __init__(self, c1, c2, c3, c4, n=3, bias=False, act="silu"): + super().__init__() + self.c = c3 // 2 + self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act) + self.cv2 = nn.Sequential( + CSPLayer(c3 // 2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), + ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act), + ) + self.cv3 = nn.Sequential( + CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), + ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act), + ) + self.cv4 = ConvNormLayer_fuse(c3 + (2 * c4), c2, 1, 1, bias=bias, act=act) + + def forward_chunk(self, x): + y = list(self.cv1(x).chunk(2, 1)) + y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + def forward(self, x): + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +class CSPLayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + num_blocks=3, + expansion=1.0, + bias=False, + act="silu", + bottletype=VGGBlock, + ): + super(CSPLayer, self).__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.conv2 = ConvNormLayer_fuse(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.bottlenecks = nn.Sequential( + *[ + bottletype(hidden_channels, hidden_channels, act=get_activation(act)) + for _ in range(num_blocks) + ] + ) + if hidden_channels != out_channels: + self.conv3 = ConvNormLayer_fuse(hidden_channels, out_channels, 1, 1, bias=bias, act=act) + else: + self.conv3 = nn.Identity() + + def forward(self, x): + x_1 = self.conv1(x) + x_1 = self.bottlenecks(x_1) + x_2 = self.conv2(x) + return self.conv3(x_1 + x_2) + + +# transformer +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.normalize_before = normalize_before + + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True) + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.activation = get_activation(activation) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: + residual = src + if self.normalize_before: + src = self.norm1(src) + q = k = self.with_pos_embed(src, pos_embed) + src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask) + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: + output = src + for layer in self.layers: + output = layer(output, src_mask=src_mask, pos_embed=pos_embed) + + if self.norm is not None: + output = self.norm(output) + + return output + + +@register() +class HybridEncoder(nn.Module): + __share__ = [ + "eval_spatial_size", + ] + + def __init__( + self, + in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + hidden_dim=256, + nhead=8, + dim_feedforward=1024, + dropout=0.0, + enc_act="gelu", + use_encoder_idx=[2], + num_encoder_layers=1, + pe_temperature=10000, + expansion=1.0, + depth_mult=1.0, + act="silu", + eval_spatial_size=None, + ): + super().__init__() + self.in_channels = in_channels + self.feat_strides = feat_strides + self.hidden_dim = hidden_dim + self.use_encoder_idx = use_encoder_idx + self.num_encoder_layers = num_encoder_layers + self.pe_temperature = pe_temperature + self.eval_spatial_size = eval_spatial_size + self.out_channels = [hidden_dim for _ in range(len(in_channels))] + self.out_strides = feat_strides + + # channel projection + self.input_proj = nn.ModuleList() + for in_channel in in_channels: + proj = nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False)), + ("norm", nn.BatchNorm2d(hidden_dim)), + ] + ) + ) + + self.input_proj.append(proj) + + # encoder transformer + encoder_layer = TransformerEncoderLayer( + hidden_dim, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=enc_act, + ) + + self.encoder = nn.ModuleList( + [ + TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) + for _ in range(len(use_encoder_idx)) + ] + ) + + # top-down fpn + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(in_channels) - 1, 0, -1): + self.lateral_convs.append(ConvNormLayer_fuse(hidden_dim, hidden_dim, 1, 1)) + self.fpn_blocks.append( + RepNCSPELAN4( + hidden_dim * 2, + hidden_dim, + hidden_dim * 2, + round(expansion * hidden_dim // 2), + round(3 * depth_mult), + ) + # CSPLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion, bottletype=VGGBlock) + ) + + # bottom-up pan + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(in_channels) - 1): + self.downsample_convs.append( + nn.Sequential( + SCDown(hidden_dim, hidden_dim, 3, 2), + ) + ) + self.pan_blocks.append( + RepNCSPELAN4( + hidden_dim * 2, + hidden_dim, + hidden_dim * 2, + round(expansion * hidden_dim // 2), + round(3 * depth_mult), + ) + # CSPLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion, bottletype=VGGBlock) + ) + + self._reset_parameters() + + def _reset_parameters(self): + if self.eval_spatial_size: + for idx in self.use_encoder_idx: + stride = self.feat_strides[idx] + pos_embed = self.build_2d_sincos_position_embedding( + self.eval_spatial_size[1] // stride, + self.eval_spatial_size[0] // stride, + self.hidden_dim, + self.pe_temperature, + ) + setattr(self, f"pos_embed{idx}", pos_embed) + # self.register_buffer(f'pos_embed{idx}', pos_embed) + + @staticmethod + def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): + """ """ + grid_w = torch.arange(int(w), dtype=torch.float32) + grid_h = torch.arange(int(h), dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + assert ( + embed_dim % 4 == 0 + ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward(self, feats): + assert len(feats) == len(self.in_channels) + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + + # encoder + if self.num_encoder_layers > 0: + for i, enc_ind in enumerate(self.use_encoder_idx): + h, w = proj_feats[enc_ind].shape[2:] + # flatten [B, C, H, W] to [B, HxW, C] + src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_spatial_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + w, h, self.hidden_dim, self.pe_temperature + ).to(src_flatten.device) + else: + pos_embed = getattr(self, f"pos_embed{enc_ind}", None).to(src_flatten.device) + + memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed) + proj_feats[enc_ind] = ( + memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() + ) + + # broadcasting and fusion + inner_outs = [proj_feats[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = proj_feats[idx - 1] + feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh) + inner_outs[0] = feat_heigh + upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest") + inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx]( + torch.concat([upsample_feat, feat_low], dim=1) + ) + inner_outs.insert(0, inner_out) + + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_height], dim=1)) + outs.append(out) + + return outs diff --git a/src/zoo/dfine/matcher.py b/src/zoo/dfine/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6321b317906b7c2a4d210021c718f76f1cc0f0 --- /dev/null +++ b/src/zoo/dfine/matcher.py @@ -0,0 +1,160 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +Modules to compute the matching cost and solve the corresponding LSAP. + +Copyright (c) 2024 The D-FINE Authors All Rights Reserved. +""" + +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from ...core import register +from .box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +@register() +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + __share__ = [ + "use_focal_loss", + ] + + def __init__(self, weight_dict, use_focal_loss=False, alpha=0.25, gamma=2.0): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = weight_dict["cost_class"] + self.cost_bbox = weight_dict["cost_bbox"] + self.cost_giou = weight_dict["cost_giou"] + + self.use_focal_loss = use_focal_loss + self.alpha = alpha + self.gamma = gamma + + assert ( + self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0 + ), "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs: Dict[str, torch.Tensor], targets, return_topk=False): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + if self.use_focal_loss: + out_prob = F.sigmoid(outputs["pred_logits"].flatten(0, 1)) + else: + out_prob = ( + outputs["pred_logits"].flatten(0, 1).softmax(-1) + ) # [batch_size * num_queries, num_classes] + + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + if self.use_focal_loss: + out_prob = out_prob[:, tgt_ids] + neg_cost_class = ( + (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log()) + ) + pos_cost_class = ( + self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log()) + ) + cost_class = pos_cost_class - neg_cost_class + else: + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix 3 * self.cost_bbox + 2 * self.cost_class + self.cost_giou + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + C = torch.nan_to_num(C, nan=1.0) + indices_pre = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + indices = [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices_pre + ] + + # Compute topk indices + if return_topk: + return { + "indices_o2m": self.get_top_k_matches( + C, sizes=sizes, k=return_topk, initial_indices=indices_pre + ) + } + + return {"indices": indices} # , 'indices_o2m': C.min(-1)[1]} + + def get_top_k_matches(self, C, sizes, k=1, initial_indices=None): + indices_list = [] + # C_original = C.clone() + for i in range(k): + indices_k = ( + [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + if i > 0 + else initial_indices + ) + indices_list.append( + [ + (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) + for i, j in indices_k + ] + ) + for c, idx_k in zip(C.split(sizes, -1), indices_k): + idx_k = np.stack(idx_k) + c[:, idx_k] = 1e6 + indices_list = [ + ( + torch.cat([indices_list[i][j][0] for i in range(k)], dim=0), + torch.cat([indices_list[i][j][1] for i in range(k)], dim=0), + ) + for j in range(len(sizes)) + ] + # C.copy_(C_original) + return indices_list diff --git a/src/zoo/dfine/postprocessor.py b/src/zoo/dfine/postprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..ac34b565bc4119d14a353a8840980a6136603051 --- /dev/null +++ b/src/zoo/dfine/postprocessor.py @@ -0,0 +1,93 @@ +""" +Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from ...core import register + +__all__ = ["DFINEPostProcessor"] + + +def mod(a, b): + out = a - a // b * b + return out + + +@register() +class DFINEPostProcessor(nn.Module): + __share__ = ["num_classes", "use_focal_loss", "num_top_queries", "remap_mscoco_category"] + + def __init__( + self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False + ) -> None: + super().__init__() + self.use_focal_loss = use_focal_loss + self.num_top_queries = num_top_queries + self.num_classes = int(num_classes) + self.remap_mscoco_category = remap_mscoco_category + self.deploy_mode = False + + def extra_repr(self) -> str: + return f"use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}" + + # def forward(self, outputs, orig_target_sizes): + def forward(self, outputs, orig_target_sizes: torch.Tensor): + logits, boxes = outputs["pred_logits"], outputs["pred_boxes"] + # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + + bbox_pred = torchvision.ops.box_convert(boxes, in_fmt="cxcywh", out_fmt="xyxy") + bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) + + if self.use_focal_loss: + scores = F.sigmoid(logits) + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) + # TODO for older tensorrt + # labels = index % self.num_classes + labels = mod(index, self.num_classes) + index = index // self.num_classes + boxes = bbox_pred.gather( + dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]) + ) + + else: + scores = F.softmax(logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > self.num_top_queries: + scores, index = torch.topk(scores, self.num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather( + boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]) + ) + + # TODO for onnx export + if self.deploy_mode: + return labels, boxes, scores + + # TODO + if self.remap_mscoco_category: + from ...data.dataset import mscoco_label2category + + labels = ( + torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()]) + .to(boxes.device) + .reshape(labels.shape) + ) + + results = [] + for lab, box, sco in zip(labels, boxes, scores): + result = dict(labels=lab, boxes=box, scores=sco) + results.append(result) + + return results + + def deploy( + self, + ): + self.eval() + self.deploy_mode = True + return self diff --git a/src/zoo/dfine/utils.py b/src/zoo/dfine/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..630f5bd0adea5c67e911073437a57a3a8dbeaeb7 --- /dev/null +++ b/src/zoo/dfine/utils.py @@ -0,0 +1,182 @@ +""" +D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement +Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. +--------------------------------------------------------------------------------- +Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) +Copyright (c) 2023 lyuwenyu. All Rights Reserved. +""" + +import math +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + x = x.clip(min=0.0, max=1.0) + return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps)) + + +def bias_init_with_prob(prior_prob=0.01): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-math.log((1 - prior_prob) / prior_prob)) + return bias_init + + +def deformable_attention_core_func( + value, value_spatial_shapes, sampling_locations, attention_weights +): + """ + Args: + value (Tensor): [bs, value_length, n_head, c] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] + sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] + attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, _, n_head, c = value.shape + _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape + + split_shape = [h * w for h, w in value_spatial_shapes] + value_list = value.split(split_shape, dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[level].flatten(2).permute(0, 2, 1).reshape(bs * n_head, c, h, w) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].permute(0, 2, 1, 3, 4).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.permute(0, 2, 1, 3, 4).reshape( + bs * n_head, 1, Len_q, n_levels * n_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .reshape(bs, n_head * c, Len_q) + ) + + return output.permute(0, 2, 1) + + +def deformable_attention_core_func_v2( + value: torch.Tensor, + value_spatial_shapes, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, + num_points_list: List[int], + method="default", +): + """ + Args: + value (Tensor): [bs, value_length, n_head, c] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] + sampling_locations (Tensor): [bs, query_length, n_head, n_levels * n_points, 2] + attention_weights (Tensor): [bs, query_length, n_head, n_levels * n_points] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, n_head, c, _ = value[0].shape + _, Len_q, _, _, _ = sampling_locations.shape + + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + + elif method == "discrete": + sampling_grids = sampling_locations + + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_locations_list = sampling_grids.split(num_points_list, dim=-2) + + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes): + value_l = value[level].reshape(bs * n_head, c, h, w) + sampling_grid_l: torch.Tensor = sampling_locations_list[level] + + if method == "default": + sampling_value_l = F.grid_sample( + value_l, sampling_grid_l, mode="bilinear", padding_mode="zeros", align_corners=False + ) + + elif method == "discrete": + # n * m, seq, n, 2 + sampling_coord = ( + sampling_grid_l * torch.tensor([[w, h]], device=value_l.device) + 0.5 + ).to(torch.int64) + + # FIX ME? for rectangle input + sampling_coord = sampling_coord.clamp(0, h - 1) + sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2) + + s_idx = ( + torch.arange(sampling_coord.shape[0], device=value_l.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l: torch.Tensor = value_l[ + s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0] + ] # n l c + + sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape( + bs * n_head, c, Len_q, num_points_list[level] + ) + + sampling_value_list.append(sampling_value_l) + + attn_weights = attention_weights.permute(0, 2, 1, 3).reshape( + bs * n_head, 1, Len_q, sum(num_points_list) + ) + weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights + output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q) + + return output.permute(0, 2, 1) + + +def get_activation(act: str, inpace: bool = True): + """get activation""" + if act is None: + return nn.Identity() + + elif isinstance(act, nn.Module): + return act + + act = act.lower() + + if act == "silu" or act == "swish": + m = nn.SiLU() + + elif act == "relu": + m = nn.ReLU() + + elif act == "leaky_relu": + m = nn.LeakyReLU() + + elif act == "silu": + m = nn.SiLU() + + elif act == "gelu": + m = nn.GELU() + + elif act == "hardsigmoid": + m = nn.Hardsigmoid() + + else: + raise RuntimeError("") + + if hasattr(m, "inplace"): + m.inplace = inpace + + return m