Spaces:
Runtime error
Runtime error
""" | |
Trainer | |
Author: Xiaoyang Wu ([email protected]) | |
Please cite our work if the code is helpful to you. | |
""" | |
import os | |
import sys | |
import weakref | |
import torch | |
import torch.nn as nn | |
import torch.utils.data | |
from functools import partial | |
if sys.version_info >= (3, 10): | |
from collections.abc import Iterator | |
else: | |
from collections import Iterator | |
from tensorboardX import SummaryWriter | |
from .defaults import create_ddp_model, worker_init_fn | |
from .hooks import HookBase, build_hooks | |
import pointcept.utils.comm as comm | |
from pointcept.datasets import build_dataset, point_collate_fn, collate_fn | |
from pointcept.models import build_model | |
from pointcept.utils.logger import get_root_logger | |
from pointcept.utils.optimizer import build_optimizer | |
from pointcept.utils.scheduler import build_scheduler | |
from pointcept.utils.events import EventStorage, ExceptionWriter | |
from pointcept.utils.registry import Registry | |
TRAINERS = Registry("trainers") | |
class TrainerBase: | |
def __init__(self) -> None: | |
self.hooks = [] | |
self.epoch = 0 | |
self.start_epoch = 0 | |
self.max_epoch = 0 | |
self.max_iter = 0 | |
self.comm_info = dict() | |
self.data_iterator: Iterator = enumerate([]) | |
self.storage: EventStorage | |
self.writer: SummaryWriter | |
def register_hooks(self, hooks) -> None: | |
hooks = build_hooks(hooks) | |
for h in hooks: | |
assert isinstance(h, HookBase) | |
# To avoid circular reference, hooks and trainer cannot own each other. | |
# This normally does not matter, but will cause memory leak if the | |
# involved objects contain __del__: | |
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ | |
h.trainer = weakref.proxy(self) | |
self.hooks.extend(hooks) | |
def train(self): | |
with EventStorage() as self.storage: | |
# => before train | |
self.before_train() | |
for self.epoch in range(self.start_epoch, self.max_epoch): | |
# => before epoch | |
self.before_epoch() | |
# => run_epoch | |
for ( | |
self.comm_info["iter"], | |
self.comm_info["input_dict"], | |
) in self.data_iterator: | |
# => before_step | |
self.before_step() | |
# => run_step | |
self.run_step() | |
# => after_step | |
self.after_step() | |
# => after epoch | |
self.after_epoch() | |
# => after train | |
self.after_train() | |
def before_train(self): | |
for h in self.hooks: | |
h.before_train() | |
def before_epoch(self): | |
for h in self.hooks: | |
h.before_epoch() | |
def before_step(self): | |
for h in self.hooks: | |
h.before_step() | |
def run_step(self): | |
raise NotImplementedError | |
def after_step(self): | |
for h in self.hooks: | |
h.after_step() | |
def after_epoch(self): | |
for h in self.hooks: | |
h.after_epoch() | |
self.storage.reset_histories() | |
def after_train(self): | |
# Sync GPU before running train hooks | |
comm.synchronize() | |
for h in self.hooks: | |
h.after_train() | |
if comm.is_main_process(): | |
self.writer.close() | |
class Trainer(TrainerBase): | |
def __init__(self, cfg): | |
super(Trainer, self).__init__() | |
self.epoch = 0 | |
self.start_epoch = 0 | |
self.max_epoch = cfg.eval_epoch | |
self.best_metric_value = -torch.inf | |
self.logger = get_root_logger( | |
log_file=os.path.join(cfg.save_path, "train.log"), | |
file_mode="a" if cfg.resume else "w", | |
) | |
self.logger.info("=> Loading config ...") | |
self.cfg = cfg | |
self.logger.info(f"Save path: {cfg.save_path}") | |
self.logger.info(f"Config:\n{cfg.pretty_text}") | |
self.logger.info("=> Building model ...") | |
self.model = self.build_model() | |
self.logger.info("=> Building writer ...") | |
self.writer = self.build_writer() | |
self.logger.info("=> Building train dataset & dataloader ...") | |
self.train_loader = self.build_train_loader() | |
self.logger.info("=> Building val dataset & dataloader ...") | |
self.val_loader = self.build_val_loader() | |
self.logger.info("=> Building optimize, scheduler, scaler(amp) ...") | |
self.optimizer = self.build_optimizer() | |
self.scheduler = self.build_scheduler() | |
self.scaler = self.build_scaler() | |
self.logger.info("=> Building hooks ...") | |
self.register_hooks(self.cfg.hooks) | |
def train(self): | |
with EventStorage() as self.storage, ExceptionWriter(): | |
# => before train | |
self.before_train() | |
self.logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>") | |
for self.epoch in range(self.start_epoch, self.max_epoch): | |
# => before epoch | |
# TODO: optimize to iteration based | |
if comm.get_world_size() > 1: | |
self.train_loader.sampler.set_epoch(self.epoch) | |
self.model.train() | |
self.data_iterator = enumerate(self.train_loader) | |
self.before_epoch() | |
# => run_epoch | |
for ( | |
self.comm_info["iter"], | |
self.comm_info["input_dict"], | |
) in self.data_iterator: | |
# => before_step | |
self.before_step() | |
# => run_step | |
self.run_step() | |
# => after_step | |
self.after_step() | |
# => after epoch | |
self.after_epoch() | |
# => after train | |
self.after_train() | |
def run_step(self): | |
input_dict = self.comm_info["input_dict"] | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
with torch.cuda.amp.autocast(enabled=self.cfg.enable_amp): | |
output_dict = self.model(input_dict) | |
loss = output_dict["loss"] | |
self.optimizer.zero_grad() | |
if self.cfg.enable_amp: | |
self.scaler.scale(loss).backward() | |
self.scaler.unscale_(self.optimizer) | |
if self.cfg.clip_grad is not None: | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), self.cfg.clip_grad | |
) | |
self.scaler.step(self.optimizer) | |
# When enable amp, optimizer.step call are skipped if the loss scaling factor is too large. | |
# Fix torch warning scheduler step before optimizer step. | |
scaler = self.scaler.get_scale() | |
self.scaler.update() | |
if scaler <= self.scaler.get_scale(): | |
self.scheduler.step() | |
else: | |
loss.backward() | |
if self.cfg.clip_grad is not None: | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), self.cfg.clip_grad | |
) | |
self.optimizer.step() | |
self.scheduler.step() | |
if self.cfg.empty_cache: | |
torch.cuda.empty_cache() | |
self.comm_info["model_output_dict"] = output_dict | |
def after_epoch(self): | |
for h in self.hooks: | |
h.after_epoch() | |
self.storage.reset_histories() | |
if self.cfg.empty_cache_per_epoch: | |
torch.cuda.empty_cache() | |
def build_model(self): | |
model = build_model(self.cfg.model) | |
if self.cfg.sync_bn: | |
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
# logger.info(f"Model: \n{self.model}") | |
self.logger.info(f"Num params: {n_parameters}") | |
model = create_ddp_model( | |
model.cuda(), | |
broadcast_buffers=False, | |
find_unused_parameters=self.cfg.find_unused_parameters, | |
) | |
return model | |
def build_writer(self): | |
writer = SummaryWriter(self.cfg.save_path) if comm.is_main_process() else None | |
self.logger.info(f"Tensorboard writer logging dir: {self.cfg.save_path}") | |
return writer | |
def build_train_loader(self): | |
train_data = build_dataset(self.cfg.data.train) | |
if comm.get_world_size() > 1: | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) | |
else: | |
train_sampler = None | |
init_fn = ( | |
partial( | |
worker_init_fn, | |
num_workers=self.cfg.num_worker_per_gpu, | |
rank=comm.get_rank(), | |
seed=self.cfg.seed, | |
) | |
if self.cfg.seed is not None | |
else None | |
) | |
train_loader = torch.utils.data.DataLoader( | |
train_data, | |
batch_size=self.cfg.batch_size_per_gpu, | |
shuffle=(train_sampler is None), | |
num_workers=self.cfg.num_worker_per_gpu, | |
sampler=train_sampler, | |
collate_fn=partial(point_collate_fn, mix_prob=self.cfg.mix_prob), | |
pin_memory=True, | |
worker_init_fn=init_fn, | |
drop_last=True, | |
persistent_workers=True, | |
) | |
return train_loader | |
def build_val_loader(self): | |
val_loader = None | |
if self.cfg.evaluate: | |
val_data = build_dataset(self.cfg.data.val) | |
if comm.get_world_size() > 1: | |
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) | |
else: | |
val_sampler = None | |
val_loader = torch.utils.data.DataLoader( | |
val_data, | |
batch_size=self.cfg.batch_size_val_per_gpu, | |
shuffle=False, | |
num_workers=self.cfg.num_worker_per_gpu, | |
pin_memory=True, | |
sampler=val_sampler, | |
collate_fn=collate_fn, | |
) | |
return val_loader | |
def build_optimizer(self): | |
return build_optimizer(self.cfg.optimizer, self.model, self.cfg.param_dicts) | |
def build_scheduler(self): | |
assert hasattr(self, "optimizer") | |
assert hasattr(self, "train_loader") | |
self.cfg.scheduler.total_steps = len(self.train_loader) * self.cfg.eval_epoch | |
return build_scheduler(self.cfg.scheduler, self.optimizer) | |
def build_scaler(self): | |
scaler = torch.cuda.amp.GradScaler() if self.cfg.enable_amp else None | |
return scaler | |
class MultiDatasetTrainer(Trainer): | |
def build_train_loader(self): | |
from pointcept.datasets import MultiDatasetDataloader | |
train_data = build_dataset(self.cfg.data.train) | |
train_loader = MultiDatasetDataloader( | |
train_data, | |
self.cfg.batch_size_per_gpu, | |
self.cfg.num_worker_per_gpu, | |
self.cfg.mix_prob, | |
self.cfg.seed, | |
) | |
self.comm_info["iter_per_epoch"] = len(train_loader) | |
return train_loader | |