Spaces:
Runtime error
Runtime error
""" | |
Misc Hook | |
Author: Xiaoyang Wu ([email protected]) | |
Please cite our work if the code is helpful to you. | |
""" | |
import sys | |
import glob | |
import os | |
import shutil | |
import time | |
import torch | |
import torch.utils.data | |
from collections import OrderedDict | |
if sys.version_info >= (3, 10): | |
from collections.abc import Sequence | |
else: | |
from collections import Sequence | |
from pointcept.utils.timer import Timer | |
from pointcept.utils.comm import is_main_process, synchronize, get_world_size | |
from pointcept.utils.cache import shared_dict | |
import pointcept.utils.comm as comm | |
from pointcept.engines.test import TESTERS | |
from .default import HookBase | |
from .builder import HOOKS | |
class IterationTimer(HookBase): | |
def __init__(self, warmup_iter=1): | |
self._warmup_iter = warmup_iter | |
self._start_time = time.perf_counter() | |
self._iter_timer = Timer() | |
self._remain_iter = 0 | |
def before_train(self): | |
self._start_time = time.perf_counter() | |
self._remain_iter = self.trainer.max_epoch * len(self.trainer.train_loader) | |
def before_epoch(self): | |
self._iter_timer.reset() | |
def before_step(self): | |
data_time = self._iter_timer.seconds() | |
self.trainer.storage.put_scalar("data_time", data_time) | |
def after_step(self): | |
batch_time = self._iter_timer.seconds() | |
self._iter_timer.reset() | |
self.trainer.storage.put_scalar("batch_time", batch_time) | |
self._remain_iter -= 1 | |
remain_time = self._remain_iter * self.trainer.storage.history("batch_time").avg | |
t_m, t_s = divmod(remain_time, 60) | |
t_h, t_m = divmod(t_m, 60) | |
remain_time = "{:02d}:{:02d}:{:02d}".format(int(t_h), int(t_m), int(t_s)) | |
if "iter_info" in self.trainer.comm_info.keys(): | |
info = ( | |
"Data {data_time_val:.3f} ({data_time_avg:.3f}) " | |
"Batch {batch_time_val:.3f} ({batch_time_avg:.3f}) " | |
"Remain {remain_time} ".format( | |
data_time_val=self.trainer.storage.history("data_time").val, | |
data_time_avg=self.trainer.storage.history("data_time").avg, | |
batch_time_val=self.trainer.storage.history("batch_time").val, | |
batch_time_avg=self.trainer.storage.history("batch_time").avg, | |
remain_time=remain_time, | |
) | |
) | |
self.trainer.comm_info["iter_info"] += info | |
if self.trainer.comm_info["iter"] <= self._warmup_iter: | |
self.trainer.storage.history("data_time").reset() | |
self.trainer.storage.history("batch_time").reset() | |
class InformationWriter(HookBase): | |
def __init__(self): | |
self.curr_iter = 0 | |
self.model_output_keys = [] | |
def before_train(self): | |
self.trainer.comm_info["iter_info"] = "" | |
self.curr_iter = self.trainer.start_epoch * len(self.trainer.train_loader) | |
def before_step(self): | |
self.curr_iter += 1 | |
# MSC pretrain do not have offset information. Comment the code for support MSC | |
# info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] " \ | |
# "Scan {batch_size} ({points_num}) ".format( | |
# epoch=self.trainer.epoch + 1, max_epoch=self.trainer.max_epoch, | |
# iter=self.trainer.comm_info["iter"], max_iter=len(self.trainer.train_loader), | |
# batch_size=len(self.trainer.comm_info["input_dict"]["offset"]), | |
# points_num=self.trainer.comm_info["input_dict"]["offset"][-1] | |
# ) | |
info = "Train: [{epoch}/{max_epoch}][{iter}/{max_iter}] ".format( | |
epoch=self.trainer.epoch + 1, | |
max_epoch=self.trainer.max_epoch, | |
iter=self.trainer.comm_info["iter"] + 1, | |
max_iter=len(self.trainer.train_loader), | |
) | |
self.trainer.comm_info["iter_info"] += info | |
def after_step(self): | |
if "model_output_dict" in self.trainer.comm_info.keys(): | |
model_output_dict = self.trainer.comm_info["model_output_dict"] | |
self.model_output_keys = model_output_dict.keys() | |
for key in self.model_output_keys: | |
self.trainer.storage.put_scalar(key, model_output_dict[key].item()) | |
for key in self.model_output_keys: | |
self.trainer.comm_info["iter_info"] += "{key}: {value:.4f} ".format( | |
key=key, value=self.trainer.storage.history(key).val | |
) | |
lr = self.trainer.optimizer.state_dict()["param_groups"][0]["lr"] | |
self.trainer.comm_info["iter_info"] += "Lr: {lr:.5f}".format(lr=lr) | |
self.trainer.logger.info(self.trainer.comm_info["iter_info"]) | |
self.trainer.comm_info["iter_info"] = "" # reset iter info | |
if self.trainer.writer is not None: | |
self.trainer.writer.add_scalar("lr", lr, self.curr_iter) | |
for key in self.model_output_keys: | |
self.trainer.writer.add_scalar( | |
"train_batch/" + key, | |
self.trainer.storage.history(key).val, | |
self.curr_iter, | |
) | |
def after_epoch(self): | |
epoch_info = "Train result: " | |
for key in self.model_output_keys: | |
epoch_info += "{key}: {value:.4f} ".format( | |
key=key, value=self.trainer.storage.history(key).avg | |
) | |
self.trainer.logger.info(epoch_info) | |
if self.trainer.writer is not None: | |
for key in self.model_output_keys: | |
self.trainer.writer.add_scalar( | |
"train/" + key, | |
self.trainer.storage.history(key).avg, | |
self.trainer.epoch + 1, | |
) | |
class CheckpointSaver(HookBase): | |
def __init__(self, save_freq=None): | |
self.save_freq = save_freq # None or int, None indicate only save model last | |
def after_epoch(self): | |
if is_main_process(): | |
is_best = False | |
if self.trainer.cfg.evaluate: | |
current_metric_value = self.trainer.comm_info["current_metric_value"] | |
current_metric_name = self.trainer.comm_info["current_metric_name"] | |
if current_metric_value > self.trainer.best_metric_value: | |
self.trainer.best_metric_value = current_metric_value | |
is_best = True | |
self.trainer.logger.info( | |
"Best validation {} updated to: {:.4f}".format( | |
current_metric_name, current_metric_value | |
) | |
) | |
self.trainer.logger.info( | |
"Currently Best {}: {:.4f}".format( | |
current_metric_name, self.trainer.best_metric_value | |
) | |
) | |
filename = os.path.join( | |
self.trainer.cfg.save_path, "model", "model_last.pth" | |
) | |
self.trainer.logger.info("Saving checkpoint to: " + filename) | |
torch.save( | |
{ | |
"epoch": self.trainer.epoch + 1, | |
"state_dict": self.trainer.model.state_dict(), | |
"optimizer": self.trainer.optimizer.state_dict(), | |
"scheduler": self.trainer.scheduler.state_dict(), | |
"scaler": ( | |
self.trainer.scaler.state_dict() | |
if self.trainer.cfg.enable_amp | |
else None | |
), | |
"best_metric_value": self.trainer.best_metric_value, | |
}, | |
filename + ".tmp", | |
) | |
os.replace(filename + ".tmp", filename) | |
if is_best: | |
shutil.copyfile( | |
filename, | |
os.path.join(self.trainer.cfg.save_path, "model", "model_best.pth"), | |
) | |
if self.save_freq and (self.trainer.epoch + 1) % self.save_freq == 0: | |
shutil.copyfile( | |
filename, | |
os.path.join( | |
self.trainer.cfg.save_path, | |
"model", | |
f"epoch_{self.trainer.epoch + 1}.pth", | |
), | |
) | |
class CheckpointLoader(HookBase): | |
def __init__(self, keywords="", replacement=None, strict=False): | |
self.keywords = keywords | |
self.replacement = replacement if replacement is not None else keywords | |
self.strict = strict | |
def before_train(self): | |
self.trainer.logger.info("=> Loading checkpoint & weight ...") | |
if self.trainer.cfg.weight and os.path.isfile(self.trainer.cfg.weight): | |
self.trainer.logger.info(f"Loading weight at: {self.trainer.cfg.weight}") | |
checkpoint = torch.load( | |
self.trainer.cfg.weight, | |
map_location=lambda storage, loc: storage.cuda(), | |
) | |
self.trainer.logger.info( | |
f"Loading layer weights with keyword: {self.keywords}, " | |
f"replace keyword with: {self.replacement}" | |
) | |
weight = OrderedDict() | |
for key, value in checkpoint["state_dict"].items(): | |
if not key.startswith("module."): | |
key = "module." + key # xxx.xxx -> module.xxx.xxx | |
# Now all keys contain "module." no matter DDP or not. | |
if self.keywords in key: | |
key = key.replace(self.keywords, self.replacement) | |
if comm.get_world_size() == 1: | |
key = key[7:] # module.xxx.xxx -> xxx.xxx | |
weight[key] = value | |
load_state_info = self.trainer.model.load_state_dict( | |
weight, strict=self.strict | |
) | |
self.trainer.logger.info(f"Missing keys: {load_state_info[0]}") | |
if self.trainer.cfg.resume: | |
self.trainer.logger.info( | |
f"Resuming train at eval epoch: {checkpoint['epoch']}" | |
) | |
self.trainer.start_epoch = checkpoint["epoch"] | |
self.trainer.best_metric_value = checkpoint["best_metric_value"] | |
self.trainer.optimizer.load_state_dict(checkpoint["optimizer"]) | |
self.trainer.scheduler.load_state_dict(checkpoint["scheduler"]) | |
if self.trainer.cfg.enable_amp: | |
self.trainer.scaler.load_state_dict(checkpoint["scaler"]) | |
else: | |
self.trainer.logger.info(f"No weight found at: {self.trainer.cfg.weight}") | |
class PreciseEvaluator(HookBase): | |
def __init__(self, test_last=False): | |
self.test_last = test_last | |
def after_train(self): | |
self.trainer.logger.info( | |
">>>>>>>>>>>>>>>> Start Precise Evaluation >>>>>>>>>>>>>>>>" | |
) | |
torch.cuda.empty_cache() | |
cfg = self.trainer.cfg | |
tester = TESTERS.build( | |
dict(type=cfg.test.type, cfg=cfg, model=self.trainer.model) | |
) | |
if self.test_last: | |
self.trainer.logger.info("=> Testing on model_last ...") | |
else: | |
self.trainer.logger.info("=> Testing on model_best ...") | |
best_path = os.path.join( | |
self.trainer.cfg.save_path, "model", "model_best.pth" | |
) | |
checkpoint = torch.load(best_path) | |
state_dict = checkpoint["state_dict"] | |
tester.model.load_state_dict(state_dict, strict=True) | |
tester.test() | |
class DataCacheOperator(HookBase): | |
def __init__(self, data_root, split): | |
self.data_root = data_root | |
self.split = split | |
self.data_list = self.get_data_list() | |
def get_data_list(self): | |
if isinstance(self.split, str): | |
data_list = glob.glob(os.path.join(self.data_root, self.split)) | |
elif isinstance(self.split, Sequence): | |
data_list = [] | |
for split in self.split: | |
data_list += glob.glob(os.path.join(self.data_root, split)) | |
else: | |
raise NotImplementedError | |
return data_list | |
def get_cache_name(self, data_path): | |
data_name = data_path.replace(os.path.dirname(self.data_root), "") | |
return "pointcept" + data_name.replace(os.path.sep, "-") | |
def before_train(self): | |
self.trainer.logger.info( | |
f"=> Caching dataset: {self.data_root}, split: {self.split} ..." | |
) | |
if is_main_process(): | |
dataset = self.trainer.train_loader.dataset | |
for i in range(len(dataset)): | |
data_dict = dataset[i] | |
name = data_dict["name"] | |
shared_dict(f"Pointcept-{name}", data_dict) | |
synchronize() | |
class RuntimeProfiler(HookBase): | |
def __init__( | |
self, | |
forward=True, | |
backward=True, | |
interrupt=False, | |
warm_up=2, | |
sort_by="cuda_time_total", | |
row_limit=30, | |
): | |
self.forward = forward | |
self.backward = backward | |
self.interrupt = interrupt | |
self.warm_up = warm_up | |
self.sort_by = sort_by | |
self.row_limit = row_limit | |
def before_train(self): | |
self.trainer.logger.info("Profiling runtime ...") | |
from torch.profiler import profile, record_function, ProfilerActivity | |
for i, input_dict in enumerate(self.trainer.train_loader): | |
if i == self.warm_up + 1: | |
break | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
if self.forward: | |
with profile( | |
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
record_shapes=True, | |
profile_memory=True, | |
with_stack=True, | |
) as forward_prof: | |
with record_function("model_inference"): | |
output_dict = self.trainer.model(input_dict) | |
else: | |
output_dict = self.trainer.model(input_dict) | |
loss = output_dict["loss"] | |
if self.backward: | |
with profile( | |
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
record_shapes=True, | |
profile_memory=True, | |
with_stack=True, | |
) as backward_prof: | |
with record_function("model_inference"): | |
loss.backward() | |
self.trainer.logger.info(f"Profile: [{i + 1}/{self.warm_up + 1}]") | |
if self.forward: | |
self.trainer.logger.info( | |
"Forward profile: \n" | |
+ str( | |
forward_prof.key_averages().table( | |
sort_by=self.sort_by, row_limit=self.row_limit | |
) | |
) | |
) | |
forward_prof.export_chrome_trace( | |
os.path.join(self.trainer.cfg.save_path, "forward_trace.json") | |
) | |
if self.backward: | |
self.trainer.logger.info( | |
"Backward profile: \n" | |
+ str( | |
backward_prof.key_averages().table( | |
sort_by=self.sort_by, row_limit=self.row_limit | |
) | |
) | |
) | |
backward_prof.export_chrome_trace( | |
os.path.join(self.trainer.cfg.save_path, "backward_trace.json") | |
) | |
if self.interrupt: | |
sys.exit(0) | |
class RuntimeProfilerV2(HookBase): | |
def __init__( | |
self, | |
interrupt=False, | |
wait=1, | |
warmup=1, | |
active=10, | |
repeat=1, | |
sort_by="cuda_time_total", | |
row_limit=30, | |
): | |
self.interrupt = interrupt | |
self.wait = wait | |
self.warmup = warmup | |
self.active = active | |
self.repeat = repeat | |
self.sort_by = sort_by | |
self.row_limit = row_limit | |
def before_train(self): | |
self.trainer.logger.info("Profiling runtime ...") | |
from torch.profiler import ( | |
profile, | |
record_function, | |
ProfilerActivity, | |
schedule, | |
tensorboard_trace_handler, | |
) | |
prof = profile( | |
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
schedule=schedule( | |
wait=self.wait, | |
warmup=self.warmup, | |
active=self.active, | |
repeat=self.repeat, | |
), | |
on_trace_ready=tensorboard_trace_handler(self.trainer.cfg.save_path), | |
record_shapes=True, | |
profile_memory=True, | |
with_stack=True, | |
) | |
prof.start() | |
for i, input_dict in enumerate(self.trainer.train_loader): | |
if i >= (self.wait + self.warmup + self.active) * self.repeat: | |
break | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].cuda(non_blocking=True) | |
with record_function("model_forward"): | |
output_dict = self.trainer.model(input_dict) | |
loss = output_dict["loss"] | |
with record_function("model_backward"): | |
loss.backward() | |
prof.step() | |
self.trainer.logger.info( | |
f"Profile: [{i + 1}/{(self.wait + self.warmup + self.active) * self.repeat}]" | |
) | |
self.trainer.logger.info( | |
"Profile: \n" | |
+ str( | |
prof.key_averages().table( | |
sort_by=self.sort_by, row_limit=self.row_limit | |
) | |
) | |
) | |
prof.stop() | |
if self.interrupt: | |
sys.exit(0) | |