|
""" |
|
Trainer |
|
|
|
Author: Xiaoyang Wu ([email protected]) |
|
Please cite our work if the code is helpful to you. |
|
""" |
|
|
|
import os |
|
import sys |
|
import weakref |
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.data |
|
from functools import partial |
|
|
|
if sys.version_info >= (3, 10): |
|
from collections.abc import Iterator |
|
else: |
|
from collections import Iterator |
|
from tensorboardX import SummaryWriter |
|
|
|
from .defaults import create_ddp_model, worker_init_fn |
|
from .hooks import HookBase, build_hooks |
|
import pointcept.utils.comm as comm |
|
from pointcept.datasets import build_dataset, point_collate_fn, collate_fn |
|
from pointcept.models import build_model |
|
from pointcept.utils.logger import get_root_logger |
|
from pointcept.utils.optimizer import build_optimizer |
|
from pointcept.utils.scheduler import build_scheduler |
|
from pointcept.utils.events import EventStorage, ExceptionWriter |
|
from pointcept.utils.registry import Registry |
|
|
|
|
|
TRAINERS = Registry("trainers") |
|
|
|
|
|
class TrainerBase: |
|
def __init__(self) -> None: |
|
self.hooks = [] |
|
self.epoch = 0 |
|
self.start_epoch = 0 |
|
self.max_epoch = 0 |
|
self.max_iter = 0 |
|
self.comm_info = dict() |
|
self.data_iterator: Iterator = enumerate([]) |
|
self.storage: EventStorage |
|
self.writer: SummaryWriter |
|
|
|
def register_hooks(self, hooks) -> None: |
|
hooks = build_hooks(hooks) |
|
for h in hooks: |
|
assert isinstance(h, HookBase) |
|
|
|
|
|
|
|
|
|
h.trainer = weakref.proxy(self) |
|
self.hooks.extend(hooks) |
|
|
|
def train(self): |
|
with EventStorage() as self.storage: |
|
|
|
self.before_train() |
|
for self.epoch in range(self.start_epoch, self.max_epoch): |
|
|
|
self.before_epoch() |
|
|
|
for ( |
|
self.comm_info["iter"], |
|
self.comm_info["input_dict"], |
|
) in self.data_iterator: |
|
|
|
self.before_step() |
|
|
|
self.run_step() |
|
|
|
self.after_step() |
|
|
|
self.after_epoch() |
|
|
|
self.after_train() |
|
|
|
def before_train(self): |
|
for h in self.hooks: |
|
h.before_train() |
|
|
|
def before_epoch(self): |
|
for h in self.hooks: |
|
h.before_epoch() |
|
|
|
def before_step(self): |
|
for h in self.hooks: |
|
h.before_step() |
|
|
|
def run_step(self): |
|
raise NotImplementedError |
|
|
|
def after_step(self): |
|
for h in self.hooks: |
|
h.after_step() |
|
|
|
def after_epoch(self): |
|
for h in self.hooks: |
|
h.after_epoch() |
|
self.storage.reset_histories() |
|
|
|
def after_train(self): |
|
|
|
comm.synchronize() |
|
for h in self.hooks: |
|
h.after_train() |
|
if comm.is_main_process(): |
|
self.writer.close() |
|
|
|
|
|
@TRAINERS.register_module("DefaultTrainer") |
|
class Trainer(TrainerBase): |
|
def __init__(self, cfg): |
|
super(Trainer, self).__init__() |
|
self.epoch = 0 |
|
self.start_epoch = 0 |
|
self.max_epoch = cfg.eval_epoch |
|
self.best_metric_value = -torch.inf |
|
self.logger = get_root_logger( |
|
log_file=os.path.join(cfg.save_path, "train.log"), |
|
file_mode="a" if cfg.resume else "w", |
|
) |
|
self.logger.info("=> Loading config ...") |
|
self.cfg = cfg |
|
self.logger.info(f"Save path: {cfg.save_path}") |
|
self.logger.info(f"Config:\n{cfg.pretty_text}") |
|
self.logger.info("=> Building model ...") |
|
self.model = self.build_model() |
|
self.logger.info("=> Building writer ...") |
|
self.writer = self.build_writer() |
|
self.logger.info("=> Building train dataset & dataloader ...") |
|
self.train_loader = self.build_train_loader() |
|
self.logger.info("=> Building val dataset & dataloader ...") |
|
self.val_loader = self.build_val_loader() |
|
self.logger.info("=> Building optimize, scheduler, scaler(amp) ...") |
|
self.optimizer = self.build_optimizer() |
|
self.scheduler = self.build_scheduler() |
|
self.scaler = self.build_scaler() |
|
self.logger.info("=> Building hooks ...") |
|
self.register_hooks(self.cfg.hooks) |
|
|
|
def train(self): |
|
with EventStorage() as self.storage, ExceptionWriter(): |
|
|
|
self.before_train() |
|
self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>") |
|
for self.epoch in range(self.start_epoch, self.max_epoch): |
|
|
|
|
|
if comm.get_world_size() > 1: |
|
self.train_loader.sampler.set_epoch(self.epoch) |
|
self.model.train() |
|
self.data_iterator = enumerate(self.train_loader) |
|
self.before_epoch() |
|
|
|
for ( |
|
self.comm_info["iter"], |
|
self.comm_info["input_dict"], |
|
) in self.data_iterator: |
|
|
|
self.before_step() |
|
|
|
self.run_step() |
|
|
|
self.after_step() |
|
|
|
self.after_epoch() |
|
|
|
self.after_train() |
|
|
|
def run_step(self): |
|
input_dict = self.comm_info["input_dict"] |
|
for key in input_dict.keys(): |
|
if isinstance(input_dict[key], torch.Tensor): |
|
input_dict[key] = input_dict[key].cuda(non_blocking=True) |
|
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp): |
|
output_dict = self.model(input_dict) |
|
loss = output_dict["loss"] |
|
self.optimizer.zero_grad() |
|
if self.cfg.enable_amp: |
|
self.scaler.scale(loss).backward() |
|
self.scaler.unscale_(self.optimizer) |
|
if self.cfg.clip_grad is not None: |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.cfg.clip_grad |
|
) |
|
self.scaler.step(self.optimizer) |
|
|
|
|
|
|
|
scaler = self.scaler.get_scale() |
|
self.scaler.update() |
|
if scaler <= self.scaler.get_scale(): |
|
self.scheduler.step() |
|
else: |
|
loss.backward() |
|
if self.cfg.clip_grad is not None: |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.cfg.clip_grad |
|
) |
|
self.optimizer.step() |
|
self.scheduler.step() |
|
if self.cfg.empty_cache: |
|
torch.cuda.empty_cache() |
|
self.comm_info["model_output_dict"] = output_dict |
|
|
|
def after_epoch(self): |
|
for h in self.hooks: |
|
h.after_epoch() |
|
self.storage.reset_histories() |
|
if self.cfg.empty_cache_per_epoch: |
|
torch.cuda.empty_cache() |
|
|
|
def build_model(self): |
|
model = build_model(self.cfg.model) |
|
if self.cfg.sync_bn: |
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
self.logger.info(f"Num params: {n_parameters}") |
|
model = create_ddp_model( |
|
model.cuda(), |
|
broadcast_buffers=False, |
|
find_unused_parameters=self.cfg.find_unused_parameters, |
|
) |
|
return model |
|
|
|
def build_writer(self): |
|
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None |
|
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") |
|
return writer |
|
|
|
def build_train_loader(self): |
|
train_data = build_dataset(self.cfg.data.train) |
|
|
|
if comm.get_world_size() > 1: |
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) |
|
else: |
|
train_sampler = None |
|
|
|
init_fn = ( |
|
partial( |
|
worker_init_fn, |
|
num_workers=self.cfg.num_worker_per_gpu, |
|
rank=comm.get_rank(), |
|
seed=self.cfg.seed, |
|
) |
|
if self.cfg.seed is not None |
|
else None |
|
) |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_data, |
|
batch_size=self.cfg.batch_size_per_gpu, |
|
shuffle=(train_sampler is None), |
|
num_workers=self.cfg.num_worker_per_gpu, |
|
sampler=train_sampler, |
|
collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob), |
|
pin_memory=True, |
|
worker_init_fn=init_fn, |
|
drop_last=True, |
|
persistent_workers=True, |
|
) |
|
return train_loader |
|
|
|
def build_val_loader(self): |
|
val_loader = None |
|
if self.cfg.evaluate: |
|
val_data = build_dataset(self.cfg.data.val) |
|
if comm.get_world_size() > 1: |
|
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) |
|
else: |
|
val_sampler = None |
|
val_loader = torch.utils.data.DataLoader( |
|
val_data, |
|
batch_size=self.cfg.batch_size_val_per_gpu, |
|
shuffle=False, |
|
num_workers=self.cfg.num_worker_per_gpu, |
|
pin_memory=True, |
|
sampler=val_sampler, |
|
collate_fn=collate_fn, |
|
) |
|
return val_loader |
|
|
|
def build_optimizer(self): |
|
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts) |
|
|
|
def build_scheduler(self): |
|
assert hasattr(self, "optimizer") |
|
assert hasattr(self, "train_loader") |
|
self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch |
|
return build_scheduler(self.cfg.scheduler, self.optimizer) |
|
|
|
def build_scaler(self): |
|
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None |
|
return scaler |
|
|
|
|
|
@TRAINERS.register_module("MultiDatasetTrainer") |
|
class MultiDatasetTrainer(Trainer): |
|
def build_train_loader(self): |
|
from pointcept.datasets import MultiDatasetDataloader |
|
|
|
train_data = build_dataset(self.cfg.data.train) |
|
train_loader = MultiDatasetDataloader( |
|
train_data, |
|
self.cfg.batch_size_per_gpu, |
|
self.cfg.num_worker_per_gpu, |
|
self.cfg.mix_prob, |
|
self.cfg.seed, |
|
) |
|
self.comm_info["iter_per_epoch"] = len(train_loader) |
|
return train_loader |
|
|