|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from typing import Callable, Dict, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.cuda.amp.grad_scaler import GradScaler
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
__all__ = [
|
|
"BaseConfig",
|
|
]
|
|
|
|
|
|
class BaseConfig(object):
|
|
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
self.task: str = None
|
|
|
|
|
|
self._model: nn.Module = None
|
|
self._postprocessor: nn.Module = None
|
|
self._criterion: nn.Module = None
|
|
self._optimizer: Optimizer = None
|
|
self._lr_scheduler: LRScheduler = None
|
|
self._lr_warmup_scheduler: LRScheduler = None
|
|
self._train_dataloader: DataLoader = None
|
|
self._val_dataloader: DataLoader = None
|
|
self._ema: nn.Module = None
|
|
self._scaler: GradScaler = None
|
|
self._train_dataset: Dataset = None
|
|
self._val_dataset: Dataset = None
|
|
self._collate_fn: Callable = None
|
|
self._evaluator: Callable[[nn.Module, DataLoader, str],] = None
|
|
self._writer: SummaryWriter = None
|
|
|
|
|
|
self.num_workers: int = 0
|
|
self.batch_size: int = None
|
|
self._train_batch_size: int = None
|
|
self._val_batch_size: int = None
|
|
self._train_shuffle: bool = None
|
|
self._val_shuffle: bool = None
|
|
|
|
|
|
self.resume: str = None
|
|
self.tuning: str = None
|
|
|
|
self.epochs: int = None
|
|
self.last_epoch: int = -1
|
|
|
|
self.use_amp: bool = False
|
|
self.use_ema: bool = False
|
|
self.ema_decay: float = 0.9999
|
|
self.ema_warmups: int = 2000
|
|
self.sync_bn: bool = False
|
|
self.clip_max_norm: float = 0.0
|
|
self.find_unused_parameters: bool = None
|
|
|
|
self.seed: int = None
|
|
self.print_freq: int = None
|
|
self.checkpoint_freq: int = 1
|
|
self.output_dir: str = None
|
|
self.summary_dir: str = None
|
|
self.device: str = ""
|
|
|
|
@property
|
|
def model(self) -> nn.Module:
|
|
return self._model
|
|
|
|
@model.setter
|
|
def model(self, m):
|
|
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
|
|
self._model = m
|
|
|
|
@property
|
|
def postprocessor(self) -> nn.Module:
|
|
return self._postprocessor
|
|
|
|
@postprocessor.setter
|
|
def postprocessor(self, m):
|
|
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
|
|
self._postprocessor = m
|
|
|
|
@property
|
|
def criterion(self) -> nn.Module:
|
|
return self._criterion
|
|
|
|
@criterion.setter
|
|
def criterion(self, m):
|
|
assert isinstance(m, nn.Module), f"{type(m)} != nn.Module, please check your model class"
|
|
self._criterion = m
|
|
|
|
@property
|
|
def optimizer(self) -> Optimizer:
|
|
return self._optimizer
|
|
|
|
@optimizer.setter
|
|
def optimizer(self, m):
|
|
assert isinstance(
|
|
m, Optimizer
|
|
), f"{type(m)} != optim.Optimizer, please check your model class"
|
|
self._optimizer = m
|
|
|
|
@property
|
|
def lr_scheduler(self) -> LRScheduler:
|
|
return self._lr_scheduler
|
|
|
|
@lr_scheduler.setter
|
|
def lr_scheduler(self, m):
|
|
assert isinstance(
|
|
m, LRScheduler
|
|
), f"{type(m)} != LRScheduler, please check your model class"
|
|
self._lr_scheduler = m
|
|
|
|
@property
|
|
def lr_warmup_scheduler(self) -> LRScheduler:
|
|
return self._lr_warmup_scheduler
|
|
|
|
@lr_warmup_scheduler.setter
|
|
def lr_warmup_scheduler(self, m):
|
|
self._lr_warmup_scheduler = m
|
|
|
|
@property
|
|
def train_dataloader(self) -> DataLoader:
|
|
if self._train_dataloader is None and self.train_dataset is not None:
|
|
loader = DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.train_batch_size,
|
|
num_workers=self.num_workers,
|
|
collate_fn=self.collate_fn,
|
|
shuffle=self.train_shuffle,
|
|
)
|
|
loader.shuffle = self.train_shuffle
|
|
self._train_dataloader = loader
|
|
|
|
return self._train_dataloader
|
|
|
|
@train_dataloader.setter
|
|
def train_dataloader(self, loader):
|
|
self._train_dataloader = loader
|
|
|
|
@property
|
|
def val_dataloader(self) -> DataLoader:
|
|
if self._val_dataloader is None and self.val_dataset is not None:
|
|
loader = DataLoader(
|
|
self.val_dataset,
|
|
batch_size=self.val_batch_size,
|
|
num_workers=self.num_workers,
|
|
drop_last=False,
|
|
collate_fn=self.collate_fn,
|
|
shuffle=self.val_shuffle,
|
|
persistent_workers=True,
|
|
)
|
|
loader.shuffle = self.val_shuffle
|
|
self._val_dataloader = loader
|
|
|
|
return self._val_dataloader
|
|
|
|
@val_dataloader.setter
|
|
def val_dataloader(self, loader):
|
|
self._val_dataloader = loader
|
|
|
|
@property
|
|
def ema(self) -> nn.Module:
|
|
if self._ema is None and self.use_ema and self.model is not None:
|
|
from ..optim import ModelEMA
|
|
|
|
self._ema = ModelEMA(self.model, self.ema_decay, self.ema_warmups)
|
|
return self._ema
|
|
|
|
@ema.setter
|
|
def ema(self, obj):
|
|
self._ema = obj
|
|
|
|
@property
|
|
def scaler(self) -> GradScaler:
|
|
if self._scaler is None and self.use_amp and torch.cuda.is_available():
|
|
self._scaler = GradScaler()
|
|
return self._scaler
|
|
|
|
@scaler.setter
|
|
def scaler(self, obj: GradScaler):
|
|
self._scaler = obj
|
|
|
|
@property
|
|
def val_shuffle(self) -> bool:
|
|
if self._val_shuffle is None:
|
|
print("warning: set default val_shuffle=False")
|
|
return False
|
|
return self._val_shuffle
|
|
|
|
@val_shuffle.setter
|
|
def val_shuffle(self, shuffle):
|
|
assert isinstance(shuffle, bool), "shuffle must be bool"
|
|
self._val_shuffle = shuffle
|
|
|
|
@property
|
|
def train_shuffle(self) -> bool:
|
|
if self._train_shuffle is None:
|
|
print("warning: set default train_shuffle=True")
|
|
return True
|
|
return self._train_shuffle
|
|
|
|
@train_shuffle.setter
|
|
def train_shuffle(self, shuffle):
|
|
assert isinstance(shuffle, bool), "shuffle must be bool"
|
|
self._train_shuffle = shuffle
|
|
|
|
@property
|
|
def train_batch_size(self) -> int:
|
|
if self._train_batch_size is None and isinstance(self.batch_size, int):
|
|
print(f"warning: set train_batch_size=batch_size={self.batch_size}")
|
|
return self.batch_size
|
|
return self._train_batch_size
|
|
|
|
@train_batch_size.setter
|
|
def train_batch_size(self, batch_size):
|
|
assert isinstance(batch_size, int), "batch_size must be int"
|
|
self._train_batch_size = batch_size
|
|
|
|
@property
|
|
def val_batch_size(self) -> int:
|
|
if self._val_batch_size is None:
|
|
print(f"warning: set val_batch_size=batch_size={self.batch_size}")
|
|
return self.batch_size
|
|
return self._val_batch_size
|
|
|
|
@val_batch_size.setter
|
|
def val_batch_size(self, batch_size):
|
|
assert isinstance(batch_size, int), "batch_size must be int"
|
|
self._val_batch_size = batch_size
|
|
|
|
@property
|
|
def train_dataset(self) -> Dataset:
|
|
return self._train_dataset
|
|
|
|
@train_dataset.setter
|
|
def train_dataset(self, dataset):
|
|
assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset"
|
|
self._train_dataset = dataset
|
|
|
|
@property
|
|
def val_dataset(self) -> Dataset:
|
|
return self._val_dataset
|
|
|
|
@val_dataset.setter
|
|
def val_dataset(self, dataset):
|
|
assert isinstance(dataset, Dataset), f"{type(dataset)} must be Dataset"
|
|
self._val_dataset = dataset
|
|
|
|
@property
|
|
def collate_fn(self) -> Callable:
|
|
return self._collate_fn
|
|
|
|
@collate_fn.setter
|
|
def collate_fn(self, fn):
|
|
assert isinstance(fn, Callable), f"{type(fn)} must be Callable"
|
|
self._collate_fn = fn
|
|
|
|
@property
|
|
def evaluator(self) -> Callable:
|
|
return self._evaluator
|
|
|
|
@evaluator.setter
|
|
def evaluator(self, fn):
|
|
assert isinstance(fn, Callable), f"{type(fn)} must be Callable"
|
|
self._evaluator = fn
|
|
|
|
@property
|
|
def writer(self) -> SummaryWriter:
|
|
if self._writer is None:
|
|
if self.summary_dir:
|
|
self._writer = SummaryWriter(self.summary_dir)
|
|
elif self.output_dir:
|
|
self._writer = SummaryWriter(Path(self.output_dir) / "summary")
|
|
return self._writer
|
|
|
|
@writer.setter
|
|
def writer(self, m):
|
|
assert isinstance(m, SummaryWriter), f"{type(m)} must be SummaryWriter"
|
|
self._writer = m
|
|
|
|
def __repr__(self):
|
|
s = ""
|
|
for k, v in self.__dict__.items():
|
|
if not k.startswith("_"):
|
|
s += f"{k}: {v}\n"
|
|
return s
|
|
|