import sys import glob import os import shutil import time import torch import torch.utils.data from collections import OrderedDict if sys.version_info >= (3, 10): from collections.abc import Sequence else: from collections import Sequence from pointcept.utils.timer import Timer from pointcept.utils.comm import is_main_process, synchronize, get_world_size from pointcept.utils.cache import shared_dict import pointcept.utils.comm as comm # from pointcept.engines.test import TESTERS from .default import HookBase from .builder import HOOKS @HOOKS.register_module() class IterationTimer(HookBase): def __init__(self, warmup_iter=1): self._warmup_iter = warmup_iter self._start_time = time.perf_counter() self._iter_timer = Timer() self._remain_iter = 0 def before_train(self): self._start_time = time.perf_counter() self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader) def before_epoch(self): self._iter_timer.reset() def before_step(self): data_time = self._iter_timer.seconds() self.trainer.storage.put_scalar("data_time", data_time) def after_step(self): batch_time = self._iter_timer.seconds() self._iter_timer.reset() self.trainer.storage.put_scalar("batch_time", batch_time) self._remain_iter -= 1 remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg t_m, t_s = divmod(remain_time, 60) t_h, t_m = divmod(t_m, 60) remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s)) if "iter_info" in self.trainer.comm_info.keys(): info = ( "Data {data_time_val:.3f} ({data_time_avg:.3f}) " "Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) " "Remain {remain_time} ".format( data_time_val=self.trainer.storage.history("data_time").val, data_time_avg=self.trainer.storage.history("data_time").avg, batch_time_val=self.trainer.storage.history("batch_time").val, batch_time_avg=self.trainer.storage.history("batch_time").avg, remain_time=remain_time, ) ) self.trainer.comm_info["iter_info"] += info if self.trainer.comm_info["iter"] <= self._warmup_iter: self.trainer.storage.history("data_time").reset() self.trainer.storage.history("batch_time").reset() @HOOKS.register_module() class InformationWriter(HookBase): def __init__(self): self.curr_iter = 0 self.model_output_keys = [] def before_train(self): self.trainer.comm_info["iter_info"] = "" self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader) def before_step(self): self.curr_iter += 1 # MSC pretrain do not have offset information. Comment the code for support MSC # info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \ # "Scan {batch_size} ({points_num}) ".format( # epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch, # iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader), # batch_size=len(self.trainer.comm_info["input_dict"]["offset"]), # points_num=self.trainer.comm_info["input_dict"]["offset"][-1] # ) info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format( epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch, iter=self.trainer.comm_info["iter"] + 1, max_iter=len(self.trainer.train_loader), ) self.trainer.comm_info["iter_info"] += info def after_step(self): if "model_output_dict" in self.trainer.comm_info.keys(): model_output_dict = self.trainer.comm_info["model_output_dict"] self.model_output_keys = model_output_dict.keys() for key in self.model_output_keys: self.trainer.storage.put_scalar(key, model_output_dict[key].item()) for key in self.model_output_keys: self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format( key=key, value=self.trainer.storage.history(key).val ) lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"] self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr) self.trainer.logger.info(self.trainer.comm_info["iter_info"]) self.trainer.comm_info["iter_info"] = "" # reset iter info if self.trainer.writer is not None: self.trainer.writer.add_scalar("lr", lr, self.curr_iter) for key in self.model_output_keys: self.trainer.writer.add_scalar( "train_batch/" + key, self.trainer.storage.history(key).val, self.curr_iter, ) def after_epoch(self): epoch_info = "Train result: " for key in self.model_output_keys: epoch_info += "{key}: {value:.4f} ".format( key=key, value=self.trainer.storage.history(key).avg ) self.trainer.logger.info(epoch_info) if self.trainer.writer is not None: for key in self.model_output_keys: self.trainer.writer.add_scalar( "train/" + key, self.trainer.storage.history(key).avg, self.trainer.epoch + 1, ) @HOOKS.register_module() class CheckpointSaver(HookBase): def __init__(self, save_freq=None): self.save_freq = save_freq # None or int, None indicate only save model last def after_epoch(self): if is_main_process(): is_best = False if self.trainer.cfg.evaluate: current_metric_value = self.trainer.comm_info["current_metric_value"] current_metric_name = self.trainer.comm_info["current_metric_name"] if current_metric_value > self.trainer.best_metric_value: self.trainer.best_metric_value = current_metric_value is_best = True self.trainer.logger.info( "Best validation {} updated to: {:.4f}".format( current_metric_name, current_metric_value ) ) self.trainer.logger.info( "Currently Best {}: {:.4f}".format( current_metric_name, self.trainer.best_metric_value ) ) filename = os.path.join( self.trainer.cfg.save_path, "model", "model_last.pth" ) self.trainer.logger.info("Saving checkpoint to: " + filename) torch.save( { "epoch": self.trainer.epoch + 1, "state_dict": self.trainer.model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "scheduler": self.trainer.scheduler.state_dict(), "scaler": self.trainer.scaler.state_dict() if self.trainer.cfg.enable_amp else None, "best_metric_value": self.trainer.best_metric_value, }, filename + ".tmp", ) os.replace(filename + ".tmp", filename) if is_best: shutil.copyfile( filename, os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"), ) if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0: shutil.copyfile( filename, os.path.join( self.trainer.cfg.save_path, "model", f"epoch_{self.trainer.epoch + 1}.pth", ), ) @HOOKS.register_module() class CheckpointLoader(HookBase): def __init__(self, keywords="", replacement=None, strict=False): self.keywords = keywords self.replacement = replacement if replacement is not None else keywords self.strict = strict def before_train(self): self.trainer.logger.info("=> Loading checkpoint & weight ...") if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight): self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}") checkpoint = torch.load( self.trainer.cfg.weight, map_location=lambda storage, loc: storage.cuda(), ) self.trainer.logger.info( f"Loading layer weights with keyword: {self.keywords}, " f"replace keyword with: {self.replacement}" ) weight = OrderedDict() for key, value in checkpoint["state_dict"].items(): if not key.startswith("module."): if comm.get_world_size() > 1: key = "module." + key # xxx.xxx -> module.xxx.xxx # Now all keys contain "module." no matter DDP or not. if self.keywords in key: key = key.replace(self.keywords, self.replacement) if comm.get_world_size() == 1: key = key[7:] # module.xxx.xxx -> xxx.xxx weight[key] = value load_state_info = self.trainer.model.load_state_dict( weight, strict=self.strict ) self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") if self.trainer.cfg.resume: self.trainer.logger.info( f"Resuming train at eval epoch: {checkpoint['epoch']}" ) self.trainer.start_epoch = checkpoint["epoch"] self.trainer.best_metric_value = checkpoint["best_metric_value"] self.trainer.optimizer.load_state_dict(checkpoint["optimizer"]) self.trainer.scheduler.load_state_dict(checkpoint["scheduler"]) if self.trainer.cfg.enable_amp: self.trainer.scaler.load_state_dict(checkpoint["scaler"]) else: self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}") @HOOKS.register_module() class DataCacheOperator(HookBase): def __init__(self, data_root, split): self.data_root = data_root self.split = split self.data_list = self.get_data_list() def get_data_list(self): if isinstance(self.split, str): data_list = glob.glob(os.path.join(self.data_root, self.split, "*.pth")) elif isinstance(self.split, Sequence): data_list = [] for split in self.split: data_list += glob.glob(os.path.join(self.data_root, split, "*.pth")) else: raise NotImplementedError return data_list def get_cache_name(self, data_path): data_name = data_path.replace(os.path.dirname(self.data_root), "").split(".")[0] return "pointcept" + data_name.replace(os.path.sep, "-") def before_train(self): self.trainer.logger.info( f"=> Caching dataset: {self.data_root}, split: {self.split} ..." ) if is_main_process(): for data_path in self.data_list: cache_name = self.get_cache_name(data_path) data = torch.load(data_path) shared_dict(cache_name, data) synchronize() @HOOKS.register_module() class RuntimeProfiler(HookBase): def __init__( self, forward=True, backward=True, interrupt=False, warm_up=2, sort_by="cuda_time_total", row_limit=30, ): self.forward = forward self.backward = backward self.interrupt = interrupt self.warm_up = warm_up self.sort_by = sort_by self.row_limit = row_limit def before_train(self): self.trainer.logger.info("Profiling runtime ...") from torch.profiler import profile, record_function, ProfilerActivity for i, input_dict in enumerate(self.trainer.train_loader): if i == self.warm_up + 1: break for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].cuda(non_blocking=True) if self.forward: with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True, ) as forward_prof: with record_function("model_inference"): output_dict = self.trainer.model(input_dict) else: output_dict = self.trainer.model(input_dict) loss = output_dict["loss"] if self.backward: with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True, ) as backward_prof: with record_function("model_inference"): loss.backward() self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]") if self.forward: self.trainer.logger.info( "Forward profile: \n" + str( forward_prof.key_averages().table( sort_by=self.sort_by, row_limit=self.row_limit ) ) ) forward_prof.export_chrome_trace( os.path.join(self.trainer.cfg.save_path, "forward_trace.json") ) if self.backward: self.trainer.logger.info( "Backward profile: \n" + str( backward_prof.key_averages().table( sort_by=self.sort_by, row_limit=self.row_limit ) ) ) backward_prof.export_chrome_trace( os.path.join(self.trainer.cfg.save_path, "backward_trace.json") ) if self.interrupt: sys.exit(0) @HOOKS.register_module() class RuntimeProfilerV2(HookBase): def __init__( self, interrupt=False, wait=1, warmup=1, active=10, repeat=1, sort_by="cuda_time_total", row_limit=30, ): self.interrupt = interrupt self.wait = wait self.warmup = warmup self.active = active self.repeat = repeat self.sort_by = sort_by self.row_limit = row_limit def before_train(self): self.trainer.logger.info("Profiling runtime ...") from torch.profiler import ( profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler, ) prof = profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], schedule=schedule( wait=self.wait, warmup=self.warmup, active=self.active, repeat=self.repeat, ), on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path), record_shapes=True, profile_memory=True, with_stack=True, ) prof.start() for i, input_dict in enumerate(self.trainer.train_loader): if i >= (self.wait + self.warmup + self.active) * self.repeat: break for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].cuda(non_blocking=True) with record_function("model_forward"): output_dict = self.trainer.model(input_dict) loss = output_dict["loss"] with record_function("model_backward"): loss.backward() prof.step() self.trainer.logger.info( f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]" ) self.trainer.logger.info( "Profile: \n" + str( prof.key_averages().table( sort_by=self.sort_by, row_limit=self.row_limit ) ) ) prof.stop() if self.interrupt: sys.exit(0)